From f092ae383b280a6df9778e62198c4341d2f1e8ad Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Mar 2025 19:15:45 -0600
Subject: [PATCH] more backprop opt

---
 hercules_opt/src/fork_transforms.rs       | 102 ++++++++++++++++++----
 hercules_opt/src/simplify_cfg.rs          |  15 +++-
 juno_samples/rodinia/backprop/src/gpu.sch |   7 +-
 3 files changed, 106 insertions(+), 18 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 1c220b99..e1598463 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -319,12 +319,12 @@ pub fn fork_fission<'a>(
         .collect();
 
     let mut created_forks = Vec::new();
-    
-    // This does the reduction fission 
+
+    // This does the reduction fission
     for fork in forks {
         let join = fork_join_map[&fork.0];
 
-        // FIXME: Don't make multiple forks for reduces that are in cycles with each other. 
+        // FIXME: Don't make multiple forks for reduces that are in cycles with each other.
         let reduce_partition = default_reduce_partition(editor, fork.0, join);
 
         if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
@@ -332,14 +332,19 @@ pub fn fork_fission<'a>(
         }
 
         if editor.is_mutable(fork.0) {
-            created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
+            created_forks = fork_reduce_fission_helper(
+                editor,
+                fork_join_map,
+                reduce_partition,
+                nodes_in_fork_joins,
+                fork.0,
+            );
             if created_forks.is_empty() {
                 continue;
             } else {
                 return created_forks;
             }
         }
-            
     }
 
     created_forks
