From bb1e436d9f5a2e4c2dc57b8d62166fe030caf24d Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Mar 2025 21:30:28 -0600
Subject: [PATCH] fork fission w/ reduces

---
 hercules_opt/src/fork_transforms.rs           | 107 +++++++-----------
 juno_samples/fork_join_tests/src/cpu.sch      |  18 ++-
 .../fork_join_tests/src/fork_join_tests.jn    |  16 +++
 juno_samples/fork_join_tests/src/main.rs      |   5 +
 juno_scheduler/src/compile.rs                 |   3 +-
 juno_scheduler/src/ir.rs                      |   1 +
 juno_scheduler/src/pm.rs                      | 107 +++++++++++++++++-
 7 files changed, 191 insertions(+), 66 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e6db0345..1df5338d 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -306,40 +306,33 @@ where
 
 pub fn fork_fission<'a>(
     editor: &'a mut FunctionEditor,
-    _control_subgraph: &Subgraph,
-    _types: &Vec<TypeID>,
-    _loop_tree: &LoopTree,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+    loop_tree: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
-) -> () {
-    let forks: Vec<_> = editor
-        .func()
-        .nodes
-        .iter()
-        .enumerate()
-        .filter_map(|(idx, node)| {
-            if node.is_fork() {
-                Some(NodeID::new(idx))
-            } else {
-                None
-            }
-        })
+) -> Vec<NodeID> {
+    let mut forks: Vec<_> = loop_tree
+        .bottom_up_loops()
+        .into_iter()
+        .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
-    let control_pred = NodeID::new(0);
-
-    // This does the reduction fission:
+    let mut created_forks = Vec::new();
+    // This does the reduction fission 
     for fork in forks.clone() {
-        // FIXME: If there is control in between fork and join, don't just give up.
-        let join = fork_join_map[&fork];
-        let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
-        if join_pred != fork {
-            todo!("Can't do fork fission on nodes with internal control")
-            // Inner control LOOPs are hard
-            // inner control in general *should* work right now without modifications.
+        let join = fork_join_map[&fork.0];
+        let reduce_partition = default_reduce_partition(editor, fork.0, join);
+
+        if editor.is_mutable(fork.0) {
+            created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
+            if !created_forks.is_empty() {
+                break;
+            }
         }
-        let reduce_partition = default_reduce_partition(editor, fork, join);
-        fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
+            
     }
+
+    created_forks
 }
 
 /** Split a 1D fork into two forks, placing select intermediate data into buffers. */
@@ -488,48 +481,38 @@ where
     }
 }
 
-/** Split a 1D fork into a separate fork for each reduction. */
+/** Split a fork into a separate fork for each reduction. */
 pub fn fork_reduce_fission_helper<'a>(
     editor: &'a mut FunctionEditor,
     fork_join_map: &HashMap<NodeID, NodeID>,
     reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
