diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5dfe2915f5f3e30f56b6665dc27d23cd40cca3d4..f6aafa35bd2c2d94324e63b2b91213ad0c2e9c4f 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1048,9 +1048,20 @@ impl Constant { } } - /* - * Useful for GVN. - */ + pub fn is_false(&self) -> bool { + match self { + Constant::Boolean(false) => true, + _ => false, + } + } + + pub fn is_true(&self) -> bool { + match self { + Constant::Boolean(true) => true, + _ => false, + } + } + pub fn is_zero(&self) -> bool { match self { Constant::Integer8(0) => true, diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs index ed7c3a855b016608aa194cc9f2cd89f05d836bde..587c4507a60b7827f4ca4e32789547feeafc0bdf 100644 --- a/hercules_opt/src/pred.rs +++ b/hercules_opt/src/pred.rs @@ -136,6 +136,69 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { bad_branches.insert(branch); } } + + // Do a quick and dirty rewrite to convert select(a, b, false) to a && b and + // select(a, b, true) to a || b. + for id in editor.node_ids() { + let nodes = &editor.func().nodes; + if let Node::Ternary { + op: TernaryOperator::Select, + first, + second, + third, + } = nodes[id.idx()] + { + if let Some(cons) = nodes[second.idx()].try_constant() + && editor.get_constant(cons).is_false() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::And, + left: first, + right: third, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } else if let Some(cons) = nodes[third.idx()].try_constant() + && editor.get_constant(cons).is_false() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::And, + left: first, + right: second, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } else if let Some(cons) = nodes[second.idx()].try_constant() + && editor.get_constant(cons).is_true() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::Or, + left: first, + right: third, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } else if let Some(cons) = nodes[third.idx()].try_constant() + && editor.get_constant(cons).is_true() + { + editor.edit(|mut edit| { + let node = edit.add_node(Node::Binary { + op: BinaryOperator::Or, + left: first, + right: second, + }); + edit = edit.replace_all_uses(id, node)?; + edit.delete_node(id) + }); + } + } + } } /* diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index d7ae40488d75a1da7ef65b8a53a894bc0f62cded..9bc7823ee7f5837cf49387170e548a9174340f42 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -69,6 +69,26 @@ pub fn infer_parallel_reduce( chain_id = reduct; } + // If the use is a phi that uses the reduce and a write, then we might + // want to parallelize this still. Set the chain ID to the write. + if let Node::Phi { + control: _, + ref data, + } = func.nodes[chain_id.idx()] + && data.len() + == data + .into_iter() + .filter(|phi_use| **phi_use == last_reduce) + .count() + + 1 + { + chain_id = *data + .into_iter() + .filter(|phi_use| **phi_use != last_reduce) + .next() + .unwrap(); + } + // Check for a Write-Reduce tight cycle. if let Node::Write { collect, @@ -130,12 +150,13 @@ pub fn infer_monoid_reduce( reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { let is_binop_monoid = |op| { - matches!( - op, - BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And - ) + op == BinaryOperator::Add + || op == BinaryOperator::Mul + || op == BinaryOperator::Or + || op == BinaryOperator::And }; - let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); + let is_intrinsic_monoid = + |intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min; for id in editor.node_ids() { let func = editor.func(); diff --git a/juno_samples/rodinia/bfs/src/bfs.jn b/juno_samples/rodinia/bfs/src/bfs.jn index 51dcd945429dfde02cb2313afa404e81f8722c84..ca0f77743ce831817fb528f4de029932d30099a0 100644 --- a/juno_samples/rodinia/bfs/src/bfs.jn +++ b/juno_samples/rodinia/bfs/src/bfs.jn @@ -43,10 +43,10 @@ fn bfs<n, m: usize>(graph_nodes: Node[n], source: u32, edges: u32[m]) -> i32[n] } @loop2 for i in 0..n { + stop = stop && updated[i]; if updated[i] { mask[i] = true; visited[i] = true; - stop = false; updated[i] = false; } } diff --git a/juno_samples/rodinia/bfs/src/cpu.sch b/juno_samples/rodinia/bfs/src/cpu.sch index 44cfa8ad0161fac0afbccc2d383637ec8a2f1aa0..ae67fdd987e961a95311a7d3aaa0f94fe31f1687 100644 --- a/juno_samples/rodinia/bfs/src/cpu.sch +++ b/juno_samples/rodinia/bfs/src/cpu.sch @@ -23,7 +23,8 @@ fixpoint { fork-guard-elim(*); } simpl!(*); +predication(*); +simpl!(*); unforkify(*); - gcm(*);