@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>(
 
     let mut new_forks = Vec::new();
 
-    let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap();
+    let mut new_control_pred: NodeID = editor
+        .get_uses(fork)
+        .filter(|n| editor.node(n).is_control())
+        .next()
+        .unwrap();
 
     let mut new_fork = NodeID::new(0);
     let mut new_join = NodeID::new(0);
 
-    let subgraph = &nodes_in_fork_joins[&fork]; 
-    
+    let subgraph = &nodes_in_fork_joins[&fork];
+
     // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
     editor.edit(|mut edit| {
         for reduce in reduce_partition {
@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>(
             new_fork = mapping[&fork];
             new_forks.push(new_fork);
             new_join = mapping[&join];
-            
+
             // Atttach new_fork after control_pred
             let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
             edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>(
             // Replace uses of reduce
             edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
             new_control_pred = new_join;
-        };
+        }
 
         // Replace original join w/ new final join
         edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
@@ -1502,6 +1511,10 @@ fn fork_fusion(
  * element. This aides in parallelizing outer loops. Looks only at reduces with
  * the monoid reduce schedule, since that indicates a particular structure which
  * is annoying to check for again.
+ *
+ * Looks for would-be monoid reduces, if not for a gate on the reduction.
+ * Partially predicate the gated reduction to allow for a proper monoid
+ * reduction.
  */
 pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
     for id in editor.node_ids() {
@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
         let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
             continue;
         };
-        let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
+        let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
 
         match nodes[reduct.idx()] {
             Node::Binary {
@@ -1532,7 +1545,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                         left: init,
                         right: id,
                     });
-                    for u in out_uses {
+                    for u in out_users {
                         edit.sub_edit(u, final_op);
                     }
                     edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
@@ -1555,7 +1568,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                         left: init,
                         right: id,
                     });
-                    for u in out_uses {
+                    for u in out_users {
                         edit.sub_edit(u, final_op);
                     }
                     edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
@@ -1574,7 +1587,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                         intrinsic: Intrinsic::Max,
                         args: Box::new([init, id]),
                     });
-                    for u in out_uses {
+                    for u in out_users {
                         edit.sub_edit(u, final_op);
                     }
                     edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
@@ -1593,7 +1606,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                         intrinsic: Intrinsic::Min,
                         args: Box::new([init, id]),
                     });
-                    for u in out_uses {
+                    for u in out_users {
                         edit.sub_edit(u, final_op);
                     }
                     edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
@@ -1602,6 +1615,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
             _ => {}
         }
     }
+
+    for id in editor.node_ids() {
+        if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
+            continue;
+        }
+        let nodes = &editor.func().nodes;
+        let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else {
+            continue;
+        };
+        if let Node::Phi {
+            control: phi_control,
+            ref data,
+        } = nodes[reduct.idx()]
+            && data.len() == 2
+            && data.contains(&id)
+            && let other = *data
+                .into_iter()
+                .filter(|other| **other != id)
+                .next()
+                .unwrap()
+            && let Node::Binary {
+                op: BinaryOperator::Add,
+                left,
+                right,
+            } = nodes[other.idx()]
+            && ((left == id) ^ (right == id))
+        {
+            let gated_input = if left == id { right } else { left };
+            let data = data.clone();
+            editor.edit(|mut edit| {
+                let zero = edit.add_zero_constant(typing[id.idx()]);
+                let zero = edit.add_node(Node::Constant { id: zero });
+                let phi = edit.add_node(Node::Phi {
+                    control: phi_control,
+                    data: data
+                        .iter()
+                        .map(|phi_use| if *phi_use == id { zero } else { gated_input })
+                        .collect(),
+                });
+                let new_reduce_id = NodeID::new(edit.num_node_ids());
+                let new_reduct_id = NodeID::new(edit.num_node_ids() + 1);
+                let new_reduce = Node::Reduce {
+                    control,
+                    init,
+                    reduct: new_reduct_id,
+                };
+                let new_add = Node::Binary {
+                    op: BinaryOperator::Add,
+                    left: new_reduce_id,
+                    right: phi,
+                };
+                let new_reduce = edit.add_node(new_reduce);
+                edit.add_node(new_add);
+                edit = edit.replace_all_uses(id, new_reduce)?;
+                edit = edit.delete_node(id)?;
+                Ok(edit)
+            });
+        }
+    }
 }
 
 /*
diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs
index cf39db2b..b13cf0c3 100644
--- a/hercules_opt/src/simplify_cfg.rs
+++ b/hercules_opt/src/simplify_cfg.rs
@@ -126,11 +126,24 @@ fn remove_useless_fork_joins(
 
     // Third, get rid of fork-joins.
     for (fork, join) in fork_join_map {
-        if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 {
+        if editor.get_users(*join).len() == 1 {
             let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
             let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
+            let tids: Vec<_> = editor
+                .get_users(*fork)
+                .filter(|id| editor.func().nodes[id.idx()].is_thread_id())
+                .collect();
 
             editor.edit(|mut edit| {
+                if !tids.is_empty() {
+                    let u64_ty = edit.add_type(Type::UnsignedInteger64);
+                    let zero = edit.add_zero_constant(u64_ty);
+                    let zero = edit.add_node(Node::Constant { id: zero });
+                    for tid in tids {
+                        edit = edit.replace_all_uses(tid, zero)?;
+                        edit = edit.delete_node(tid)?;
+                    }
+                }
                 edit = edit.replace_all_uses(*join, join_use)?;
                 edit = edit.replace_all_uses(*fork, fork_use)?;
                 edit = edit.delete_node(*fork)?;
diff --git a/juno_samples/rodinia/backprop/src/gpu.sch b/juno_samples/rodinia/backprop/src/gpu.sch
index d0be79db..f8cc84a3 100644
--- a/juno_samples/rodinia/backprop/src/gpu.sch
+++ b/juno_samples/rodinia/backprop/src/gpu.sch
@@ -33,7 +33,11 @@ fixpoint {
 reduce-slf(*);
 simpl!(*);
 
-fork-tile[16, 0, false, true](layer_forward@inner_loop);
+fork-extend[32](layer_forward@inner_loop);
+clean-monoid-reduces(layer_forward);
+simpl!(layer_forward);
+fork-tile[32, 0, false, true](layer_forward@inner_loop);
+clean-monoid-reduces(layer_forward);
 let out = fork-split(layer_forward@inner_loop);
 clean-monoid-reduces(layer_forward);
 simpl!(layer_forward);
@@ -47,5 +51,4 @@ fork-tile[32, 0, false, true](adjust_weights);
 fork-split(adjust_weights);
 simpl!(adjust_weights);
 
-xdot[true](*);
 gcm(*);
-- 
GitLab