-    original_control_pred: NodeID,           // What the new fork connects to.
-
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     fork: NodeID,
-) -> (NodeID, NodeID) {
+) -> Vec<NodeID> {
     let join = fork_join_map[&fork];
 
-    let mut new_control_pred: NodeID = original_control_pred;
-    // Important edges are: Reduces,
+    let mut new_forks = Vec::new();
 
-    // NOTE:
-    // Say two reduce are in a fork, s.t  reduce A depends on reduce B
-    // If user wants A and B in separate forks:
-    // - we can simply refuse
-    // - or we can duplicate B
+    let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap();
 
     let mut new_fork = NodeID::new(0);
     let mut new_join = NodeID::new(0);
 
+    let subgraph = &nodes_in_fork_joins[&fork]; 
+    
     // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
-    for reduce in reduce_partition {
-        let reduce = reduce.0;
-
-        let function = editor.func();
-        let subgraph = find_reduce_dependencies(function, reduce, fork);
-
-        let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();
-
-        subgraph.insert(join);
-        subgraph.insert(fork);
-        subgraph.insert(reduce);
-
-        let (_, mapping, _) = copy_subgraph(editor, subgraph);
+    editor.edit(|mut edit| {
+        for reduce in reduce_partition {
+            let reduce = reduce.0;
 
-        new_fork = mapping[&fork];
-        new_join = mapping[&join];
+            let a = copy_subgraph_in_edit(edit, subgraph.clone())?;
+            edit = a.0;
+            let mapping = a.1;
 
-        editor.edit(|mut edit| {
+            new_fork = mapping[&fork];
+            new_forks.push(new_fork);
+            new_join = mapping[&join];
+            
             // Atttach new_fork after control_pred
             let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
             edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
@@ -538,13 +521,9 @@ pub fn fork_reduce_fission_helper<'a>(
 
             // Replace uses of reduce
             edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
-            Ok(edit)
-        });
-
-        new_control_pred = new_join;
-    }
+            new_control_pred = new_join;
+        };
 
-    editor.edit(|mut edit| {
         // Replace original join w/ new final join
         edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
 
@@ -553,10 +532,12 @@ pub fn fork_reduce_fission_helper<'a>(
 
         // Replace all users of original fork, and then delete it, leftover users will be DCE'd.
         edit = edit.replace_all_uses(fork, new_fork)?;
-        edit.delete_node(fork)
+        edit = edit.delete_node(fork)?;
+
+        Ok(edit)
     });
 
-    (new_fork, new_join)
+    new_forks
 }
 
 pub fn fork_coalesce(
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index f46c91d6..185ea441 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -3,7 +3,7 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10);
 cpu(auto.test1);
 cpu(auto.test2);
 cpu(auto.test3);
@@ -12,6 +12,7 @@ cpu(auto.test5);
 cpu(auto.test7);
 cpu(auto.test8);
 cpu(auto.test9);
+cpu(auto.test10);
 
 let test1_cpu = auto.test1;
 rename["test1_cpu"](test1_cpu);
@@ -96,4 +97,19 @@ dce(auto.test8);
 
 no-memset(test9@const);
 
+fork-split(auto.test10);
+xdot[true](auto.test10);
+fork-fission(auto.test10);
+dce(auto.test10);
+xdot[false](auto.test10);
+dce(auto.test10);
+simplify-cfg(auto.test10);
+dce(auto.test10);
+xdot[true](auto.test10);
+
+fork-split(auto.test10);
+
+unforkify(auto.test10);
+
+
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 334fc2bf..51284c29 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -147,3 +147,19 @@ fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] {
 
   return out;
 }
+
+#[entry]
+fn test10(i0 : u64, i1 : u64, input : i32) -> i32 {
+  let arr0 : i32[4, 5];
+  let arr1 : i32[4, 5];
+  let arr2 : i32[4, 5];
+  @loop2 for k = 0 to 5 {
+    @loop3 for j = 0 to 4 {
+      arr0[j, k] += input;
+      arr1[j, k] += 2;
+      arr2[j, k] += input * 2;
+    }
+  }
+  
+  return arr0[i0, i1] + arr1[i0, i1] + arr2[i0, i1];
+}
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index e66309b2..277484f6 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -74,6 +74,11 @@ fn main() {
             5 + 6 + 8 + 9,
         ];
         assert(&correct, output);
+
+        let mut r = runner!(test10);
+        let output = r.run(1, 2, 5).await;
+        let correct = 17i32;
+        assert_eq!(correct, output);
     });
 }
 
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 9d020c64..39fb3469 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -126,7 +126,7 @@ impl FromStr for Appliable {
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             }
-            "fork-fission-bufferize" | "fork-fission" => {
+            "fork-fission-bufferize" => {
                 Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize))
             }
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
@@ -134,6 +134,7 @@ impl FromStr for Appliable {
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
+            "fork-fission" | "fission" => Ok(Appliable::Pass(ir::Pass::ForkFission)),
             "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
             "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index ab1495b8..d381e861 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -18,6 +18,7 @@ pub enum Pass {
     ForkDimMerge,
     ForkExtend,
     ForkFissionBufferize,
+    ForkFission, 
     ForkFusion,
     ForkGuardElim,
     ForkInterchange,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 456df2ed..9281925a 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2740,7 +2740,104 @@ fn run_pass(
             }
             pm.delete_gravestones();
             pm.clear_analyses();
-        }
+        },
+        Pass::ForkFission => {
+            assert!(args.len() == 0);
+
+            pm.make_fork_join_maps();
+            pm.make_loops();
+            pm.make_reduce_cycles();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
+            let reduce_cycles = pm.reduce_cycles.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+
+            let mut created_fork_joins = vec![vec![]; pm.functions.len()];
+
+            for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins,
+            ) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(fork_join_maps.iter())
+                .zip(loops.iter())
+                .zip(reduce_cycles.iter())
+                .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                
+                let new_forks = fork_fission(
+                    &mut func,
+                    nodes_in_fork_joins,
+                    reduce_cycles,
+                    loop_tree,
+                    fork_join_map,
+                );
+                
+                let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
+
+                for f in new_forks {
+                    created_fork_joins.push(f);
+                };
+
+                changed |= func.modified();
+            }
+
+            pm.clear_analyses();
+            pm.make_nodes_in_fork_joins();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            let mut new_fork_joins = HashMap::new();
+
+            for (mut func, created_fork_joins) in
+                build_editors(pm).into_iter().zip(created_fork_joins)
+            {
+                // For every function, create a label for every level of fork-
+                // joins resulting from the split.
+                let name = func.func().name.clone();
+                let func_id = func.func_id();
+                let labels = create_labels_for_node_sets(
+                    &mut func,
+                    created_fork_joins.into_iter().map(|fork| {
+                        nodes_in_fork_joins[func_id.idx()][&fork]
+                            .iter()
+                            .map(|id| *id)
+                    }),
+                );
+
+                // Assemble those labels into a record for this function. The
+                // format of the records is <function>.<f>, where N is the
+                // level of the split fork-joins being referred to.
+                let mut func_record = HashMap::new();
+                for (idx, label) in labels {
+                    let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" };
+                    func_record.insert(
+                        fmt.to_string(),
+                        Value::Label {
+                            labels: vec![LabelInfo {
+                                func: func_id,
+                                label: label,
+                            }],
+                        },
+                    );
+                }
+
+                // Try to avoid creating unnecessary record entries.
+                if !func_record.is_empty() {
+                    new_fork_joins.entry(name).insert_entry(Value::Record {
+                        fields: func_record,
+                    });
+                }
+            }
+
+            pm.delete_gravestones();
+            pm.clear_analyses();
+
+            result = Value::Record {
+                fields: new_fork_joins,
+            };
+
+        },
         Pass::ForkFissionBufferize => {
             assert!(args.len() == 1 || args.len() == 2);
             let Some(Value::Label {
@@ -2777,6 +2874,14 @@ fn run_pass(
 
             let fork_label = fork_labels[0].label;
 
+            // assert only one function is in the selection.
+            let num_functions = build_selection(pm, selection.clone(), false)
+            .iter()
+            .filter(|func| func.is_some())
+            .count();
+        
+            assert!(num_functions <= 1);
+
             for (
                 ((((func, fork_join_map), loop_tree), typing), reduce_cycles),
                 nodes_in_fork_joins,
-- 
GitLab