From bb1e436d9f5a2e4c2dc57b8d62166fe030caf24d Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Mar 2025 21:30:28 -0600
Subject: [PATCH 1/7] fork fission w/ reduces

---
 hercules_opt/src/fork_transforms.rs           | 107 +++++++-----------
 juno_samples/fork_join_tests/src/cpu.sch      |  18 ++-
 .../fork_join_tests/src/fork_join_tests.jn    |  16 +++
 juno_samples/fork_join_tests/src/main.rs      |   5 +
 juno_scheduler/src/compile.rs                 |   3 +-
 juno_scheduler/src/ir.rs                      |   1 +
 juno_scheduler/src/pm.rs                      | 107 +++++++++++++++++-
 7 files changed, 191 insertions(+), 66 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e6db0345..1df5338d 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -306,40 +306,33 @@ where
 
 pub fn fork_fission<'a>(
     editor: &'a mut FunctionEditor,
-    _control_subgraph: &Subgraph,
-    _types: &Vec<TypeID>,
-    _loop_tree: &LoopTree,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+    loop_tree: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
-) -> () {
-    let forks: Vec<_> = editor
-        .func()
-        .nodes
-        .iter()
-        .enumerate()
-        .filter_map(|(idx, node)| {
-            if node.is_fork() {
-                Some(NodeID::new(idx))
-            } else {
-                None
-            }
-        })
+) -> Vec<NodeID> {
+    let mut forks: Vec<_> = loop_tree
+        .bottom_up_loops()
+        .into_iter()
+        .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
-    let control_pred = NodeID::new(0);
-
-    // This does the reduction fission:
+    let mut created_forks = Vec::new();
+    // This does the reduction fission 
     for fork in forks.clone() {
-        // FIXME: If there is control in between fork and join, don't just give up.
-        let join = fork_join_map[&fork];
-        let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
-        if join_pred != fork {
-            todo!("Can't do fork fission on nodes with internal control")
-            // Inner control LOOPs are hard
-            // inner control in general *should* work right now without modifications.
+        let join = fork_join_map[&fork.0];
+        let reduce_partition = default_reduce_partition(editor, fork.0, join);
+
+        if editor.is_mutable(fork.0) {
+            created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
+            if !created_forks.is_empty() {
+                break;
+            }
         }
-        let reduce_partition = default_reduce_partition(editor, fork, join);
-        fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
+            
     }
+
+    created_forks
 }
 
 /** Split a 1D fork into two forks, placing select intermediate data into buffers. */
@@ -488,48 +481,38 @@ where
     }
 }
 
-/** Split a 1D fork into a separate fork for each reduction. */
+/** Split a fork into a separate fork for each reduction. */
 pub fn fork_reduce_fission_helper<'a>(
     editor: &'a mut FunctionEditor,
     fork_join_map: &HashMap<NodeID, NodeID>,
     reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
