From f092ae383b280a6df9778e62198c4341d2f1e8ad Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Sun, 2 Mar 2025 19:15:45 -0600 Subject: [PATCH] more backprop opt --- hercules_opt/src/fork_transforms.rs | 102 ++++++++++++++++++---- hercules_opt/src/simplify_cfg.rs | 15 +++- juno_samples/rodinia/backprop/src/gpu.sch | 7 +- 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 1c220b99..e1598463 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -319,12 +319,12 @@ pub fn fork_fission<'a>( .collect(); let mut created_forks = Vec::new(); - - // This does the reduction fission + + // This does the reduction fission for fork in forks { let join = fork_join_map[&fork.0]; - // FIXME: Don't make multiple forks for reduces that are in cycles with each other. + // FIXME: Don't make multiple forks for reduces that are in cycles with each other. let reduce_partition = default_reduce_partition(editor, fork.0, join); if !editor.func().labels[fork.0.idx()].contains(&fork_label) { @@ -332,14 +332,19 @@ pub fn fork_fission<'a>( } if editor.is_mutable(fork.0) { - created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0); + created_forks = fork_reduce_fission_helper( + editor, + fork_join_map, + reduce_partition, + nodes_in_fork_joins, + fork.0, + ); if created_forks.is_empty() { continue; } else { return created_forks; } } - } created_forks @@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>( let mut new_forks = Vec::new(); - let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap(); + let mut new_control_pred: NodeID = editor + .get_uses(fork) + .filter(|n| editor.node(n).is_control()) + .next() + .unwrap(); let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); - let subgraph = &nodes_in_fork_joins[&fork]; - + let subgraph = &nodes_in_fork_joins[&fork]; + // Gets everything between fork & join that this reduce needs. (ALL CONTROL) editor.edit(|mut edit| { for reduce in reduce_partition { @@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>( new_fork = mapping[&fork]; new_forks.push(new_fork); new_join = mapping[&join]; - + // Atttach new_fork after control_pred let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { @@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>( // Replace uses of reduce edit = edit.replace_all_uses(reduce, mapping[&reduce])?; new_control_pred = new_join; - }; + } // Replace original join w/ new final join edit = edit.replace_all_uses_where(join, new_join, |_| true)?; @@ -1502,6 +1511,10 @@ fn fork_fusion( * element. This aides in parallelizing outer loops. Looks only at reduces with * the monoid reduce schedule, since that indicates a particular structure which * is annoying to check for again. + * + * Looks for would-be monoid reduces, if not for a gate on the reduction. + * Partially predicate the gated reduction to allow for a proper monoid + * reduction. */ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { for id in editor.node_ids() { @@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { continue; }; - let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); + let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); match nodes[reduct.idx()] { Node::Binary { @@ -1532,7 +1545,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { left: init, right: id, }); - for u in out_uses { + for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) @@ -1555,7 +1568,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { left: init, right: id, }); - for u in out_uses { + for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) @@ -1574,7 +1587,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { intrinsic: Intrinsic::Max, args: Box::new([init, id]), }); - for u in out_uses { + for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) @@ -1593,7 +1606,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { intrinsic: Intrinsic::Min, args: Box::new([init, id]), }); - for u in out_uses { + for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) @@ -1602,6 +1615,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { _ => {} } } + + for id in editor.node_ids() { + if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) { + continue; + } + let nodes = &editor.func().nodes; + let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else { + continue; + }; + if let Node::Phi { + control: phi_control, + ref data, + } = nodes[reduct.idx()] + && data.len() == 2 + && data.contains(&id) + && let other = *data + .into_iter() + .filter(|other| **other != id) + .next() + .unwrap() + && let Node::Binary { + op: BinaryOperator::Add, + left, + right, + } = nodes[other.idx()] + && ((left == id) ^ (right == id)) + { + let gated_input = if left == id { right } else { left }; + let data = data.clone(); + editor.edit(|mut edit| { + let zero = edit.add_zero_constant(typing[id.idx()]); + let zero = edit.add_node(Node::Constant { id: zero }); + let phi = edit.add_node(Node::Phi { + control: phi_control, + data: data + .iter() + .map(|phi_use| if *phi_use == id { zero } else { gated_input }) + .collect(), + }); + let new_reduce_id = NodeID::new(edit.num_node_ids()); + let new_reduct_id = NodeID::new(edit.num_node_ids() + 1); + let new_reduce = Node::Reduce { + control, + init, + reduct: new_reduct_id, + }; + let new_add = Node::Binary { + op: BinaryOperator::Add, + left: new_reduce_id, + right: phi, + }; + let new_reduce = edit.add_node(new_reduce); + edit.add_node(new_add); + edit = edit.replace_all_uses(id, new_reduce)?; + edit = edit.delete_node(id)?; + Ok(edit) + }); + } + } } /* diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs index cf39db2b..b13cf0c3 100644 --- a/hercules_opt/src/simplify_cfg.rs +++ b/hercules_opt/src/simplify_cfg.rs @@ -126,11 +126,24 @@ fn remove_useless_fork_joins( // Third, get rid of fork-joins. for (fork, join) in fork_join_map { - if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 { + if editor.get_users(*join).len() == 1 { let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0]; let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0]; + let tids: Vec<_> = editor + .get_users(*fork) + .filter(|id| editor.func().nodes[id.idx()].is_thread_id()) + .collect(); editor.edit(|mut edit| { + if !tids.is_empty() { + let u64_ty = edit.add_type(Type::UnsignedInteger64); + let zero = edit.add_zero_constant(u64_ty); + let zero = edit.add_node(Node::Constant { id: zero }); + for tid in tids { + edit = edit.replace_all_uses(tid, zero)?; + edit = edit.delete_node(tid)?; + } + } edit = edit.replace_all_uses(*join, join_use)?; edit = edit.replace_all_uses(*fork, fork_use)?; edit = edit.delete_node(*fork)?; diff --git a/juno_samples/rodinia/backprop/src/gpu.sch b/juno_samples/rodinia/backprop/src/gpu.sch index d0be79db..f8cc84a3 100644 --- a/juno_samples/rodinia/backprop/src/gpu.sch +++ b/juno_samples/rodinia/backprop/src/gpu.sch @@ -33,7 +33,11 @@ fixpoint { reduce-slf(*); simpl!(*); -fork-tile[16, 0, false, true](layer_forward@inner_loop); +fork-extend[32](layer_forward@inner_loop); +clean-monoid-reduces(layer_forward); +simpl!(layer_forward); +fork-tile[32, 0, false, true](layer_forward@inner_loop); +clean-monoid-reduces(layer_forward); let out = fork-split(layer_forward@inner_loop); clean-monoid-reduces(layer_forward); simpl!(layer_forward); @@ -47,5 +51,4 @@ fork-tile[32, 0, false, true](adjust_weights); fork-split(adjust_weights); simpl!(adjust_weights); -xdot[true](*); gcm(*); -- GitLab