From 0a6326e6aaaffec59d8490f89b9bc187e79a686d Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 27 Feb 2025 19:06:34 -0600 Subject: [PATCH] Infer parallelreduce in bfs --- hercules_ir/src/ir.rs | 17 ++++++-- hercules_opt/src/pred.rs | 63 ++++++++++++++++++++++++++++ hercules_opt/src/schedule.rs | 31 +++++++++++--- juno_samples/rodinia/bfs/src/bfs.jn | 2 +- juno_samples/rodinia/bfs/src/cpu.sch | 3 +- 5 files changed, 106 insertions(+), 10 deletions(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5dfe2915..f6aafa35 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1048,9 +1048,20 @@ impl Constant { } } - /* - * Useful for GVN. - */ + pub fn is_false(&self) -> bool { + match self { + Constant::Boolean(false) => true, + _ => false, + } + } + + pub fn is_true(&self) -> bool { + match self { + Constant::Boolean(true) => true, + _ => false, + } + } + pub fn is_zero(&self) -> bool { match self { Constant::Integer8(0) => true, diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs index ed7c3a85..587c4507 100644 --- a/hercules_opt/src/pred.rs +++ b/hercules_opt/src/pred.rs @@ -136,6 +136,69 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { bad_branches.insert(branch); } } + + // Do a quick and dirty rewrite to convert select(a, b, false) to a && b and + // select(a, b, true) to a || b. + for id in editor.node_ids() { + let nodes = &editor.func().nodes; + if let Node::Ternary { + op: TernaryOperator::Select, + first, + second, + third, + } = nodes[id.idx()] + { + if let Some(cons) = nodes[second.idx()].try_constant() + && editor.get_constant(cons).is_false() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::And, + left: first, + right: third, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } else if let Some(cons) = nodes[third.idx()].try_constant() + && editor.get_constant(cons).is_false() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::And, + left: first, + right: second, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } else if let Some(cons) = nodes[second.idx()].try_constant() + && editor.get_constant(cons).is_true() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::Or, + left: first, + right: third, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } else if let Some(cons) = nodes[third.idx()].try_constant() + && editor.get_constant(cons).is_true() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::Or, + left: first, + right: second, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } + } + } } /* diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index d7ae4048..9bc7823e 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -69,6 +69,26 @@ pub fn infer_parallel_reduce( chain_id = reduct; } + // If the use is a phi that uses the reduce and a write, then we might + // want to parallelize this still. Set the chain ID to the write. + if let Node::Phi { + control: _, + ref data, + } = func.nodes[chain_id.idx()] + && data.len() + == data + .into_iter() + .filter(|phi_use| **phi_use == last_reduce) + .count() + + 1 + { + chain_id = *data + .into_iter() + .filter(|phi_use| **phi_use != last_reduce) + .next() + .unwrap(); + } + // Check for a Write-Reduce tight cycle. if let Node::Write { collect, @@ -130,12 +150,13 @@ pub fn infer_monoid_reduce( reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { let is_binop_monoid = |op| { - matches!( - op, - BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And - ) + op == BinaryOperator::Add + || op == BinaryOperator::Mul + || op == BinaryOperator::Or + || op == BinaryOperator::And }; - let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); + let is_intrinsic_monoid = + |intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min; for id in editor.node_ids() { let func = editor.func(); diff --git a/juno_samples/rodinia/bfs/src/bfs.jn b/juno_samples/rodinia/bfs/src/bfs.jn index 51dcd945..ca0f7774 100644 --- a/juno_samples/rodinia/bfs/src/bfs.jn +++ b/juno_samples/rodinia/bfs/src/bfs.jn @@ -43,10 +43,10 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] } @loop2 for i in 0..n { + stop = stop && updated[i]; if updated[i] { mask[i] = true; visited[i] = true; - stop = false; updated[i] = false; } } diff --git a/juno_samples/rodinia/bfs/src/cpu.sch b/juno_samples/rodinia/bfs/src/cpu.sch index 44cfa8ad..ae67fdd9 100644 --- a/juno_samples/rodinia/bfs/src/cpu.sch +++ b/juno_samples/rodinia/bfs/src/cpu.sch @@ -23,7 +23,8 @@ fixpoint { fork-guard-elim(*); } simpl!(*); +predication(*); +simpl!(*); unforkify(*); - gcm(*); -- GitLab