-    original_control_pred: NodeID,           // What the new fork connects to.
-
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     fork: NodeID,
-) -> (NodeID, NodeID) {
+) -> Vec<NodeID> {
     let join = fork_join_map[&fork];
 
-    let mut new_control_pred: NodeID = original_control_pred;
-    // Important edges are: Reduces,
+    let mut new_forks = Vec::new();
 
-    // NOTE:
-    // Say two reduce are in a fork, s.t  reduce A depends on reduce B
-    // If user wants A and B in separate forks:
-    // - we can simply refuse
-    // - or we can duplicate B
+    let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap();
 
     let mut new_fork = NodeID::new(0);
     let mut new_join = NodeID::new(0);
 
+    let subgraph = &nodes_in_fork_joins[&fork]; 
+    
     // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
-    for reduce in reduce_partition {
-        let reduce = reduce.0;
-
-        let function = editor.func();
-        let subgraph = find_reduce_dependencies(function, reduce, fork);
-
-        let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();
-
-        subgraph.insert(join);
-        subgraph.insert(fork);
-        subgraph.insert(reduce);
-
-        let (_, mapping, _) = copy_subgraph(editor, subgraph);
+    editor.edit(|mut edit| {
+        for reduce in reduce_partition {
+            let reduce = reduce.0;
 
-        new_fork = mapping[&fork];
-        new_join = mapping[&join];
+            let a = copy_subgraph_in_edit(edit, subgraph.clone())?;
+            edit = a.0;
+            let mapping = a.1;
 
-        editor.edit(|mut edit| {
+            new_fork = mapping[&fork];
+            new_forks.push(new_fork);
+            new_join = mapping[&join];
+            
             // Atttach new_fork after control_pred
             let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
             edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
@@ -538,13 +521,9 @@ pub fn fork_reduce_fission_helper<'a>(
 
             // Replace uses of reduce
             edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
-            Ok(edit)
-        });
-
-        new_control_pred = new_join;
-    }
+            new_control_pred = new_join;
+        };
 
-    editor.edit(|mut edit| {
         // Replace original join w/ new final join
         edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
 
@@ -553,10 +532,12 @@ pub fn fork_reduce_fission_helper<'a>(
 
         // Replace all users of original fork, and then delete it, leftover users will be DCE'd.
         edit = edit.replace_all_uses(fork, new_fork)?;
-        edit.delete_node(fork)
+        edit = edit.delete_node(fork)?;
+
+        Ok(edit)
     });
 
-    (new_fork, new_join)
+    new_forks
 }
 
 pub fn fork_coalesce(
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index f46c91d6..185ea441 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -3,7 +3,7 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10);
 cpu(auto.test1);
 cpu(auto.test2);
 cpu(auto.test3);
@@ -12,6 +12,7 @@ cpu(auto.test5);
 cpu(auto.test7);
 cpu(auto.test8);
 cpu(auto.test9);
+cpu(auto.test10);
 
 let test1_cpu = auto.test1;
 rename["test1_cpu"](test1_cpu);
@@ -96,4 +97,19 @@ dce(auto.test8);
 
 no-memset(test9@const);
 
+fork-split(auto.test10);
+xdot[true](auto.test10);
+fork-fission(auto.test10);
+dce(auto.test10);
+xdot[false](auto.test10);
+dce(auto.test10);
+simplify-cfg(auto.test10);
+dce(auto.test10);
+xdot[true](auto.test10);
+
+fork-split(auto.test10);
+
+unforkify(auto.test10);
+
+
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 334fc2bf..51284c29 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -147,3 +147,19 @@ fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] {
 
   return out;
 }
+
+#[entry]
+fn test10(i0 : u64, i1 : u64, input : i32) -> i32 {
+  let arr0 : i32[4, 5];
+  let arr1 : i32[4, 5];
+  let arr2 : i32[4, 5];
+  @loop2 for k = 0 to 5 {
+    @loop3 for j = 0 to 4 {
+      arr0[j, k] += input;
+      arr1[j, k] += 2;
+      arr2[j, k] += input * 2;
+    }
+  }
+  
+  return arr0[i0, i1] + arr1[i0, i1] + arr2[i0, i1];
+}
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index e66309b2..277484f6 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -74,6 +74,11 @@ fn main() {
             5 + 6 + 8 + 9,
         ];
         assert(&correct, output);
+
+        let mut r = runner!(test10);
+        let output = r.run(1, 2, 5).await;
+        let correct = 17i32;
+        assert_eq!(correct, output);
     });
 }
 
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 9d020c64..39fb3469 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -126,7 +126,7 @@ impl FromStr for Appliable {
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             }
-            "fork-fission-bufferize" | "fork-fission" => {
+            "fork-fission-bufferize" => {
                 Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize))
             }
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
@@ -134,6 +134,7 @@ impl FromStr for Appliable {
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
+            "fork-fission" | "fission" => Ok(Appliable::Pass(ir::Pass::ForkFission)),
             "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
             "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index ab1495b8..d381e861 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -18,6 +18,7 @@ pub enum Pass {
     ForkDimMerge,
     ForkExtend,
     ForkFissionBufferize,
+    ForkFission, 
     ForkFusion,
     ForkGuardElim,
     ForkInterchange,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 456df2ed..9281925a 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2740,7 +2740,104 @@ fn run_pass(
             }
             pm.delete_gravestones();
             pm.clear_analyses();
-        }
+        },
+        Pass::ForkFission => {
+            assert!(args.len() == 0);
+
+            pm.make_fork_join_maps();
+            pm.make_loops();
+            pm.make_reduce_cycles();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
+            let reduce_cycles = pm.reduce_cycles.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+
+            let mut created_fork_joins = vec![vec![]; pm.functions.len()];
+
+            for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins,
+            ) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(fork_join_maps.iter())
+                .zip(loops.iter())
+                .zip(reduce_cycles.iter())
+                .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                
+                let new_forks = fork_fission(
+                    &mut func,
+                    nodes_in_fork_joins,
+                    reduce_cycles,
+                    loop_tree,
+                    fork_join_map,
+                );
+                
+                let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
+
+                for f in new_forks {
+                    created_fork_joins.push(f);
+                };
+
+                changed |= func.modified();
+            }
+
+            pm.clear_analyses();
+            pm.make_nodes_in_fork_joins();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            let mut new_fork_joins = HashMap::new();
+
+            for (mut func, created_fork_joins) in
+                build_editors(pm).into_iter().zip(created_fork_joins)
+            {
+                // For every function, create a label for every level of fork-
+                // joins resulting from the split.
+                let name = func.func().name.clone();
+                let func_id = func.func_id();
+                let labels = create_labels_for_node_sets(
+                    &mut func,
+                    created_fork_joins.into_iter().map(|fork| {
+                        nodes_in_fork_joins[func_id.idx()][&fork]
+                            .iter()
+                            .map(|id| *id)
+                    }),
+                );
+
+                // Assemble those labels into a record for this function. The
+                // format of the records is <function>.<f>, where N is the
+                // level of the split fork-joins being referred to.
+                let mut func_record = HashMap::new();
+                for (idx, label) in labels {
+                    let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" };
+                    func_record.insert(
+                        fmt.to_string(),
+                        Value::Label {
+                            labels: vec![LabelInfo {
+                                func: func_id,
+                                label: label,
+                            }],
+                        },
+                    );
+                }
+
+                // Try to avoid creating unnecessary record entries.
+                if !func_record.is_empty() {
+                    new_fork_joins.entry(name).insert_entry(Value::Record {
+                        fields: func_record,
+                    });
+                }
+            }
+
+            pm.delete_gravestones();
+            pm.clear_analyses();
+
+            result = Value::Record {
+                fields: new_fork_joins,
+            };
+
+        },
         Pass::ForkFissionBufferize => {
             assert!(args.len() == 1 || args.len() == 2);
             let Some(Value::Label {
@@ -2777,6 +2874,14 @@ fn run_pass(
 
             let fork_label = fork_labels[0].label;
 
+            // assert only one function is in the selection.
+            let num_functions = build_selection(pm, selection.clone(), false)
+            .iter()
+            .filter(|func| func.is_some())
+            .count();
+        
+            assert!(num_functions <= 1);
+
             for (
                 ((((func, fork_join_map), loop_tree), typing), reduce_cycles),
                 nodes_in_fork_joins,
-- 
GitLab


From ef3b1cf7bdd7b61028dc778bd57c30f674b59d3d Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Mar 2025 21:43:20 -0600
Subject: [PATCH 2/7] rename passes

---
 juno_samples/fork_join_tests/src/cpu.sch | 2 +-
 juno_scheduler/src/compile.rs            | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 185ea441..71a6eb92 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -99,7 +99,7 @@ no-memset(test9@const);
 
 fork-split(auto.test10);
 xdot[true](auto.test10);
-fork-fission(auto.test10);
+fork-fission-reduces(auto.test10);
 dce(auto.test10);
 xdot[false](auto.test10);
 dce(auto.test10);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 39fb3469..fbe6d8f0 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -126,7 +126,7 @@ impl FromStr for Appliable {
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             }
-            "fork-fission-bufferize" => {
+            "fork-fission-bufferize" | "fork-fission" => {
                 Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize))
             }
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
@@ -134,7 +134,7 @@ impl FromStr for Appliable {
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
-            "fork-fission" | "fission" => Ok(Appliable::Pass(ir::Pass::ForkFission)),
+            "fork-fission-reduces" => Ok(Appliable::Pass(ir::Pass::ForkFission)),
             "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
             "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
-- 
GitLab


From 3f1067aad0779be73c635e024a3ffe2ba4b18c36 Mon Sep 17 00:00:00 2001
From: Russel Arbore <rarbore2@illinois.edu>
Date: Sat, 1 Mar 2025 22:35:26 -0600
Subject: [PATCH 3/7] remove xdot

---
 juno_samples/fork_join_tests/src/cpu.sch | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 71a6eb92..df1b0345 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -98,14 +98,11 @@ dce(auto.test8);
 no-memset(test9@const);
 
 fork-split(auto.test10);
-xdot[true](auto.test10);
 fork-fission-reduces(auto.test10);
 dce(auto.test10);
-xdot[false](auto.test10);
 dce(auto.test10);
 simplify-cfg(auto.test10);
 dce(auto.test10);
-xdot[true](auto.test10);
 
 fork-split(auto.test10);
 
-- 
GitLab


From 0d4fb1cfde5d26305d25688a6d91b0cbf9e244be Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Mar 2025 10:00:30 -0600
Subject: [PATCH 4/7] fix labels for fork-fission-reduces, + gpu schedule for
 test

---
 hercules_opt/src/fork_transforms.rs      | 18 +++++++++---
 juno_samples/fork_join_tests/src/cpu.sch |  8 ++----
 juno_samples/fork_join_tests/src/gpu.sch |  9 +++++-
 juno_scheduler/src/ir.rs                 |  2 ++
 juno_scheduler/src/pm.rs                 | 35 +++++++++++++++++-------
 5 files changed, 51 insertions(+), 21 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 1df5338d..7367a63e 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -310,23 +310,33 @@ pub fn fork_fission<'a>(
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     loop_tree: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_label: LabelID,
 ) -> Vec<NodeID> {
-    let mut forks: Vec<_> = loop_tree
+    let forks: Vec<_> = loop_tree
         .bottom_up_loops()
         .into_iter()
         .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
     let mut created_forks = Vec::new();
+    
     // This does the reduction fission 
-    for fork in forks.clone() {
+    for fork in forks {
         let join = fork_join_map[&fork.0];
+
+        // FIXME: Don't make multiple forks for reduces that are in cycles with each other. 
         let reduce_partition = default_reduce_partition(editor, fork.0, join);
 
+        if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
+            continue;
+        }
+
         if editor.is_mutable(fork.0) {
             created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
-            if !created_forks.is_empty() {
-                break;
+            if created_forks.is_empty() {
+                continue;
+            } else {
+                return created_forks;
             }
         }
             
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 71a6eb92..a3bb146b 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -98,15 +98,11 @@ dce(auto.test8);
 no-memset(test9@const);
 
 fork-split(auto.test10);
-xdot[true](auto.test10);
-fork-fission-reduces(auto.test10);
-dce(auto.test10);
-xdot[false](auto.test10);
+let fission_out = fork-fission-reduces[test10@loop3](auto.test10);
+fork-fusion(fission_out.test10_8.fj0, fission_out.test10_8.fj1);
 dce(auto.test10);
 simplify-cfg(auto.test10);
 dce(auto.test10);
-xdot[true](auto.test10);
-
 fork-split(auto.test10);
 
 unforkify(auto.test10);
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 81dc8d98..0d828f48 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -13,7 +13,7 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10);
 gpu(auto.test1);
 gpu(auto.test2);
 gpu(auto.test3);
@@ -22,6 +22,7 @@ gpu(auto.test5);
 gpu(auto.test7);
 gpu(auto.test8);
 gpu(auto.test9);
+gpu(auto.test10);
 
 ip-sroa(*);
 sroa(*);
@@ -75,6 +76,12 @@ dce(auto.test8);
 
 no-memset(test9@const);
 
+fork-split(auto.test10);
+fork-fission-reduces(auto.test10);
+dce(auto.test10);
+simplify-cfg(auto.test10);
+dce(auto.test10);
+
 ip-sroa(*);
 sroa(*);
 dce(*);
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index d381e861..dd7a76a7 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -60,6 +60,7 @@ impl Pass {
             Pass::ForkChunk => num == 4,
             Pass::ForkExtend => num == 1,
             Pass::ForkFissionBufferize => num == 2 || num == 1,
+            Pass::ForkFission => num == 1,
             Pass::ForkInterchange => num == 2,
             Pass::ForkReshape => true,
             Pass::InterproceduralSROA => num == 0 || num == 1,
@@ -78,6 +79,7 @@ impl Pass {
             Pass::ForkChunk => "4",
             Pass::ForkExtend => "1",
             Pass::ForkFissionBufferize => "1 or 2",
+            Pass::ForkFission => "1",
             Pass::ForkInterchange => "2",
             Pass::ForkReshape => "any",
             Pass::InterproceduralSROA => "0 or 1",
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 9281925a..666ebdcd 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2742,7 +2742,18 @@ fn run_pass(
             pm.clear_analyses();
         },
         Pass::ForkFission => {
-            assert!(args.len() == 0);
+            assert!(args.len() == 1);
+
+            let Some(Value::Label {
+                labels: fork_labels,
+            }) = args.get(0)
+            else {
+                return Err(SchedulerError::PassError {
+                    pass: "forkFission".to_string(),
+                    error: "expected label argument".to_string(),
+                });
+            };
+
 
             pm.make_fork_join_maps();
             pm.make_loops();
@@ -2755,6 +2766,17 @@ fn run_pass(
 
             let mut created_fork_joins = vec![vec![]; pm.functions.len()];
 
+            // assert only one function is in the selection.
+            let num_functions = build_selection(pm, selection.clone(), false)
+                .iter()
+                .filter(|func| func.is_some())
+                .count();
+
+            assert!(num_functions <= 1);
+            assert_eq!(fork_labels.len(), 1);
+
+            let fork_label = fork_labels[0].label;
+
             for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins,
             ) in build_selection(pm, selection, false)
                 .into_iter()
@@ -2773,6 +2795,7 @@ fn run_pass(
                     reduce_cycles,
                     loop_tree,
                     fork_join_map,
+                    fork_label,
                 );
                 
                 let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
@@ -2810,7 +2833,7 @@ fn run_pass(
                 // level of the split fork-joins being referred to.
                 let mut func_record = HashMap::new();
                 for (idx, label) in labels {
-                    let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" };
+                    let fmt = format!("fj{idx}");
                     func_record.insert(
                         fmt.to_string(),
                         Value::Label {
@@ -2874,14 +2897,6 @@ fn run_pass(
 
             let fork_label = fork_labels[0].label;
 
-            // assert only one function is in the selection.
-            let num_functions = build_selection(pm, selection.clone(), false)
-            .iter()
-            .filter(|func| func.is_some())
-            .count();
-        
-            assert!(num_functions <= 1);
-
             for (
                 ((((func, fork_join_map), loop_tree), typing), reduce_cycles),
                 nodes_in_fork_joins,
-- 
GitLab


From 9c111a2930bdd599d1707afd7a1739243aed8c16 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Mar 2025 12:20:49 -0600
Subject: [PATCH 5/7] oops

---
 juno_samples/fork_join_tests/src/gpu.sch | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 0d828f48..c6d6bee7 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -77,7 +77,7 @@ dce(auto.test8);
 no-memset(test9@const);
 
 fork-split(auto.test10);
-fork-fission-reduces(auto.test10);
+fork-fission-reduces[test10@loop3](auto.test10);
 dce(auto.test10);
 simplify-cfg(auto.test10);
 dce(auto.test10);
-- 
GitLab


From 9b1936286e5c3cda44a4fb505f59c3c81915a013 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Mar 2025 13:57:22 -0600
Subject: [PATCH 6/7] float collections

---
 juno_samples/fork_join_tests/src/gpu.sch | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index c6d6bee7..b8e030a8 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -78,9 +78,11 @@ no-memset(test9@const);
 
 fork-split(auto.test10);
 fork-fission-reduces[test10@loop3](auto.test10);
+
 dce(auto.test10);
 simplify-cfg(auto.test10);
 dce(auto.test10);
+xdot[true](auto.test10);
 
 ip-sroa(*);
 sroa(*);
@@ -91,5 +93,5 @@ phi-elim(*);
 dce(*);
 
 gcm(*);
-float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5);
+float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5, test10, auto.test10);
 gcm(*);
-- 
GitLab


From 3d41b6f77aa4a5075c89805445cad227c86f5ac1 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Mar 2025 14:01:33 -0600
Subject: [PATCH 7/7] XDOT

---
 juno_samples/fork_join_tests/src/gpu.sch | 1 -
 1 file changed, 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index b8e030a8..18caa8e5 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -82,7 +82,6 @@ fork-fission-reduces[test10@loop3](auto.test10);
 dce(auto.test10);
 simplify-cfg(auto.test10);
 dce(auto.test10);
-xdot[true](auto.test10);
 
 ip-sroa(*);
 sroa(*);
-- 
GitLab