diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index c46e4e985449a3fae8aa3041782b02ab8213c7cb..bebb8c6cd0867f8018a65d421b30938f60da49db 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1515,6 +1515,10 @@ fn fork_fusion( * Looks for would-be monoid reduces, if not for a gate on the reduction. * Partially predicate the gated reduction to allow for a proper monoid * reduction. + * + * Looks for monoid reduces that occur through a gated write to a single + * location, and lift them to a proper monoid reduction with a single gated + * write afterwards. */ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { for id in editor.node_ids() { @@ -1676,6 +1680,121 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { }); } } + + for id in editor.node_ids() { + // Identify reduce/write/phi cycle through which a sparse AND reduction + // is occurring. + let nodes = &editor.func().nodes; + let Some((join, init, reduct)) = nodes[id.idx()].try_reduce() else { + continue; + }; + let join_pred = nodes[join.idx()].try_join().unwrap(); + let join_succ = editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_control()) + .next() + .unwrap(); + let Some((_, phi_data)) = nodes[reduct.idx()].try_phi() else { + continue; + }; + if phi_data.len() != 2 { + continue; + } + let phi_other_use = if phi_data[0] == id { + phi_data[1] + } else if phi_data[1] == id { + phi_data[0] + } else { + continue; + }; + let Some((collect, data, indices)) = nodes[phi_other_use.idx()].try_write() else { + continue; + }; + if collect != id { + continue; + } + if indices.into_iter().any(|idx| idx.is_position()) { + continue; + } + if !is_false(editor, data) { + continue; + } + let Some(preds) = nodes[join_pred.idx()].try_region() else { + continue; + }; + if preds.len() != 2 { + continue; + } + let Some((if1, _)) = nodes[preds[0].idx()].try_control_proj() else { + continue; + }; + let Some((if2, sel)) = nodes[preds[1].idx()].try_control_proj() else { + continue; + }; + if if1 != if2 { + continue; + } + let Some((_, mut cond)) = nodes[if1.idx()].try_if() else { + continue; + }; + + // Transform to a monoid reduction and a single gated write. + let negated = phi_other_use == phi_data[sel]; + let indices = indices.to_vec().into_boxed_slice(); + editor.edit(|mut edit| { + let t = edit.add_constant(Constant::Boolean(true)); + let t = edit.add_node(Node::Constant { id: t }); + let f = edit.add_constant(Constant::Boolean(false)); + let f = edit.add_node(Node::Constant { id: f }); + if negated { + cond = edit.add_node(Node::Unary { + op: UnaryOperator::Not, + input: cond, + }); + } + let reduce_id = NodeID::new(edit.num_node_ids()); + let and_id = NodeID::new(edit.num_node_ids() + 1); + edit.add_node(Node::Reduce { + control: join, + init: t, + reduct: and_id, + }); + edit.add_node(Node::Binary { + op: BinaryOperator::And, + left: cond, + right: reduce_id, + }); + + let new_if = edit.add_node(Node::If { + control: join, + cond: reduce_id, + }); + let cpj1 = edit.add_node(Node::ControlProjection { + control: new_if, + selection: 0, + }); + let cpj2 = edit.add_node(Node::ControlProjection { + control: new_if, + selection: 1, + }); + let region = edit.add_node(Node::Region { + preds: Box::new([cpj1, cpj2]), + }); + let write = edit.add_node(Node::Write { + collect: init, + data: f, + indices, + }); + let phi = edit.add_node(Node::Phi { + control: region, + data: Box::new([write, init]), + }); + edit = edit.replace_all_uses_where(id, phi, |other_id| { + *other_id != phi_other_use && *other_id != reduct + })?; + edit.replace_all_uses_where(join, region, |id| *id == join_succ) + }); + } } /* diff --git a/juno_samples/rodinia/bfs/src/cpu.sch b/juno_samples/rodinia/bfs/src/cpu.sch index f564cd36571564dfc352315c848c14c36ad5970f..07006edbd791810fc90e5e756315449687d13d7f 100644 --- a/juno_samples/rodinia/bfs/src/cpu.sch +++ b/juno_samples/rodinia/bfs/src/cpu.sch @@ -54,9 +54,13 @@ if !feature("seq") { inline(bfs@cost_init, bfs@loop1, bfs@loop2); init = init_body; } -fork-tile[8, 0, false, true](init, traverse, collect); delete-uncalled(*); const-inline(*); +clean-monoid-reduces(collect); +simpl!(*); + +fork-tile[8, 0, false, true](init, traverse, collect); +clean-monoid-reduces(collect); simpl!(*); fork-split(init, traverse, collect); diff --git a/juno_samples/rodinia/bfs/src/gpu.sch b/juno_samples/rodinia/bfs/src/gpu.sch index 541d15d7a5b90b17a484c98c2ed216c5912bd666..ea81f330072e891df1250a88f54e87e7d73610d0 100644 --- a/juno_samples/rodinia/bfs/src/gpu.sch +++ b/juno_samples/rodinia/bfs/src/gpu.sch @@ -39,6 +39,7 @@ fixpoint { simpl!(collect); fork-tile[1024, 0, false, true](init, traverse, collect); -fork-split(init, traverse, collect); +let out = fork-split(init, traverse, collect); +simpl!(*); gcm(*);