From 6a0c4a3410d40410dbdac7a92926583e9c6e640a Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Feb 2025 15:04:24 -0600
Subject: [PATCH 01/33] control in reduce cycle fixes

---
 Cargo.lock                                    | 11 ++++
 Cargo.toml                                    |  3 +-
 hercules_opt/src/forkify.rs                   | 12 +++++
 hercules_opt/src/unforkify.rs                 |  2 +-
 .../hercules_tests/tests/loop_tests.rs        |  4 +-
 juno_scheduler/src/pm.rs                      | 53 +++++++++++--------
 6 files changed, 59 insertions(+), 26 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 49630436..ad69bc72 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1181,6 +1181,17 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_test"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "rand",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_utils"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index ced011a9..46fc7eaa 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,7 +21,7 @@ members = [
 	"hercules_samples/ccp",
 
 	"juno_samples/simple3",
-  "juno_samples/patterns",
+	"juno_samples/patterns",
 	"juno_samples/matmul",
 	"juno_samples/casts_and_intrinsics",
 	"juno_samples/nested_ccp",
@@ -30,4 +30,5 @@ members = [
   	"juno_samples/cava",
 	"juno_samples/concat",
   	"juno_samples/schedule_test",
+	"juno_samples/test",
 ]
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index ec4e9fbc..0f06627d 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -152,6 +152,7 @@ pub fn forkify_loop(
         .filter(|id| !l.control[id.idx()])
         .collect();
 
+    // FIXME: @xrouth
     if loop_preds.len() != 1 {
         return false;
     }
@@ -388,6 +389,7 @@ nest! {
             is_associative: bool,
         },
         LoopDependant(NodeID),
+        ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop.
         UsedByDependant(NodeID),
     }
 }
@@ -398,6 +400,7 @@ impl LoopPHI {
             LoopPHI::Reductionable { phi, .. } => *phi,
             LoopPHI::LoopDependant(node_id) => *node_id,
             LoopPHI::UsedByDependant(node_id) => *node_id,
+            LoopPHI::ControlDependant(node_id) => *node_id,
         }
     }
 }
@@ -415,6 +418,9 @@ pub fn analyze_phis<'a>(
     loop_nodes: &'a HashSet<NodeID>,
 ) -> impl Iterator<Item = LoopPHI> + 'a {
 
+    // We are also moving the phi from the top of the loop (the header),
+    // to the very end (the join). If there are uses of the phi somewhere in the loop,
+    // then they may try to use the phi (now a reduce) before it hits the join. 
     // Find data cycles within the loop of this phi, 
     // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. 
 
@@ -509,6 +515,12 @@ pub fn analyze_phis<'a>(
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
 
+            // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
+            // Which is not allowed.
+            if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
+                return LoopPHI::ControlDependant(*phi);
+            }
+
             // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
             // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
             if intersection
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 85ffd233..7d158d1a 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -133,7 +133,7 @@ pub fn unforkify(
         if factors.len() > 1 {
             // For now, don't convert multi-dimensional fork-joins. Rely on pass
             // that splits fork-joins.
-            continue;
+            break; // Because we have to unforkify top down, we can't unforkify forks that are contained 
         }
         let join_control = nodes[join.idx()].try_join().unwrap();
         let tids: Vec<_> = editor
diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs
index 5832a161..192c1366 100644
--- a/hercules_test/hercules_tests/tests/loop_tests.rs
+++ b/hercules_test/hercules_tests/tests/loop_tests.rs
@@ -401,7 +401,7 @@ fn matmul_pipeline() {
     let dyn_consts = [I, J, K];
 
     // FIXME: This path should not leave the crate
-    let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin");
+    let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin");
     //
     let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
     for i in 0..I {
@@ -425,7 +425,7 @@ fn matmul_pipeline() {
     };
     assert_eq!(correct_c[0], value);
 
-    let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]);
+    let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]);
 
     module = run_schedule_on_hercules(module, schedule).unwrap();
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 2371e0f2..d2772c71 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1471,29 +1471,38 @@ fn run_pass(
         }
         Pass::Forkify => {
             assert!(args.is_empty());
-            pm.make_fork_join_maps();
-            pm.make_control_subgraphs();
-            pm.make_loops();
-            let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            let loops = pm.loops.take().unwrap();
-            let control_subgraphs = pm.control_subgraphs.take().unwrap();
-            for (((func, fork_join_map), loop_nest), control_subgraph) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
-                    .zip(loops.iter())
-                    .zip(control_subgraphs.iter())
-            {
-                let Some(mut func) = func else {
-                    continue;
-                };
-                // TODO: uses direct return from forkify for now instead of
-                // func.modified, see comment on top of `forkify` for why. Fix
-                // this eventually.
-                changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+            loop {
+                let mut inner_changed = false;
+                pm.make_fork_join_maps();
+                pm.make_control_subgraphs();
+                pm.make_loops();
+                let fork_join_maps = pm.fork_join_maps.take().unwrap();
+                let loops = pm.loops.take().unwrap();
+                let control_subgraphs = pm.control_subgraphs.take().unwrap();
+                for (((func, fork_join_map), loop_nest), control_subgraph) in
+                    build_selection(pm, selection.clone())
+                        .into_iter()
+                        .zip(fork_join_maps.iter())
+                        .zip(loops.iter())
+                        .zip(control_subgraphs.iter())
+                {
+                    let Some(mut func) = func else {
+                        continue;
+                    };
+                    // TODO: uses direct return from forkify for now instead of
+                    // func.modified, see comment on top of `forkify` for why. Fix
+                    // this eventually.
+                    let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+                    changed |= c;
+                    inner_changed |= c; 
+                }
+                pm.delete_gravestones();
+                pm.clear_analyses();
+
+                if !inner_changed {
+                    break;
+                }
             }
-            pm.delete_gravestones();
-            pm.clear_analyses();
         }
         Pass::GCM => {
             assert!(args.is_empty());
-- 
GitLab


From fab913636e8b63cef26af9db6df2eb699f415161 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Feb 2025 16:30:10 -0600
Subject: [PATCH 02/33] misc

---
 hercules_opt/src/forkify.rs | 1 -
 1 file changed, 1 deletion(-)

diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 0f06627d..299422c1 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -514,7 +514,6 @@ pub fn analyze_phis<'a>(
             // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
-
             // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
             // Which is not allowed.
             if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
-- 
GitLab


From 43b4022c38f65fcf5c2abad523c63f59888dabe9 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 18:29:52 -0600
Subject: [PATCH 03/33] simple test

---
 Cargo.lock                                    | 10 +++++++
 Cargo.toml                                    |  1 +
 juno_samples/fork_join_tests/Cargo.toml       | 21 +++++++++++++
 juno_samples/fork_join_tests/build.rs         | 24 +++++++++++++++
 .../fork_join_tests/src/fork_join_tests.jn    | 10 +++++++
 juno_samples/fork_join_tests/src/gpu.sch      | 30 +++++++++++++++++++
 juno_samples/fork_join_tests/src/main.rs      | 17 +++++++++++
 7 files changed, 113 insertions(+)
 create mode 100644 juno_samples/fork_join_tests/Cargo.toml
 create mode 100644 juno_samples/fork_join_tests/build.rs
 create mode 100644 juno_samples/fork_join_tests/src/fork_join_tests.jn
 create mode 100644 juno_samples/fork_join_tests/src/gpu.sch
 create mode 100644 juno_samples/fork_join_tests/src/main.rs

diff --git a/Cargo.lock b/Cargo.lock
index b8bf2278..af7902c6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1130,6 +1130,16 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_fork_join_tests"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_frontend"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index f7b9322a..890d7924 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,4 +31,5 @@ members = [
 	"juno_samples/concat",
 	"juno_samples/schedule_test",
 	"juno_samples/edge_detection",
+	"juno_samples/fork_join_tests",
 ]
diff --git a/juno_samples/fork_join_tests/Cargo.toml b/juno_samples/fork_join_tests/Cargo.toml
new file mode 100644
index 00000000..a109e782
--- /dev/null
+++ b/juno_samples/fork_join_tests/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "juno_fork_join_tests"
+version = "0.1.0"
+authors = ["Russel Arbore <rarbore2@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_fork_join_tests"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
diff --git a/juno_samples/fork_join_tests/build.rs b/juno_samples/fork_join_tests/build.rs
new file mode 100644
index 00000000..796e9f32
--- /dev/null
+++ b/juno_samples/fork_join_tests/build.rs
@@ -0,0 +1,24 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    #[cfg(not(feature = "cuda"))]
+    {
+        JunoCompiler::new()
+            .file_in_src("fork_join_tests.jn")
+            .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+    #[cfg(feature = "cuda")]
+    {
+        JunoCompiler::new()
+            .file_in_src("fork_join_tests.jn")
+            .unwrap()
+            .schedule_in_src("gpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+}
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
new file mode 100644
index 00000000..aa8eb4bb
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -0,0 +1,10 @@
+#[entry]
+fn test1(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 4 {
+    for j = 0 to 4 {
+      arr[i, j] = input;
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
new file mode 100644
index 00000000..e2fe980e
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -0,0 +1,30 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+let out = auto-outline(*);
+gpu(out.test1);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+}
+
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  infer-schedules(*);
+}
+xdot[true](*);
+
+gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
new file mode 100644
index 00000000..6e5f2182
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -0,0 +1,17 @@
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
+
+juno_build::juno!("fork_join_tests");
+
+fn main() {
+    async_std::task::block_on(async {
+        let mut r = runner!(tests1);
+        let output = r.run(5).await;
+    });
+}
+
+#[test]
+fn implicit_clone_test() {
+    main();
+}
-- 
GitLab


From 5eeb70cb675bdd7ac2cf4ac4887f803d24ebe165 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 18:52:57 -0600
Subject: [PATCH 04/33] fixes for test

---
 hercules_cg/src/gpu.rs                   |  8 +++---
 juno_samples/fork_join_tests/src/cpu.sch | 31 ++++++++++++++++++++++++
 juno_samples/fork_join_tests/src/gpu.sch |  1 -
 juno_samples/fork_join_tests/src/main.rs | 13 +++++++++-
 juno_scheduler/src/pm.rs                 |  2 ++
 5 files changed, 50 insertions(+), 5 deletions(-)
 create mode 100644 juno_samples/fork_join_tests/src/cpu.sch

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 81e31396..afc016a4 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -539,15 +539,17 @@ namespace cg = cooperative_groups;
         w: &mut String,
     ) -> Result<(), Error> {
         write!(w, "\n")?;
-        for (id, goto) in gotos.iter() {
-            let goto_block = self.get_block_name(*id, false);
+        let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
+        for id in rev_po {
+            let goto = &gotos[&id];
+            let goto_block = self.get_block_name(id, false);
             write!(w, "{}:\n", goto_block)?;
             if goto_debug {
                 write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?;
             }
             write!(w, "{}", goto.init)?;
             if !goto.post_init.is_empty() {
-                let goto_block = self.get_block_name(*id, true);
+                let goto_block = self.get_block_name(id, true);
                 write!(w, "{}:\n", goto_block)?;
                 write!(w, "{}", goto.post_init)?;
             }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
new file mode 100644
index 00000000..81f5a12c
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -0,0 +1,31 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+let out = auto-outline(*);
+cpu(out.test1);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+}
+
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  infer-schedules(*);
+}
+fork-split(*);
+unforkify(*);
+
+gcm(*);
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index e2fe980e..e4e4e04f 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -25,6 +25,5 @@ dce(*);
 fixpoint panic after 20 {
   infer-schedules(*);
 }
-xdot[true](*);
 
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 6e5f2182..a63b3f78 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -6,8 +6,19 @@ juno_build::juno!("fork_join_tests");
 
 fn main() {
     async_std::task::block_on(async {
-        let mut r = runner!(tests1);
+        let mut r = runner!(test1);
         let output = r.run(5).await;
+        let correct = vec![5i32; 16];
+        #[cfg(not(feature = "cuda"))]
+        {
+            assert_eq!(output.as_slice::<i32>(), &correct);
+        }
+        #[cfg(feature = "cuda")]
+        {
+            let mut dst = vec![0i32; 16];
+            let output = output.to_cpu_ref(&mut dst);
+            assert_eq!(output.as_slice::<i32>(), &correct);
+        }
     });
 }
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index f6fe2fc1..5d398804 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -575,6 +575,8 @@ impl PassManager {
         self.postdoms = None;
         self.fork_join_maps = None;
         self.fork_join_nests = None;
+        self.fork_control_maps = None;
+        self.fork_trees = None;
         self.loops = None;
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
-- 
GitLab


From d67263e6a871fa539067ec84e87c58f242809ea4 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 10:50:56 -0600
Subject: [PATCH 05/33] second fork test, fails on cpu w/ forkify/unforkify

---
 hercules_opt/src/float_collections.rs         | 27 ++++++++++++------
 juno_samples/fork_join_tests/src/cpu.sch      |  1 +
 .../fork_join_tests/src/fork_join_tests.jn    | 13 +++++++++
 juno_samples/fork_join_tests/src/gpu.sch      |  2 ++
 juno_samples/fork_join_tests/src/main.rs      | 28 ++++++++++++-------
 juno_scheduler/src/pm.rs                      | 17 +++++------
 6 files changed, 60 insertions(+), 28 deletions(-)

diff --git a/hercules_opt/src/float_collections.rs b/hercules_opt/src/float_collections.rs
index faa38375..6ef050c2 100644
--- a/hercules_opt/src/float_collections.rs
+++ b/hercules_opt/src/float_collections.rs
@@ -1,3 +1,5 @@
+use std::collections::BTreeMap;
+
 use hercules_ir::*;
 
 use crate::*;
@@ -7,27 +9,36 @@ use crate::*;
  * allowed.
  */
 pub fn float_collections(
-    editors: &mut [FunctionEditor],
+    editors: &mut BTreeMap<FunctionID, FunctionEditor>,
     typing: &ModuleTyping,
     callgraph: &CallGraph,
     devices: &Vec<Device>,
 ) {
-    let topo = callgraph.topo();
+    let topo: Vec<_> = callgraph
+        .topo()
+        .into_iter()
+        .filter(|id| editors.contains_key(&id))
+        .collect();
     for to_float_id in topo {
         // Collection constants float until reaching an AsyncRust function.
         if devices[to_float_id.idx()] == Device::AsyncRust {
             continue;
         }
 
+        // Check that all callers are in the selection as well.
+        for caller in callgraph.get_callers(to_float_id) {
+            assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller);
+        }
+
         // Find the target constant nodes in the function.
-        let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()]
+        let cons: Vec<(NodeID, Node)> = editors[&to_float_id]
             .func()
             .nodes
             .iter()
             .enumerate()
             .filter(|(_, node)| {
                 node.try_constant()
-                    .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar())
+                    .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar())
                     .unwrap_or(false)
             })
             .map(|(idx, node)| (NodeID::new(idx), node.clone()))
@@ -37,12 +48,12 @@ pub fn float_collections(
         }
 
         // Each constant node becomes a new parameter.
-        let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone();
+        let mut new_param_types = editors[&to_float_id].func().param_types.clone();
         let old_num_params = new_param_types.len();
         for (id, _) in cons.iter() {
             new_param_types.push(typing[to_float_id.idx()][id.idx()]);
         }
-        let success = editors[to_float_id.idx()].edit(|mut edit| {
+        let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| {
             for (idx, (id, _)) in cons.iter().enumerate() {
                 let param = edit.add_node(Node::Parameter {
                     index: idx + old_num_params,
@@ -59,7 +70,7 @@ pub fn float_collections(
 
         // Add constants in callers and pass them into calls.
         for caller in callgraph.get_callers(to_float_id) {
-            let calls: Vec<(NodeID, Node)> = editors[caller.idx()]
+            let calls: Vec<(NodeID, Node)> = editors[&caller]
                 .func()
                 .nodes
                 .iter()
@@ -71,7 +82,7 @@ pub fn float_collections(
                 })
                 .map(|(idx, node)| (NodeID::new(idx), node.clone()))
                 .collect();
-            let success = editors[caller.idx()].edit(|mut edit| {
+            let success = editors.get_mut(&caller).unwrap().edit(|mut edit| {
                 let cons_ids: Vec<_> = cons
                     .iter()
                     .map(|(_, node)| edit.add_node(node.clone()))
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 81f5a12c..a6b1afe7 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -4,6 +4,7 @@ dce(*);
 
 let out = auto-outline(*);
 cpu(out.test1);
+cpu(out.test2);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index aa8eb4bb..4a6a94c9 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -8,3 +8,16 @@ fn test1(input : i32) -> i32[4, 4] {
   }
   return arr;
 }
+
+#[entry]
+fn test2(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 8 {
+    for j = 0 to 4 {
+      for k = 0 to 4 {
+        arr[j, k] += input;
+      }
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index e4e4e04f..b506c4a4 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -4,6 +4,7 @@ dce(*);
 
 let out = auto-outline(*);
 gpu(out.test1);
+gpu(out.test2);
 
 ip-sroa(*);
 sroa(*);
@@ -26,4 +27,5 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
+float-collections(test2, out.test2);
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index a63b3f78..de1b0805 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -5,20 +5,28 @@ use hercules_rt::runner;
 juno_build::juno!("fork_join_tests");
 
 fn main() {
+    #[cfg(not(feature = "cuda"))]
+    let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| {
+        assert_eq!(output.as_slice::<i32>(), &correct);
+    };
+
+    #[cfg(feature = "cuda")]
+    let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| {
+        let mut dst = vec![0i32; 16];
+        let output = output.to_cpu_ref(&mut dst);
+        assert_eq!(output.as_slice::<i32>(), &correct);
+    };
+
     async_std::task::block_on(async {
         let mut r = runner!(test1);
         let output = r.run(5).await;
         let correct = vec![5i32; 16];
-        #[cfg(not(feature = "cuda"))]
-        {
-            assert_eq!(output.as_slice::<i32>(), &correct);
-        }
-        #[cfg(feature = "cuda")]
-        {
-            let mut dst = vec![0i32; 16];
-            let output = output.to_cpu_ref(&mut dst);
-            assert_eq!(output.as_slice::<i32>(), &correct);
-        }
+        assert(correct, output);
+
+        let mut r = runner!(test2);
+        let output = r.run(3).await;
+        let correct = vec![24i32; 16];
+        assert(correct, output);
     });
 }
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5d398804..1ebc885c 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1305,7 +1305,7 @@ fn run_pass(
     pm: &mut PassManager,
     pass: Pass,
     args: Vec<Value>,
-    selection: Option<Vec<CodeLocation>>,
+    mut selection: Option<Vec<CodeLocation>>,
 ) -> Result<(Value, bool), SchedulerError> {
     let mut result = Value::Record {
         fields: HashMap::new(),
@@ -1441,13 +1441,6 @@ fn run_pass(
         }
         Pass::FloatCollections => {
             assert!(args.is_empty());
-            if let Some(_) = selection {
-                return Err(SchedulerError::PassError {
-                    pass: "floatCollections".to_string(),
-                    error: "must be applied to the entire module".to_string(),
-                });
-            }
-
             pm.make_typing();
             pm.make_callgraph();
             pm.make_devices();
@@ -1455,11 +1448,15 @@ fn run_pass(
             let callgraph = pm.callgraph.take().unwrap();
             let devices = pm.devices.take().unwrap();
 
-            let mut editors = build_editors(pm);
+            // Modify the selection to include callers of selected functions.
+            let mut editors = build_selection(pm, selection)
+                .into_iter()
+                .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor)))
+                .collect();
             float_collections(&mut editors, &typing, &callgraph, &devices);
 
             for func in editors {
-                changed |= func.modified();
+                changed |= func.1.modified();
             }
 
             pm.delete_gravestones();
-- 
GitLab


From ba614011066dcde2305a649f7911368d99f95a84 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 11:12:34 -0600
Subject: [PATCH 06/33] fix reduce cycles

---
 hercules_ir/src/fork_join_analysis.rs    | 34 ++++++++++++------------
 juno_samples/fork_join_tests/src/cpu.sch |  9 +++++++
 2 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs
index 263fa952..7a098a35 100644
--- a/hercules_ir/src/fork_join_analysis.rs
+++ b/hercules_ir/src/fork_join_analysis.rs
@@ -140,23 +140,23 @@ fn reduce_cycle_dfs_helper(
     }
 
     current_visited.insert(iter);
-    let found_reduce = get_uses(&function.nodes[iter.idx()])
-        .as_ref()
-        .into_iter()
-        .any(|u| {
-            !current_visited.contains(u)
-                && !function.nodes[u.idx()].is_control()
-                && isnt_outside_fork_join(*u)
-                && reduce_cycle_dfs_helper(
-                    function,
-                    *u,
-                    fork,
-                    reduce,
-                    current_visited,
-                    in_cycle,
-                    fork_join_nest,
-                )
-        });
+    let mut found_reduce = false;
+
+    // This doesn't short circuit on purpose.
+    for u in get_uses(&function.nodes[iter.idx()]).as_ref() {
+        found_reduce |= !current_visited.contains(u)
+            && !function.nodes[u.idx()].is_control()
+            && isnt_outside_fork_join(*u)
+            && reduce_cycle_dfs_helper(
+                function,
+                *u,
+                fork,
+                reduce,
+                current_visited,
+                in_cycle,
+                fork_join_nest,
+            )
+    }
     if found_reduce {
         in_cycle.insert(iter);
     }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index a6b1afe7..2889cec0 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -18,6 +18,9 @@ fixpoint panic after 20 {
   fork-guard-elim(*);
   fork-coalesce(*);
 }
+gvn(*);
+phi-elim(*);
+dce(*);
 
 gvn(*);
 phi-elim(*);
@@ -27,6 +30,12 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 fork-split(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 unforkify(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 
 gcm(*);
-- 
GitLab


From f2865cbfd0fd7f3a7e84c08e70283c8c0eedefdf Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 11:21:02 -0600
Subject: [PATCH 07/33] interesting test

---
 juno_samples/fork_join_tests/src/cpu.sch      |  1 +
 .../fork_join_tests/src/fork_join_tests.jn    | 23 +++++++++++++++++++
 juno_samples/fork_join_tests/src/gpu.sch      |  1 +
 juno_samples/fork_join_tests/src/main.rs      |  5 ++++
 4 files changed, 30 insertions(+)

diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 2889cec0..0263c275 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -5,6 +5,7 @@ dce(*);
 let out = auto-outline(*);
 cpu(out.test1);
 cpu(out.test2);
+cpu(out.test3);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 4a6a94c9..073cfd1e 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -21,3 +21,26 @@ fn test2(input : i32) -> i32[4, 4] {
   }
   return arr;
 }
+
+#[entry]
+fn test3(input : i32) -> i32[3, 3] {
+  let arr1 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr1[i, j] = (i + j) as i32 + input;
+    }
+  }
+  let arr2 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr2[i, j] = arr1[3 - i, 3 - j];
+    }
+  }
+  let arr3 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr3[i, j] = arr2[i, j] + 7;
+    }
+  }
+  return arr3;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index b506c4a4..80f1bbc9 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -5,6 +5,7 @@ dce(*);
 let out = auto-outline(*);
 gpu(out.test1);
 gpu(out.test2);
+gpu(out.test3);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index de1b0805..4384ecd5 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -27,6 +27,11 @@ fn main() {
         let output = r.run(3).await;
         let correct = vec![24i32; 16];
         assert(correct, output);
+
+        let mut r = runner!(test3);
+        let output = r.run(0).await;
+        let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
+        assert(correct, output);
     });
 }
 
-- 
GitLab


From 912a729ac05bd2077f2ab864cef04e1dddde7667 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 12:34:13 -0600
Subject: [PATCH 08/33] unforkify fixes

---
 hercules_opt/src/unforkify.rs |  6 +++--
 juno_scheduler/src/pm.rs      | 42 ++++++++++++++++++++++-------------
 2 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 7d158d1a..a08d1667 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -118,7 +118,7 @@ pub fn unforkify(
     // control insides of the fork-join should become the successor of the true
     // projection node, and what was the use of the join should become a use of
     // the new region.
-    for l in loop_tree.bottom_up_loops().into_iter().rev() {
+    for l in loop_tree.bottom_up_loops().iter().rev() {
         if !editor.node(l.0).is_fork() {
             continue;
         }
@@ -133,7 +133,8 @@ pub fn unforkify(
         if factors.len() > 1 {
             // For now, don't convert multi-dimensional fork-joins. Rely on pass
             // that splits fork-joins.
-            break; // Because we have to unforkify top down, we can't unforkify forks that are contained 
+            // We can't unforkify, because then the outer forks reduce will depend on non-fork control.
+            break;
         }
         let join_control = nodes[join.idx()].try_join().unwrap();
         let tids: Vec<_> = editor
@@ -293,5 +294,6 @@ pub fn unforkify(
 
             Ok(edit)
         });
+        break;
     }
 }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d2772c71..378d8730 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1815,25 +1815,35 @@ fn run_pass(
         }
         Pass::Unforkify => {
             assert!(args.is_empty());
-            pm.make_fork_join_maps();
-            pm.make_loops();
+            loop {
+                let mut inner_changed = false;
 
-            let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            let loops = pm.loops.take().unwrap();
+                pm.make_fork_join_maps();
+                pm.make_loops();
 
-            for ((func, fork_join_map), loop_tree) in build_selection(pm, selection)
-                .into_iter()
-                .zip(fork_join_maps.iter())
-                .zip(loops.iter())
-            {
-                let Some(mut func) = func else {
-                    continue;
-                };
-                unforkify(&mut func, fork_join_map, loop_tree);
-                changed |= func.modified();
+                let fork_join_maps = pm.fork_join_maps.take().unwrap();
+                let loops = pm.loops.take().unwrap();
+
+                for ((func, fork_join_map), loop_tree) in build_selection(pm, selection.clone())
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(loops.iter())
+                {
+                    let Some(mut func) = func else {
+                        continue;
+                    };
+                    unforkify(&mut func, fork_join_map, loop_tree);
+                    changed |= func.modified();
+                    inner_changed |= func.modified();
+                }
+                pm.delete_gravestones();
+                pm.clear_analyses();
+
+                if !inner_changed {
+                    break;
+                }
+                break;
             }
-            pm.delete_gravestones();
-            pm.clear_analyses();
         }
         Pass::ForkCoalesce => {
             assert!(args.is_empty());
-- 
GitLab


From d390705275493b662dcfa92a8d7ad35551d119a4 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 13:20:04 -0600
Subject: [PATCH 09/33] whoops

---
 juno_samples/fork_join_tests/src/fork_join_tests.jn | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 073cfd1e..3d003f3c 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -33,7 +33,7 @@ fn test3(input : i32) -> i32[3, 3] {
   let arr2 : i32[3, 3];
   for i = 0 to 3 {
     for j = 0 to 3 {
-      arr2[i, j] = arr1[3 - i, 3 - j];
+      arr2[i, j] = arr1[2 - i, 2 - j];
     }
   }
   let arr3 : i32[3, 3];
-- 
GitLab


From ed33189c2fcaf725a8ae34f09d4fb963826ac4b6 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 13:38:57 -0600
Subject: [PATCH 10/33] whoops x2

---
 hercules_ir/src/ir.rs                    | 2 +-
 juno_samples/fork_join_tests/src/gpu.sch | 1 +
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 846347b0..5c575ea1 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
  * list of indices B.
  */
 pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool {
-    if indices_a.len() < indices_b.len() {
+    if indices_a.len() > indices_b.len() {
         return false;
     }
 
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 80f1bbc9..0647d781 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -28,5 +28,6 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
+xdot[true](*);
 float-collections(test2, out.test2);
 gcm(*);
-- 
GitLab


From e1f5634f38cca264a25abe83174434ef333af9c3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 13:59:43 -0600
Subject: [PATCH 11/33] fix to antideps

---
 hercules_opt/src/gcm.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 271bfaf1..b13c919a 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -290,7 +290,7 @@ fn basic_blocks(
                     .collect();
                 for mutator in reverse_postorder.iter() {
                     let mutator_early = schedule_early[mutator.idx()].unwrap();
-                    if dom.does_dom(root_early, mutator_early)
+                    if dom.does_prop_dom(root_early, mutator_early)
                         && (root_early != mutator_early
                             || root_block_iterated_users.contains(&mutator))
                         && mutating_objects(function, func_id, *mutator, objects)
-- 
GitLab


From 7f381ff54caddce24b0808725390a982e27bfdbc Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 14:38:13 -0600
Subject: [PATCH 12/33] hack for gpu

---
 hercules_cg/src/gpu.rs  | 17 ++++++++---------
 hercules_opt/src/gcm.rs | 18 ++++++++++++++++--
 2 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index afc016a4..8f186aa7 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -622,23 +622,23 @@ extern \"C\" {} {}(",
             write!(pass_args, "ret")?;
             write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
         }
-        write!(w, "\tcudaError_t err;\n");
+        write!(w, "\tcudaError_t err;\n")?;
         write!(
             w,
             "\t{}_gpu<<<{}, {}, {}>>>({});\n",
             self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args
         )?;
-        write!(w, "\terr = cudaGetLastError();\n");
+        write!(w, "\terr = cudaGetLastError();\n")?;
         write!(
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
-        );
+        )?;
         write!(w, "\tcudaDeviceSynchronize();\n")?;
-        write!(w, "\terr = cudaGetLastError();\n");
+        write!(w, "\terr = cudaGetLastError();\n")?;
         write!(
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n"
-        );
+        )?;
         if has_ret_var {
             // Copy return from device to host, whether it's primitive value or collection pointer
             write!(w, "\t{} host_ret;\n", ret_type)?;
@@ -1150,7 +1150,8 @@ extern \"C\" {} {}(",
             // for all threads. Otherwise, it can be inside or outside block fork.
             // If inside, it's stored in shared memory so we "allocate" it once
             // and parallelize memset to 0. If outside, we initialize as offset
-            // to backing, but if multi-block grid, don't memset to avoid grid-level sync.
+            // to backing, but if multi-block grid, don't memset to avoid grid-
+            // level sync.
             Node::Constant { id: cons_id } => {
                 let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive();
                 let cg_tile = match state {
@@ -1192,9 +1193,7 @@ extern \"C\" {} {}(",
                     )?;
                 }
                 if !is_primitive
-                    && (state != KernelState::OutBlock
-                        || is_block_parallel.is_none()
-                        || !is_block_parallel.unwrap())
+                    && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
                 {
                     let data_size =
                         self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects));
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index b13c919a..65f7c2d0 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -90,6 +90,7 @@ pub fn gcm(
         loops,
         fork_join_map,
         objects,
+        devices,
     );
 
     let liveness = liveness_dataflow(
@@ -174,6 +175,7 @@ fn basic_blocks(
     loops: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
     objects: &CollectionObjects,
+    devices: &Vec<Device>,
 ) -> BasicBlocks {
     let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
 
@@ -421,9 +423,18 @@ fn basic_blocks(
                 // If the next node further up the dominator tree is in a shallower
                 // loop nest or if we can get out of a reduce loop when we don't
                 // need to be in one, place this data node in a higher-up location.
-                // Only do this is the node isn't a constant or undef.
+                // Only do this is the node isn't a constant or undef - if a
+                // node is a constant or undef, we want its placement to be as
+                // control dependent as possible, even inside loops. In GPU
+                // functions specifically, lift constants that may be returned
+                // outside fork-joins.
                 let is_constant_or_undef =
                     function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef();
+                let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
+                    && objects[&func_id]
+                        .objects(id)
+                        .into_iter()
+                        .any(|obj| objects[&func_id].returned_objects().contains(obj));
                 let old_nest = loops
                     .header_of(location)
                     .map(|header| loops.nesting(header).unwrap());
@@ -444,7 +455,10 @@ fn basic_blocks(
                 // loop use the reduce node forming the loop, so the dominator chain
                 // will consist of one block, and this loop won't ever iterate.
                 let currently_at_join = function.nodes[location.idx()].is_join();
-                if !is_constant_or_undef && (shallower_nest || currently_at_join) {
+
+                if (!is_constant_or_undef || is_gpu_returned)
+                    && (shallower_nest || currently_at_join)
+                {
                     location = control_node;
                 }
             }
-- 
GitLab


From 25725cb1df4ec7b3c59f963c00aa5a2844a52916 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 14:50:58 -0600
Subject: [PATCH 13/33] Ok fix antideps for real this time

---
 hercules_opt/src/gcm.rs  | 19 +++++++++++++++++--
 juno_scheduler/src/pm.rs |  3 +++
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 65f7c2d0..3ff6d2fe 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1,5 +1,5 @@
 use std::cell::Ref;
-use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
+use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
 use std::iter::{empty, once, zip, FromIterator};
 
 use bitvec::prelude::*;
@@ -76,6 +76,7 @@ pub fn gcm(
     dom: &DomTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
     loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
     object_device_demands: &FunctionObjectDeviceDemands,
@@ -88,6 +89,7 @@ pub fn gcm(
         reverse_postorder,
         dom,
         loops,
+        reduce_cycles,
         fork_join_map,
         objects,
         devices,
@@ -173,6 +175,7 @@ fn basic_blocks(
     reverse_postorder: &Vec<NodeID>,
     dom: &DomTree,
     loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     fork_join_map: &HashMap<NodeID, NodeID>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
@@ -246,6 +249,9 @@ fn basic_blocks(
     //    but not forwarding read - forwarding reads are collapsed, and the
     //    bottom read is treated as reading from the transitive parent of the
     //    forwarding read(s).
+    // 3: If the node producing the collection is a reduce node, then any read
+    //    users that aren't in the reduce's cycle shouldn't anti-depend user any
+    //    mutators in the reduce cycle.
     let mut antideps = BTreeSet::new();
     for id in reverse_postorder.iter() {
         // Find a terminating read node and the collections it reads.
@@ -271,6 +277,10 @@ fn basic_blocks(
             // TODO: make this less outrageously inefficient.
             let func_objects = &objects[&func_id];
             for root in roots.iter() {
+                let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles
+                    .get(root)
+                    .map(|cycle| !cycle.contains(&id))
+                    .unwrap_or(false);
                 let root_early = schedule_early[root.idx()].unwrap();
                 let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new();
                 let mut workset = BTreeSet::new();
@@ -292,12 +302,17 @@ fn basic_blocks(
                     .collect();
                 for mutator in reverse_postorder.iter() {
                     let mutator_early = schedule_early[mutator.idx()].unwrap();
-                    if dom.does_prop_dom(root_early, mutator_early)
+                    if dom.does_dom(root_early, mutator_early)
                         && (root_early != mutator_early
                             || root_block_iterated_users.contains(&mutator))
                         && mutating_objects(function, func_id, *mutator, objects)
                             .any(|mutated| read_objs.contains(&mutated))
                         && id != mutator
+                        && (!root_is_reduce_and_read_isnt_in_cycle
+                            || !reduce_cycles
+                                .get(root)
+                                .map(|cycle| cycle.contains(mutator))
+                                .unwrap_or(false))
                     {
                         antideps.insert((*id, *mutator));
                     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 1ebc885c..db4455e8 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1567,6 +1567,7 @@ fn run_pass(
                 pm.make_doms();
                 pm.make_fork_join_maps();
                 pm.make_loops();
+                pm.make_reduce_cycles();
                 pm.make_collection_objects();
                 pm.make_devices();
                 pm.make_object_device_demands();
@@ -1577,6 +1578,7 @@ fn run_pass(
                 let doms = pm.doms.take().unwrap();
                 let fork_join_maps = pm.fork_join_maps.take().unwrap();
                 let loops = pm.loops.take().unwrap();
+                let reduce_cycles = pm.reduce_cycles.take().unwrap();
                 let control_subgraphs = pm.control_subgraphs.take().unwrap();
                 let collection_objects = pm.collection_objects.take().unwrap();
                 let devices = pm.devices.take().unwrap();
@@ -1598,6 +1600,7 @@ fn run_pass(
                         &doms[id.idx()],
                         &fork_join_maps[id.idx()],
                         &loops[id.idx()],
+                        &reduce_cycles[id.idx()],
                         &collection_objects,
                         &devices,
                         &object_device_demands[id.idx()],
-- 
GitLab


From c7d47ec8de0737d73395c06cd3ea5e835b2324d8 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 14:51:22 -0600
Subject: [PATCH 14/33] remove xdot from schedule

---
 juno_samples/fork_join_tests/src/gpu.sch | 1 -
 1 file changed, 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 0647d781..80f1bbc9 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -28,6 +28,5 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
-xdot[true](*);
 float-collections(test2, out.test2);
 gcm(*);
-- 
GitLab


From 0caa58c5be3deb70cd0ef5dfde77b02a09bfbf2a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 15:28:02 -0600
Subject: [PATCH 15/33] Test requiring outer split + unforkify

---
 .../fork_join_tests/src/fork_join_tests.jn        | 15 +++++++++++++++
 juno_samples/fork_join_tests/src/gpu.sch          |  1 +
 juno_samples/fork_join_tests/src/main.rs          |  5 +++++
 3 files changed, 21 insertions(+)

diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 3d003f3c..55e0a37e 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -44,3 +44,18 @@ fn test3(input : i32) -> i32[3, 3] {
   }
   return arr3;
 }
+
+#[entry]
+fn test4(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 4 {
+    for j = 0 to 4 {
+      let acc = arr[i, j];
+      for k = 0 to 7 {
+        acc += input;
+      }
+      arr[i, j] = acc;
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 80f1bbc9..bf35caea 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -6,6 +6,7 @@ let out = auto-outline(*);
 gpu(out.test1);
 gpu(out.test2);
 gpu(out.test3);
+gpu(out.test4);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 4384ecd5..cbd42c50 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -32,6 +32,11 @@ fn main() {
         let output = r.run(0).await;
         let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
         assert(correct, output);
+
+        let mut r = runner!(test4);
+        let output = r.run(9).await;
+        let correct = vec![63i32; 16];
+        assert(correct, output);
     });
 }
 
-- 
GitLab


From 50555d655836a3090d605dfe622a9cc1127076a6 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 15:42:08 -0600
Subject: [PATCH 16/33] interpreter fixes + product consts

---
 hercules_test/hercules_interpreter/src/interpreter.rs | 3 +--
 hercules_test/hercules_interpreter/src/value.rs       | 4 ++--
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs
index 871e304a..22ef062a 100644
--- a/hercules_test/hercules_interpreter/src/interpreter.rs
+++ b/hercules_test/hercules_interpreter/src/interpreter.rs
@@ -783,8 +783,7 @@ impl<'a> FunctionExecutionState<'a> {
                                 &self.module.dynamic_constants,
                                 &self.dynamic_constant_params,
                             )
-                        })
-                        .rev();
+                        });
 
                     let n_tokens: usize = factors.clone().product();
 
diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs
index 53911e05..adbed6e6 100644
--- a/hercules_test/hercules_interpreter/src/value.rs
+++ b/hercules_test/hercules_interpreter/src/value.rs
@@ -156,8 +156,8 @@ impl<'a> InterpreterVal {
             Constant::Float64(v) => Self::Float64(v),
 
             Constant::Product(ref type_id, ref constant_ids) => {
-                // Self::Product((), ())
-                todo!()
+                let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params));
+                InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice())
             }
             Constant::Summation(_, _, _) => todo!(),
             Constant::Array(type_id) => {
-- 
GitLab


From acf060c7daf486ef28a68a292beda3132850891d Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 15:42:26 -0600
Subject: [PATCH 17/33] read schedule from file

---
 juno_scheduler/src/lib.rs | 36 ++++++++++++++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 571d1fbf..2479af98 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -146,6 +146,42 @@ pub fn run_schedule_on_hercules(
     .map_err(|e| format!("Scheduling Error: {}", e))
 }
 
+
+pub fn run_schedule_from_file_on_hercules(
+    module: Module,
+    sched_filename: Option<String>,
+) -> Result<Module, String> {
+    let sched = process_schedule(sched_filename)?;
+
+    // Prepare the scheduler's string table and environment
+    // For this, we put all of the Hercules function names into the environment
+    // and string table
+    let mut strings = StringTable::new();
+    let mut env = Env::new();
+
+    env.open_scope();
+
+    for (idx, func) in module.functions.iter().enumerate() {
+        let func_name = strings.lookup_string(func.name.clone());
+        env.insert(
+            func_name,
+            Value::HerculesFunction {
+                func: FunctionID::new(idx),
+            },
+        );
+    }
+
+    env.open_scope();
+    schedule_module(
+        module,
+        sched,
+        strings,
+        env,
+        JunoFunctions { func_ids: vec![] },
+    )
+    .map_err(|e| format!("Scheduling Error: {}", e))
+}
+
 pub fn schedule_hercules(
     module: Module,
     sched_filename: Option<String>,
-- 
GitLab


From 61bc9ae455c3933f930d7a980a4cc4efc39afb6f Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 15:43:08 -0600
Subject: [PATCH 18/33] unforkify fix

---
 hercules_opt/src/fork_guard_elim.rs | 8 ++++----
 juno_scheduler/src/pm.rs            | 1 -
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs
index 052fd0e4..f6914b74 100644
--- a/hercules_opt/src/fork_guard_elim.rs
+++ b/hercules_opt/src/fork_guard_elim.rs
@@ -39,7 +39,6 @@ struct GuardedFork {
     guard_if: NodeID,
     fork_taken_proj: NodeID,
     fork_skipped_proj: NodeID,
-    guard_pred: NodeID,
     guard_join_region: NodeID,
     phi_reduce_map: HashMap<NodeID, NodeID>,
     factor: Factor, // The factor that matches the guard
@@ -302,7 +301,6 @@ fn guarded_fork(
         guard_if: if_node,
         fork_taken_proj: *control,
         fork_skipped_proj: other_pred,
-        guard_pred: if_pred,
         guard_join_region: join_control,
         phi_reduce_map: phi_nodes,
         factor,
@@ -323,13 +321,15 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node
         join,
         fork_taken_proj,
         fork_skipped_proj,
-        guard_pred,
         phi_reduce_map,
         factor,
         guard_if,
         guard_join_region,
     } in guard_info
-    {
+    {   
+        let Some(guard_pred) = editor.get_uses(guard_if).next() else {
+            unreachable!()
+        };
         let new_fork_info = if let Factor::Max(idx, dc) = factor {
             let Node::Fork {
                 control: _,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index db3904f7..dd2ae73a 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1895,7 +1895,6 @@ fn run_pass(
                 if !inner_changed {
                     break;
                 }
-                break;
             }
         }
         Pass::ForkCoalesce => {
-- 
GitLab


From 82ab7c076640225cf0ff4de2b8443a46dcc9b95f Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 22:39:36 -0600
Subject: [PATCH 19/33] fork dim merge

---
 hercules_opt/src/fork_transforms.rs | 161 ++++++++++++++++++++++++++++
 juno_scheduler/src/compile.rs       |   3 +-
 juno_scheduler/src/ir.rs            |   2 +
 juno_scheduler/src/lib.rs           |   3 +-
 juno_scheduler/src/pm.rs            |  30 ++++--
 5 files changed, 189 insertions(+), 10 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e23f586f..58ace775 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,5 +1,6 @@
 use std::collections::{HashMap, HashSet};
 use std::iter::zip;
+use std::thread::ThreadId;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -693,3 +694,163 @@ pub(crate) fn split_fork(
         None
     }
 }
+
+// Splits a dimension of a single fork join into multiple. 
+// Iterates an outer loop original_dim / tile_size times 
+// adds a tile_size loop as the inner loop 
+// Assumes that tile size divides original dim evenly.
+pub fn chunk_fork_unguarded(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    dim_idx: usize,
+    tile_size: DynamicConstantID,
+) -> () {
+    // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
+
+    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
+
+    let mut new_factors: Vec<_> = old_factors.to_vec();
+
+    let fork_users: Vec<_> = editor.get_users(fork).collect();
+
+
+    for tid in fork_users {
+        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+        editor.edit(|mut edit| {
+            if tid_dim > dim_idx {
+                let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let new_tid = edit.add_node(new_tid);
+                edit.replace_all_uses(tid, new_tid)
+            } else if tid_dim == dim_idx {
+                let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let tile_tid = edit.add_node(tile_tid);
+                
+                let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
+                let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
+                let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
+                edit.replace_all_uses_where(tid, add, |usee| *usee != mul )
+            } else {
+                Ok(edit)
+            }
+        });
+    }
+
+    editor.edit(|mut edit| {
+        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
+        new_factors.insert(dim_idx + 1, tile_size);
+        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
+
+        edit.replace_all_uses(fork, new_fork)
+    });
+}
+
+
+pub fn merge_all_fork_dims(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+) {
+    for (fork, _) in fork_join_map {
+        let Node::Fork { control: _, factors: dims } = editor.node(fork) else {
+            unreachable!();
+        };
+
+        let mut fork = *fork;
+        for _ in 0..dims.len() - 1 {
+            let outer = 0;
+            let inner = 1;
+            fork = fork_dim_merge(editor, fork, outer, inner);
+        }
+    }
+}
+
+// Splits a dimension of a single fork join into multiple. 
+// Iterates an outer loop original_dim / tile_size times 
+// adds a tile_size loop as the inner loop 
+// Assumes that tile size divides original dim evenly.
+pub fn fork_dim_merge(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    dim_idx1: usize,
+    dim_idx2: usize,
+) -> NodeID {
+    // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2)
+    // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2)
+    assert_ne!(dim_idx1, dim_idx2);
+
+    // Outer is smaller, and also closer to the left of the factors array.
+    let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 {
+        (dim_idx2, dim_idx1)
+    } else {
+        (dim_idx1, dim_idx2)
+    };
+
+    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork};
+
+    let mut new_factors: Vec<_> = old_factors.to_vec();
+
+    
+
+    let fork_users: Vec<_> = editor.get_users(fork).collect();
+
+    let mut new_nodes = vec![];
+
+    let outer_dc_id = new_factors[outer_idx];
+    let inner_dc_id = new_factors[inner_idx];
+
+    let mut new_fork_id = NodeID::new(0);
+    editor.edit(|mut edit| {
+        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
+        new_factors.remove(inner_idx);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
+        new_fork_id = new_fork;
+
+
+        edit = edit.replace_all_uses(fork, new_fork)?;
+        edit.delete_node(fork)
+    });
+
+    
+
+    for tid in fork_users {
+        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+
+        println!("tid: {:?}", tid);
+        editor.edit(|mut edit| {
+            if tid_dim > inner_idx {
+                let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
+                let new_tid = edit.add_node(new_tid);
+                edit.replace_all_uses(tid, new_tid)
+            } else if tid_dim == outer_idx {
+                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = edit.add_node(outer_tid);
+
+                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
+                new_nodes.push(outer_tid);
+
+                // inner_idx % dim(outer_idx)
+                let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
+                
+                edit.replace_all_uses(tid, rem)
+            } else if tid_dim == inner_idx {
+                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = edit.add_node(outer_tid);
+
+                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
+                // inner_idx / dim(outer_idx)
+                let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
+                
+                edit.replace_all_uses(tid, div)
+            } else {
+                Ok(edit)
+            }
+        });
+    };
+
+    return new_fork_id;
+
+}
\ No newline at end of file
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 11a8ec53..07ad5e7a 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -108,7 +108,8 @@ impl FromStr for Appliable {
             "inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
-            }
+            },
+            "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index d6a41baf..939ef925 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -12,6 +12,8 @@ pub enum Pass {
     ForkSplit,
     ForkCoalesce,
     Forkify,
+    ForkDimMerge,
+    ForkChunk,
     GCM,
     GVN,
     InferSchedules,
diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 2479af98..ad9195fb 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -60,7 +60,7 @@ fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> {
     }
 }
 
-fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> {
+pub fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> {
     if let Some(name) = sched_filename {
         build_schedule(name)
     } else {
@@ -146,7 +146,6 @@ pub fn run_schedule_on_hercules(
     .map_err(|e| format!("Scheduling Error: {}", e))
 }
 
-
 pub fn run_schedule_from_file_on_hercules(
     module: Module,
     sched_filename: Option<String>,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 43f355c3..8b71d24e 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1871,14 +1871,11 @@ fn run_pass(
         }
         Pass::Unforkify => {
             assert!(args.is_empty());
-            loop {
-                let mut inner_changed = false;
-
-                pm.make_fork_join_maps();
-                pm.make_loops();
+            pm.make_fork_join_maps();
+            pm.make_loops();
 
-                let fork_join_maps = pm.fork_join_maps.take().unwrap();
-                let loops = pm.loops.take().unwrap();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
 
             for ((func, fork_join_map), loop_tree) in build_selection(pm, selection)
                 .into_iter()
@@ -1894,6 +1891,24 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkDimMerge => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                merge_all_fork_dims(&mut func, fork_join_map);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkCoalesce => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
@@ -1991,6 +2006,7 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
+        Pass::ForkChunk => todo!(),
     }
     println!("Ran Pass: {:?}", pass);
 
-- 
GitLab


From 1788bd1f2ba4ec496cb5afa62da884f92c1761a4 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Mon, 3 Feb 2025 11:10:10 -0600
Subject: [PATCH 20/33] tiling + dim merge with one edit per loop dim

---
 hercules_opt/src/fork_transforms.rs | 96 ++++++++++++++++-------------
 juno_scheduler/src/compile.rs       |  1 +
 juno_scheduler/src/ir.rs            |  1 +
 juno_scheduler/src/pm.rs            | 34 ++++++++++
 4 files changed, 90 insertions(+), 42 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 58ace775..cbb09bbf 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -695,6 +695,24 @@ pub(crate) fn split_fork(
     }
 }
 
+pub fn chunk_all_forks_unguarded(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    dim_idx: usize,
+    tile_size: usize,
+) -> () {
+    // Add dc
+    let mut dc_id = DynamicConstantID::new(0);
+    editor.edit(|mut edit| {
+        dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size));
+        Ok(edit)
+    });
+
+    for (fork, _ ) in fork_join_map {
+        chunk_fork_unguarded(editor, *fork, dim_idx, dc_id);
+    }
+
+}
 // Splits a dimension of a single fork join into multiple. 
 // Iterates an outer loop original_dim / tile_size times 
 // adds a tile_size loop as the inner loop 
@@ -711,39 +729,36 @@ pub fn chunk_fork_unguarded(
 
     let mut new_factors: Vec<_> = old_factors.to_vec();
 
-    let fork_users: Vec<_> = editor.get_users(fork).collect();
+    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
+
+    editor.edit(|mut edit| {
+        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
+        new_factors.insert(dim_idx + 1, tile_size);
+        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
 
+        edit = edit.replace_all_uses(fork, new_fork)?;
 
-    for tid in fork_users {
-        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
-        editor.edit(|mut edit| {
+        for (tid, node) in fork_users {
+            let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue};
             if tid_dim > dim_idx {
-                let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
                 let new_tid = edit.add_node(new_tid);
-                edit.replace_all_uses(tid, new_tid)
+                edit = edit.replace_all_uses(tid, new_tid)?;
             } else if tid_dim == dim_idx {
-                let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
                 let tile_tid = edit.add_node(tile_tid);
                 
                 let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                 let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
                 let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
-                edit.replace_all_uses_where(tid, add, |usee| *usee != mul )
-            } else {
-                Ok(edit)
+                edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?;
             }
-        });
-    }
-
-    editor.edit(|mut edit| {
-        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
-        new_factors.insert(dim_idx + 1, tile_size);
-        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
-
-        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
-        let new_fork = edit.add_node(new_fork);
-
-        edit.replace_all_uses(fork, new_fork)
+        }
+        edit = edit.delete_node(fork)?;
+        Ok(edit)
     });
 }
 
@@ -791,9 +806,8 @@ pub fn fork_dim_merge(
 
     let mut new_factors: Vec<_> = old_factors.to_vec();
 
-    
 
-    let fork_users: Vec<_> = editor.get_users(fork).collect();
+    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
 
     let mut new_nodes = vec![];
 
@@ -801,6 +815,7 @@ pub fn fork_dim_merge(
     let inner_dc_id = new_factors[inner_idx];
 
     let mut new_fork_id = NodeID::new(0);
+
     editor.edit(|mut edit| {
         new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
         new_factors.remove(inner_idx);
@@ -809,22 +824,20 @@ pub fn fork_dim_merge(
         let new_fork = edit.add_node(new_fork);
         new_fork_id = new_fork;
 
+        edit.sub_edit(fork, new_fork);
 
         edit = edit.replace_all_uses(fork, new_fork)?;
-        edit.delete_node(fork)
-    });
-
-    
-
-    for tid in fork_users {
-        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+        edit = edit.delete_node(fork)?;
 
-        println!("tid: {:?}", tid);
-        editor.edit(|mut edit| {
+        for (tid, node) in fork_users {
+            // FIXME: DO we want sub edits in this?
+            
+            let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue };
             if tid_dim > inner_idx {
                 let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
                 let new_tid = edit.add_node(new_tid);
-                edit.replace_all_uses(tid, new_tid)
+                edit = edit.replace_all_uses(tid, new_tid)?;
+                edit.sub_edit(tid, new_tid);
             } else if tid_dim == outer_idx {
                 let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
                 let outer_tid = edit.add_node(outer_tid);
@@ -834,8 +847,8 @@ pub fn fork_dim_merge(
 
                 // inner_idx % dim(outer_idx)
                 let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
-                
-                edit.replace_all_uses(tid, rem)
+                edit.sub_edit(tid, rem);
+                edit = edit.replace_all_uses(tid, rem)?;
             } else if tid_dim == inner_idx {
                 let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
                 let outer_tid = edit.add_node(outer_tid);
@@ -843,13 +856,12 @@ pub fn fork_dim_merge(
                 let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                 // inner_idx / dim(outer_idx)
                 let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
-                
-                edit.replace_all_uses(tid, div)
-            } else {
-                Ok(edit)
+                edit.sub_edit(tid, div);
+                edit = edit.replace_all_uses(tid, div)?;
             }
-        });
-    };
+        }
+        Ok(edit)
+    });
 
     return new_fork_id;
 
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 07ad5e7a..49dedd2b 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -110,6 +110,7 @@ impl FromStr for Appliable {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             },
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
+            "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 939ef925..796437a7 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -36,6 +36,7 @@ impl Pass {
     pub fn num_args(&self) -> usize {
         match self {
             Pass::Xdot => 1,
+            Pass::ForkChunk => 3, 
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 8b71d24e..5740d2a6 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1891,6 +1891,40 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkChunk => {
+            assert_eq!(args.len(), 3);
+            let tile_size = args.get(0);
+            let dim_idx = args.get(1);
+
+            let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
+                panic!(); // How to error here?
+            };
+
+            let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
+                panic!(); // How to error here?
+            };
+
+            let Some(Value::Integer { val: tile_size }) = args.get(0) else {
+                panic!(); // How to error here?
+            };
+
+            assert_eq!(*guarded_flag, true);
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
-- 
GitLab


From 27731fa4fe3553ba1a93a98115636b60a4dd00ce Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Mon, 3 Feb 2025 11:15:56 -0600
Subject: [PATCH 21/33] check for out of bounds dim on chunking

---
 hercules_opt/src/fork_transforms.rs | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index cbb09bbf..190dbd25 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -725,8 +725,13 @@ pub fn chunk_fork_unguarded(
 ) -> () {
     // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
 
+
     let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
 
+    if dim_idx >= old_factors.len()  {
+        return; // FIXME Error here?
+    }
+
     let mut new_factors: Vec<_> = old_factors.to_vec();
 
     let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
-- 
GitLab


From 1f1c6cb94c80ac867958c41d2a84d506bf06ef92 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Mon, 3 Feb 2025 11:46:58 -0600
Subject: [PATCH 22/33] rewrite forkify as single edit per loop

---
 hercules_opt/src/forkify.rs | 75 ++++++++++++++++++++-----------------
 1 file changed, 40 insertions(+), 35 deletions(-)

diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 38b9aaaa..0a2d5601 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -298,29 +298,11 @@ pub fn forkify_loop(
     let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap();
     let dimension = factors.len() - 1;
 
-    // Create ThreadID
-    editor.edit(|mut edit| {
-        let thread_id = Node::ThreadID {
-            control: fork_id,
-            dimension: dimension,
-        };
-        let thread_id_id = edit.add_node(thread_id);
-
-        // Replace uses that are inside with the thread id
-        edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
-            loop_nodes.contains(node)
-        })?;
+    // Start failable edit:
 
-        // Replace uses that are outside with DC - 1. Or just give up.
-        let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id });
-        edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| {
-            !loop_nodes.contains(node)
-        })?;
-
-        edit.delete_node(canonical_iv.phi())
-    });
-
-    for reduction_phi in reductionable_phis {
+    let redcutionable_phis_and_init: Vec<(_, NodeID)> = 
+        reductionable_phis.iter().map(|reduction_phi| {
+        
         let LoopPHI::Reductionable {
             phi,
             data_cycle: _,
@@ -342,12 +324,41 @@ pub fn forkify_loop(
         .unwrap()
         .1;
 
-        editor.edit(|mut edit| {
+        (reduction_phi, init)
+    }).collect();
+
+    editor.edit(|mut edit| {
+        let thread_id = Node::ThreadID {
+            control: fork_id,
+            dimension: dimension,
+        };
+        let thread_id_id = edit.add_node(thread_id);
+
+        // Replace uses that are inside with the thread id
+        edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
+            loop_nodes.contains(node)
+        })?;
+
+        edit = edit.delete_node(canonical_iv.phi())?;
+
+        for (reduction_phi, init) in redcutionable_phis_and_init {
+            let LoopPHI::Reductionable {
+                phi,
+                data_cycle: _,
+                continue_latch,
+                is_associative: _,
+            } = *reduction_phi
+            else {
+                panic!();
+            };
+            
             let reduce = Node::Reduce {
                 control: join_id,
                 init,
                 reduct: continue_latch,
             };
+            
+            
             let reduce_id = edit.add_node(reduce);
 
             if (!edit.get_node(init).is_reduce()
@@ -375,20 +386,14 @@ pub fn forkify_loop(
             edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| {
                 !loop_nodes.contains(usee) && *usee != reduce_id
             })?;
-            edit.delete_node(phi)
-        });
-    }
-
-    // Replace all uses of the loop header with the fork
-    editor.edit(|edit| edit.replace_all_uses(l.header, fork_id));
+            edit = edit.delete_node(phi)?
 
-    editor.edit(|edit| edit.replace_all_uses(loop_continue_projection, fork_id));
+        }
 
-    editor.edit(|edit| edit.replace_all_uses(loop_exit_projection, join_id));
+        edit = edit.replace_all_uses(l.header, fork_id)?;
+        edit = edit.replace_all_uses(loop_continue_projection, fork_id)?;
+        edit = edit.replace_all_uses(loop_exit_projection, join_id)?;
 
-    // Get rid of loop condition
-    // DCE should get these, but delete them ourselves because we are nice :)
-    editor.edit(|mut edit| {
         edit = edit.delete_node(loop_continue_projection)?;
         edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this.
         edit = edit.delete_node(loop_exit_projection)?;
@@ -396,7 +401,7 @@ pub fn forkify_loop(
         edit = edit.delete_node(l.header)?;
         Ok(edit)
     });
-
+    
     return true;
 }
 
-- 
GitLab


From 6382ef4263b16f54a8d3b4d5e3a795c9c9e11013 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 3 Feb 2025 15:44:25 -0600
Subject: [PATCH 23/33] fix ups

---
 Cargo.lock                                    |  11 --
 hercules_cg/src/fork_tree.rs                  |  19 +-
 hercules_opt/src/fork_guard_elim.rs           |  19 +-
 hercules_opt/src/fork_transforms.rs           | 166 +++++++++++-------
 hercules_opt/src/forkify.rs                   |  61 ++++---
 hercules_samples/dot/build.rs                 |   6 +-
 hercules_samples/dot/src/main.rs              |   2 +-
 hercules_samples/matmul/build.rs              |   6 +-
 hercules_samples/matmul/src/main.rs           |   6 +-
 .../hercules_interpreter/src/interpreter.rs   |  76 ++++----
 .../hercules_interpreter/src/value.rs         |  10 +-
 .../hercules_tests/tests/loop_tests.rs        |  40 ++---
 juno_frontend/src/semant.rs                   |  11 +-
 juno_samples/cava/src/main.rs                 |  45 ++---
 juno_samples/concat/src/main.rs               |   4 +-
 juno_samples/edge_detection/src/main.rs       |  11 +-
 juno_samples/matmul/src/main.rs               |  18 +-
 juno_samples/nested_ccp/src/main.rs           |   2 +-
 juno_samples/patterns/src/main.rs             |   2 +-
 juno_samples/schedule_test/build.rs           |   6 +-
 juno_samples/schedule_test/src/main.rs        |  13 +-
 juno_samples/simple3/src/main.rs              |   2 +-
 juno_scheduler/src/compile.rs                 |   2 +-
 juno_scheduler/src/ir.rs                      |   2 +-
 juno_scheduler/src/pm.rs                      |  31 ++--
 25 files changed, 336 insertions(+), 235 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 7a70825a..af7902c6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1236,17 +1236,6 @@ dependencies = [
  "with_builtin_macros",
 ]
 
-[[package]]
-name = "juno_test"
-version = "0.1.0"
-dependencies = [
- "async-std",
- "hercules_rt",
- "juno_build",
- "rand",
- "with_builtin_macros",
-]
-
 [[package]]
 name = "juno_utils"
 version = "0.1.0"
diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs
index 64a93160..c048f7e3 100644
--- a/hercules_cg/src/fork_tree.rs
+++ b/hercules_cg/src/fork_tree.rs
@@ -9,11 +9,16 @@ use crate::*;
  * c) no domination by any other fork that's also dominated by F, where we do count self-domination
  * Here too we include the non-fork start node, as key for all controls outside any fork.
  */
-pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+pub fn fork_control_map(
+    fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>,
+) -> HashMap<NodeID, HashSet<NodeID>> {
     let mut fork_control_map = HashMap::new();
     for (control, forks) in fork_join_nesting {
         let fork = forks.first().copied().unwrap_or(NodeID::new(0));
-        fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control);
+        fork_control_map
+            .entry(fork)
+            .or_insert_with(HashSet::new)
+            .insert(*control);
     }
     fork_control_map
 }
@@ -24,13 +29,19 @@ pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> Has
  * c) no domination by any other fork that's also dominated by F, where we don't count self-domination
  * Note that the fork_tree also includes the non-fork start node, as unique root node.
  */
-pub fn fork_tree(function: &Function, fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+pub fn fork_tree(
+    function: &Function,
+    fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>,
+) -> HashMap<NodeID, HashSet<NodeID>> {
     let mut fork_tree = HashMap::new();
     for (control, forks) in fork_join_nesting {
         if function.nodes[control.idx()].is_fork() {
             fork_tree.entry(*control).or_insert_with(HashSet::new);
             let nesting_fork = forks.get(1).copied().unwrap_or(NodeID::new(0));
-            fork_tree.entry(nesting_fork).or_insert_with(HashSet::new).insert(*control);
+            fork_tree
+                .entry(nesting_fork)
+                .or_insert_with(HashSet::new)
+                .insert(*control);
         }
     }
     fork_tree
diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs
index f6914b74..df40e60f 100644
--- a/hercules_opt/src/fork_guard_elim.rs
+++ b/hercules_opt/src/fork_guard_elim.rs
@@ -76,13 +76,16 @@ fn guarded_fork(
         };
 
         // Filter out any terms which are just 1s
-        let non_ones = xs.iter().filter(|i| {
-            if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() {
-                false
-            } else {
-                true
-            }
-        }).collect::<Vec<_>>();
+        let non_ones = xs
+            .iter()
+            .filter(|i| {
+                if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() {
+                    false
+                } else {
+                    true
+                }
+            })
+            .collect::<Vec<_>>();
         // If we're left with just one term x, we had max { 1, x }
         if non_ones.len() == 1 {
             Factor::Max(idx, *non_ones[0])
@@ -326,7 +329,7 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node
         guard_if,
         guard_join_region,
     } in guard_info
-    {   
+    {
         let Some(guard_pred) = editor.get_uses(guard_if).next() else {
             unreachable!()
         };
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 190dbd25..ed6283fd 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -708,14 +708,13 @@ pub fn chunk_all_forks_unguarded(
         Ok(edit)
     });
 
-    for (fork, _ ) in fork_join_map {
+    for (fork, _) in fork_join_map {
         chunk_fork_unguarded(editor, *fork, dim_idx, dc_id);
     }
-
 }
-// Splits a dimension of a single fork join into multiple. 
-// Iterates an outer loop original_dim / tile_size times 
-// adds a tile_size loop as the inner loop 
+// Splits a dimension of a single fork join into multiple.
+// Iterates an outer loop original_dim / tile_size times
+// adds a tile_size loop as the inner loop
 // Assumes that tile size divides original dim evenly.
 pub fn chunk_fork_unguarded(
     editor: &mut FunctionEditor,
@@ -724,42 +723,68 @@ pub fn chunk_fork_unguarded(
     tile_size: DynamicConstantID,
 ) -> () {
     // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
-
-
-    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
-
-    if dim_idx >= old_factors.len()  {
-        return; // FIXME Error here?
-    }
-
+    let Node::Fork {
+        control: old_control,
+        factors: ref old_factors,
+    } = *editor.node(fork)
+    else {
+        return;
+    };
+    assert!(dim_idx < old_factors.len());
     let mut new_factors: Vec<_> = old_factors.to_vec();
-
-    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
+    let fork_users: Vec<_> = editor
+        .get_users(fork)
+        .map(|f| (f, editor.node(f).clone()))
+        .collect();
 
     editor.edit(|mut edit| {
         let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
         new_factors.insert(dim_idx + 1, tile_size);
         new_factors[dim_idx] = edit.add_dynamic_constant(outer);
 
-        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = Node::Fork {
+            control: old_control,
+            factors: new_factors.into(),
+        };
         let new_fork = edit.add_node(new_fork);
 
         edit = edit.replace_all_uses(fork, new_fork)?;
 
         for (tid, node) in fork_users {
-            let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue};
+            let Node::ThreadID {
+                control: _,
+                dimension: tid_dim,
+            } = node
+            else {
+                continue;
+            };
             if tid_dim > dim_idx {
-                let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
+                let new_tid = Node::ThreadID {
+                    control: new_fork,
+                    dimension: tid_dim + 1,
+                };
                 let new_tid = edit.add_node(new_tid);
                 edit = edit.replace_all_uses(tid, new_tid)?;
+                edit = edit.delete_node(tid)?;
             } else if tid_dim == dim_idx {
-                let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
+                let tile_tid = Node::ThreadID {
+                    control: new_fork,
+                    dimension: tid_dim + 1,
+                };
                 let tile_tid = edit.add_node(tile_tid);
-                
+
                 let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
-                let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
-                let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
-                edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?;
+                let mul = edit.add_node(Node::Binary {
+                    left: tid,
+                    right: tile_size,
+                    op: BinaryOperator::Mul,
+                });
+                let add = edit.add_node(Node::Binary {
+                    left: mul,
+                    right: tile_tid,
+                    op: BinaryOperator::Add,
+                });
+                edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
             }
         }
         edit = edit.delete_node(fork)?;
@@ -767,13 +792,13 @@ pub fn chunk_fork_unguarded(
     });
 }
 
-
-pub fn merge_all_fork_dims(
-    editor: &mut FunctionEditor,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-) {
+pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
     for (fork, _) in fork_join_map {
-        let Node::Fork { control: _, factors: dims } = editor.node(fork) else {
+        let Node::Fork {
+            control: _,
+            factors: dims,
+        } = editor.node(fork)
+        else {
             unreachable!();
         };
 
@@ -786,10 +811,6 @@ pub fn merge_all_fork_dims(
     }
 }
 
-// Splits a dimension of a single fork join into multiple. 
-// Iterates an outer loop original_dim / tile_size times 
-// adds a tile_size loop as the inner loop 
-// Assumes that tile size divides original dim evenly.
 pub fn fork_dim_merge(
     editor: &mut FunctionEditor,
     fork: NodeID,
@@ -806,61 +827,85 @@ pub fn fork_dim_merge(
     } else {
         (dim_idx1, dim_idx2)
     };
-
-    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork};
-
+    let Node::Fork {
+        control: old_control,
+        factors: ref old_factors,
+    } = *editor.node(fork)
+    else {
+        return fork;
+    };
     let mut new_factors: Vec<_> = old_factors.to_vec();
-
-
-    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
-
+    let fork_users: Vec<_> = editor
+        .get_users(fork)
+        .map(|f| (f, editor.node(f).clone()))
+        .collect();
     let mut new_nodes = vec![];
-
     let outer_dc_id = new_factors[outer_idx];
     let inner_dc_id = new_factors[inner_idx];
-
-    let mut new_fork_id = NodeID::new(0);
+    let mut new_fork = NodeID::new(0);
 
     editor.edit(|mut edit| {
-        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
+        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(
+            new_factors[outer_idx],
+            new_factors[inner_idx],
+        ));
         new_factors.remove(inner_idx);
-
-        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
-        let new_fork = edit.add_node(new_fork);
-        new_fork_id = new_fork;
-
+        new_fork = edit.add_node(Node::Fork {
+            control: old_control,
+            factors: new_factors.into(),
+        });
         edit.sub_edit(fork, new_fork);
-
         edit = edit.replace_all_uses(fork, new_fork)?;
         edit = edit.delete_node(fork)?;
 
         for (tid, node) in fork_users {
-            // FIXME: DO we want sub edits in this?
-            
-            let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue };
+            let Node::ThreadID {
+                control: _,
+                dimension: tid_dim,
+            } = node
+            else {
+                continue;
+            };
             if tid_dim > inner_idx {
-                let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
+                let new_tid = Node::ThreadID {
+                    control: new_fork_id,
+                    dimension: tid_dim - 1,
+                };
                 let new_tid = edit.add_node(new_tid);
                 edit = edit.replace_all_uses(tid, new_tid)?;
                 edit.sub_edit(tid, new_tid);
             } else if tid_dim == outer_idx {
-                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = Node::ThreadID {
+                    control: new_fork_id,
+                    dimension: outer_idx,
+                };
                 let outer_tid = edit.add_node(outer_tid);
 
                 let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                 new_nodes.push(outer_tid);
 
                 // inner_idx % dim(outer_idx)
-                let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
+                let rem = edit.add_node(Node::Binary {
+                    left: outer_tid,
+                    right: outer_dc,
+                    op: BinaryOperator::Rem,
+                });
                 edit.sub_edit(tid, rem);
                 edit = edit.replace_all_uses(tid, rem)?;
             } else if tid_dim == inner_idx {
-                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = Node::ThreadID {
+                    control: new_fork_id,
+                    dimension: outer_idx,
+                };
                 let outer_tid = edit.add_node(outer_tid);
 
                 let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                 // inner_idx / dim(outer_idx)
-                let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
+                let div = edit.add_node(Node::Binary {
+                    left: outer_tid,
+                    right: outer_dc,
+                    op: BinaryOperator::Div,
+                });
                 edit.sub_edit(tid, div);
                 edit = edit.replace_all_uses(tid, div)?;
             }
@@ -868,6 +913,5 @@ pub fn fork_dim_merge(
         Ok(edit)
     });
 
-    return new_fork_id;
-
-}
\ No newline at end of file
+    new_fork
+}
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 0a2d5601..f6db06ca 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -300,32 +300,33 @@ pub fn forkify_loop(
 
     // Start failable edit:
 
-    let redcutionable_phis_and_init: Vec<(_, NodeID)> = 
-        reductionable_phis.iter().map(|reduction_phi| {
-        
-        let LoopPHI::Reductionable {
-            phi,
-            data_cycle: _,
-            continue_latch,
-            is_associative: _,
-        } = reduction_phi
-        else {
-            panic!();
-        };
+    let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis
+        .iter()
+        .map(|reduction_phi| {
+            let LoopPHI::Reductionable {
+                phi,
+                data_cycle: _,
+                continue_latch,
+                is_associative: _,
+            } = reduction_phi
+            else {
+                panic!();
+            };
 
-        let function = editor.func();
+            let function = editor.func();
 
-        let init = *zip(
-            editor.get_uses(l.header),
-            function.nodes[phi.idx()].try_phi().unwrap().1.iter(),
-        )
-        .filter(|(c, _)| *c == loop_pred)
-        .next()
-        .unwrap()
-        .1;
+            let init = *zip(
+                editor.get_uses(l.header),
+                function.nodes[phi.idx()].try_phi().unwrap().1.iter(),
+            )
+            .filter(|(c, _)| *c == loop_pred)
+            .next()
+            .unwrap()
+            .1;
 
-        (reduction_phi, init)
-    }).collect();
+            (reduction_phi, init)
+        })
+        .collect();
 
     editor.edit(|mut edit| {
         let thread_id = Node::ThreadID {
@@ -351,14 +352,13 @@ pub fn forkify_loop(
             else {
                 panic!();
             };
-            
+
             let reduce = Node::Reduce {
                 control: join_id,
                 init,
                 reduct: continue_latch,
             };
-            
-            
+
             let reduce_id = edit.add_node(reduce);
 
             if (!edit.get_node(init).is_reduce()
@@ -387,7 +387,6 @@ pub fn forkify_loop(
                 !loop_nodes.contains(usee) && *usee != reduce_id
             })?;
             edit = edit.delete_node(phi)?
-
         }
 
         edit = edit.replace_all_uses(l.header, fork_id)?;
@@ -401,7 +400,7 @@ pub fn forkify_loop(
         edit = edit.delete_node(l.header)?;
         Ok(edit)
     });
-    
+
     return true;
 }
 
@@ -538,7 +537,11 @@ pub fn analyze_phis<'a>(
             // by the time the reduce is triggered (at the end of the loop's internal control).
             // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
             // Which is not allowed.
-            if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
+            if intersection
+                .iter()
+                .any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi)
+                || editor.node(loop_continue_latch).is_phi()
+            {
                 return LoopPHI::ControlDependant(*phi);
             }
 
diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs
index 8657fdc1..c8de7e90 100644
--- a/hercules_samples/dot/build.rs
+++ b/hercules_samples/dot/build.rs
@@ -4,7 +4,11 @@ fn main() {
     JunoCompiler::new()
         .ir_in_src("dot.hir")
         .unwrap()
-        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
         .unwrap()
         .build()
         .unwrap();
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 8862c11a..7f5b453a 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -1,8 +1,8 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("dot");
 
diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs
index 735458c0..ed92e022 100644
--- a/hercules_samples/matmul/build.rs
+++ b/hercules_samples/matmul/build.rs
@@ -4,7 +4,11 @@ fn main() {
     JunoCompiler::new()
         .ir_in_src("matmul.hir")
         .unwrap()
-        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
         .unwrap()
         .build()
         .unwrap();
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index abd25ec9..5c879915 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -2,9 +2,9 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("matmul");
 
@@ -36,7 +36,9 @@ fn main() {
             let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
             let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
             let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await;
+            let c = r
+                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
+                .await;
             let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
             assert_eq!(&*c_cpu, &*correct_c);
diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs
index 22ef062a..2e352644 100644
--- a/hercules_test/hercules_interpreter/src/interpreter.rs
+++ b/hercules_test/hercules_interpreter/src/interpreter.rs
@@ -69,18 +69,18 @@ pub fn dyn_const_value(
     match dc {
         DynamicConstant::Constant(v) => *v,
         DynamicConstant::Parameter(v) => dyn_const_params[*v],
-        DynamicConstant::Add(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(0, |s, v| s + v)
-        }
+        DynamicConstant::Add(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(0, |s, v| s + v),
         DynamicConstant::Sub(a, b) => {
             dyn_const_value(a, dyn_const_values, dyn_const_params)
                 - dyn_const_value(b, dyn_const_values, dyn_const_params)
         }
-        DynamicConstant::Mul(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(1, |p, v| p * v)
-        }
+        DynamicConstant::Mul(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(1, |p, v| p * v),
         DynamicConstant::Div(a, b) => {
             dyn_const_value(a, dyn_const_values, dyn_const_params)
                 / dyn_const_value(b, dyn_const_values, dyn_const_params)
@@ -89,28 +89,28 @@ pub fn dyn_const_value(
             dyn_const_value(a, dyn_const_values, dyn_const_params)
                 % dyn_const_value(b, dyn_const_values, dyn_const_params)
         }
-        DynamicConstant::Max(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(None, |m, v| {
-                    if let Some(m) = m {
-                        Some(max(m, v))
-                    } else {
-                        Some(v)
-                    }
-                })
-                .unwrap()
-        }
-        DynamicConstant::Min(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(None, |m, v| {
-                    if let Some(m) = m {
-                        Some(min(m, v))
-                    } else {
-                        Some(v)
-                    }
-                })
-                .unwrap()
-        }
+        DynamicConstant::Max(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(None, |m, v| {
+                if let Some(m) = m {
+                    Some(max(m, v))
+                } else {
+                    Some(v)
+                }
+            })
+            .unwrap(),
+        DynamicConstant::Min(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(None, |m, v| {
+                if let Some(m) = m {
+                    Some(min(m, v))
+                } else {
+                    Some(v)
+                }
+            })
+            .unwrap(),
     }
 }
 
@@ -775,15 +775,13 @@ impl<'a> FunctionExecutionState<'a> {
                     //     panic!("multi-dimensional forks unimplemented")
                     // }
 
-                    let factors = factors
-                        .iter()
-                        .map(|f| {
-                            dyn_const_value(
-                                &f,
-                                &self.module.dynamic_constants,
-                                &self.dynamic_constant_params,
-                            )
-                        });
+                    let factors = factors.iter().map(|f| {
+                        dyn_const_value(
+                            &f,
+                            &self.module.dynamic_constants,
+                            &self.dynamic_constant_params,
+                        )
+                    });
 
                     let n_tokens: usize = factors.clone().product();
 
diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs
index adbed6e6..4a802f7a 100644
--- a/hercules_test/hercules_interpreter/src/value.rs
+++ b/hercules_test/hercules_interpreter/src/value.rs
@@ -156,7 +156,15 @@ impl<'a> InterpreterVal {
             Constant::Float64(v) => Self::Float64(v),
 
             Constant::Product(ref type_id, ref constant_ids) => {
-                let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params));
+                let contents = constant_ids.iter().map(|const_id| {
+                    InterpreterVal::from_constant(
+                        &constants[const_id.idx()],
+                        constants,
+                        types,
+                        dynamic_constants,
+                        dynamic_constant_params,
+                    )
+                });
                 InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice())
             }
             Constant::Summation(_, _, _) => todo!(),
diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs
index 192c1366..795642b2 100644
--- a/hercules_test/hercules_tests/tests/loop_tests.rs
+++ b/hercules_test/hercules_tests/tests/loop_tests.rs
@@ -35,9 +35,7 @@ fn alternate_bounds_use_after_loop_no_tid() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = default_schedule![
-        Forkify,
-    ];
+    let schedule = default_schedule![Forkify,];
 
     let module = run_schedule_on_hercules(module, Some(schedule)).unwrap();
 
@@ -61,9 +59,7 @@ fn alternate_bounds_use_after_loop() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -108,10 +104,7 @@ fn do_while_separate_body() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        PhiElim,
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![PhiElim, Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -131,10 +124,7 @@ fn alternate_bounds_internal_control() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        PhiElim,
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![PhiElim, Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -155,10 +145,7 @@ fn alternate_bounds_internal_control2() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        PhiElim,
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![PhiElim, Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -366,16 +353,13 @@ fn look_at_local() {
         "/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin",
     );
 
-    let schedule = Some(default_schedule![
-    ]);
+    let schedule = Some(default_schedule![]);
 
     let result_1 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone());
 
     let module = run_schedule_on_hercules(module.clone(), schedule).unwrap();
 
-    let schedule = Some(default_schedule![
-        Unforkify, Verify,
-    ]);
+    let schedule = Some(default_schedule![Unforkify, Verify,]);
 
     let module = run_schedule_on_hercules(module.clone(), schedule).unwrap();
 
@@ -425,7 +409,15 @@ fn matmul_pipeline() {
     };
     assert_eq!(correct_c[0], value);
 
-    let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]);
+    let schedule = Some(default_schedule![
+        AutoOutline,
+        InterproceduralSROA,
+        SROA,
+        InferSchedules,
+        DCE,
+        Xdot,
+        GCM
+    ]);
 
     module = run_schedule_on_hercules(module, schedule).unwrap();
 
diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs
index e133e3c2..8668d1b4 100644
--- a/juno_frontend/src/semant.rs
+++ b/juno_frontend/src/semant.rs
@@ -752,7 +752,16 @@ fn analyze_program(
                             }
                             arg_info.push((ty, inout.is_some(), var));
 
-                            match process_irrefutable_pattern(pattern, false, var, ty, lexer, &mut stringtab, &mut env, &mut types) {
+                            match process_irrefutable_pattern(
+                                pattern,
+                                false,
+                                var,
+                                ty,
+                                lexer,
+                                &mut stringtab,
+                                &mut env,
+                                &mut types,
+                            ) {
                                 Ok(prep) => {
                                     stmts.extend(prep);
                                 }
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index 482bbf8d..e8a7e4e9 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -8,9 +8,9 @@ use self::camera_model::*;
 use self::cava_rust::CHAN;
 use self::image_proc::*;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 use image::ImageError;
 
@@ -31,7 +31,6 @@ fn run_cava(
     coefs: &[f32],
     tonemap: &[f32],
 ) -> Box<[u8]> {
-
     assert_eq!(image.len(), CHAN * rows * cols);
     assert_eq!(tstw.len(), CHAN * CHAN);
     assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN);
@@ -47,21 +46,24 @@ fn run_cava(
         let weights = HerculesCPURef::from_slice(weights);
         let coefs = HerculesCPURef::from_slice(coefs);
         let tonemap = HerculesCPURef::from_slice(tonemap);
-	    let mut r = runner!(cava);
-	    async_std::task::block_on(async {
-		r.run(
-		    rows as u64,
-		    cols as u64,
-		    num_ctrl_pts as u64,
-		    image,
-		    tstw,
-		    ctrl_pts,
-		    weights,
-		    coefs,
-		    tonemap,
-		)
-		.await
-	    }).as_slice::<u8>().to_vec().into_boxed_slice()
+        let mut r = runner!(cava);
+        async_std::task::block_on(async {
+            r.run(
+                rows as u64,
+                cols as u64,
+                num_ctrl_pts as u64,
+                image,
+                tstw,
+                ctrl_pts,
+                weights,
+                coefs,
+                tonemap,
+            )
+            .await
+        })
+        .as_slice::<u8>()
+        .to_vec()
+        .into_boxed_slice()
     }
 
     #[cfg(feature = "cuda")]
@@ -72,8 +74,8 @@ fn run_cava(
         let weights = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(weights));
         let coefs = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(coefs));
         let tonemap = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(tonemap));
-	    let mut r = runner!(cava);
-	    let res = async_std::task::block_on(async {
+        let mut r = runner!(cava);
+        let res = async_std::task::block_on(async {
             r.run(
                 rows as u64,
                 cols as u64,
@@ -86,7 +88,7 @@ fn run_cava(
                 tonemap.get_ref(),
             )
             .await
-	    });
+        });
         let num_out = unsafe { res.__size() / std::mem::size_of::<u8>() };
         let mut res_cpu: Box<[u8]> = vec![0; num_out].into_boxed_slice();
         res.to_cpu_ref(&mut res_cpu);
@@ -204,7 +206,8 @@ fn cava_harness(args: CavaInputs) {
                 .expect("Error saving verification image");
         }
 
-        let max_diff = result.iter()
+        let max_diff = result
+            .iter()
             .zip(cpu_result.iter())
             .map(|(a, b)| (*a as i16 - *b as i16).abs())
             .max()
diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs
index 9674c2c5..547dee08 100644
--- a/juno_samples/concat/src/main.rs
+++ b/juno_samples/concat/src/main.rs
@@ -1,9 +1,9 @@
 #![feature(concat_idents)]
 
 use hercules_rt::runner;
-use hercules_rt::HerculesCPURef;
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::HerculesCPURef;
 
 juno_build::juno!("concat");
 
@@ -20,7 +20,7 @@ fn main() {
             assert_eq!(output, 42);
 
             const N: usize = 3;
-            let arr : Box<[i32]> = (2..=4).collect();
+            let arr: Box<[i32]> = (2..=4).collect();
             let arr = HerculesCPURef::from_slice(&arr);
 
             let mut r = runner!(concat_switch);
diff --git a/juno_samples/edge_detection/src/main.rs b/juno_samples/edge_detection/src/main.rs
index eda65016..3b067ebd 100644
--- a/juno_samples/edge_detection/src/main.rs
+++ b/juno_samples/edge_detection/src/main.rs
@@ -2,9 +2,9 @@
 
 mod edge_detection_rust;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 use std::slice::from_raw_parts;
 
@@ -228,9 +228,9 @@ fn edge_detection_harness(args: EdgeDetectionInputs) {
         });
 
         #[cfg(not(feature = "cuda"))]
-        let result : Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice();
+        let result: Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice();
         #[cfg(feature = "cuda")]
-        let result : Box<[f32]> = {
+        let result: Box<[f32]> = {
             let num_out = unsafe { result.__size() / std::mem::size_of::<f32>() };
             let mut res_cpu: Box<[f32]> = vec![0.0; num_out].into_boxed_slice();
             result.to_cpu_ref(&mut res_cpu);
@@ -261,7 +261,10 @@ fn edge_detection_harness(args: EdgeDetectionInputs) {
                 theta,
             );
 
-            assert_eq!(result.as_ref(), <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result));
+            assert_eq!(
+                result.as_ref(),
+                <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result)
+            );
             println!("Frames {} match", i);
 
             if display_verify {
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 50fe1760..2892cd34 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -2,9 +2,9 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("matmul");
 
@@ -28,10 +28,14 @@ fn main() {
             let a = HerculesCPURef::from_slice(&a);
             let b = HerculesCPURef::from_slice(&b);
             let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await;
+            let c = r
+                .run(I as u64, J as u64, K as u64, a.clone(), b.clone())
+                .await;
             assert_eq!(c.as_slice::<i32>(), &*correct_c);
             let mut r = runner!(tiled_64_matmul);
-            let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await;
+            let tiled_c = r
+                .run(I as u64, J as u64, K as u64, a.clone(), b.clone())
+                .await;
             assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
         }
         #[cfg(feature = "cuda")]
@@ -39,12 +43,16 @@ fn main() {
             let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
             let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
             let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await;
+            let c = r
+                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
+                .await;
             let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
             assert_eq!(&*c_cpu, &*correct_c);
             let mut r = runner!(tiled_64_matmul);
-            let tiled_c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await;
+            let tiled_c = r
+                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
+                .await;
             let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             tiled_c.to_cpu_ref(&mut tiled_c_cpu);
             assert_eq!(&*tiled_c_cpu, &*correct_c);
diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs
index bc99a4bd..b364c03c 100644
--- a/juno_samples/nested_ccp/src/main.rs
+++ b/juno_samples/nested_ccp/src/main.rs
@@ -1,8 +1,8 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut};
 
 juno_build::juno!("nested_ccp");
 
diff --git a/juno_samples/patterns/src/main.rs b/juno_samples/patterns/src/main.rs
index 5cc2e7c8..a5586c8b 100644
--- a/juno_samples/patterns/src/main.rs
+++ b/juno_samples/patterns/src/main.rs
@@ -1,6 +1,6 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner};
+use hercules_rt::runner;
 
 juno_build::juno!("patterns");
 
diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs
index 749a660c..0129c4de 100644
--- a/juno_samples/schedule_test/build.rs
+++ b/juno_samples/schedule_test/build.rs
@@ -4,7 +4,11 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("code.jn")
         .unwrap()
-        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
         .unwrap()
         .build()
         .unwrap();
diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs
index 1505d4e5..f769e750 100644
--- a/juno_samples/schedule_test/src/main.rs
+++ b/juno_samples/schedule_test/src/main.rs
@@ -2,9 +2,9 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("code");
 
@@ -43,7 +43,16 @@ fn main() {
             let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b));
             let c = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&c));
             let mut r = runner!(test);
-            let res = r.run(N as u64, M as u64, K as u64, a.get_ref(), b.get_ref(), c.get_ref()).await;
+            let res = r
+                .run(
+                    N as u64,
+                    M as u64,
+                    K as u64,
+                    a.get_ref(),
+                    b.get_ref(),
+                    c.get_ref(),
+                )
+                .await;
             let mut res_cpu: Box<[i32]> = vec![0; correct_res.len()].into_boxed_slice();
             res.to_cpu_ref(&mut res_cpu);
             assert_eq!(&*res_cpu, &*correct_res);
diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs
index 8eb78f7c..687ff414 100644
--- a/juno_samples/simple3/src/main.rs
+++ b/juno_samples/simple3/src/main.rs
@@ -1,8 +1,8 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("simple3");
 
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index ea06a0f2..713c30d4 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -108,7 +108,7 @@ impl FromStr for Appliable {
             "inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
-            },
+            }
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 796437a7..9e85509f 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -36,7 +36,7 @@ impl Pass {
     pub fn num_args(&self) -> usize {
         match self {
             Pass::Xdot => 1,
-            Pass::ForkChunk => 3, 
+            Pass::ForkChunk => 3,
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d176b636..2142d5c5 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1566,7 +1566,7 @@ fn run_pass(
                     // this eventually.
                     let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
                     changed |= c;
-                    inner_changed |= c; 
+                    inner_changed |= c;
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
@@ -1921,24 +1921,32 @@ fn run_pass(
             let dim_idx = args.get(1);
 
             let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected boolean argument".to_string(),
+                });
             };
 
             let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
             };
 
             let Some(Value::Integer { val: tile_size }) = args.get(0) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
             };
 
             assert_eq!(*guarded_flag, true);
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in build_selection(pm, selection)
+                .into_iter()
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -1953,10 +1961,9 @@ fn run_pass(
             assert!(args.is_empty());
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in build_selection(pm, selection)
+                .into_iter()
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
-- 
GitLab


From 43404780896ebb6bd287c8a56783a08cded6852e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 3 Feb 2025 15:53:31 -0600
Subject: [PATCH 24/33] whoops

---
 hercules_opt/src/fork_transforms.rs | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ed6283fd..c4a6ba7f 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -868,7 +868,7 @@ pub fn fork_dim_merge(
             };
             if tid_dim > inner_idx {
                 let new_tid = Node::ThreadID {
-                    control: new_fork_id,
+                    control: new_fork,
                     dimension: tid_dim - 1,
                 };
                 let new_tid = edit.add_node(new_tid);
@@ -876,7 +876,7 @@ pub fn fork_dim_merge(
                 edit.sub_edit(tid, new_tid);
             } else if tid_dim == outer_idx {
                 let outer_tid = Node::ThreadID {
-                    control: new_fork_id,
+                    control: new_fork,
                     dimension: outer_idx,
                 };
                 let outer_tid = edit.add_node(outer_tid);
@@ -894,7 +894,7 @@ pub fn fork_dim_merge(
                 edit = edit.replace_all_uses(tid, rem)?;
             } else if tid_dim == inner_idx {
                 let outer_tid = Node::ThreadID {
-                    control: new_fork_id,
+                    control: new_fork,
                     dimension: outer_idx,
                 };
                 let outer_tid = edit.add_node(outer_tid);
-- 
GitLab


From e1124c929bd07adbe14d73b28af3fea777e5bbbe Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Wed, 5 Feb 2025 12:06:36 -0600
Subject: [PATCH 25/33] add bufferize test

---
 juno_samples/fork_join_tests/src/cpu.sch            |  8 ++++++++
 juno_samples/fork_join_tests/src/fork_join_tests.jn | 13 +++++++++++++
 2 files changed, 21 insertions(+)

diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 38010004..a557cd03 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -12,6 +12,7 @@ cpu(out.test3);
 cpu(out.test4);
 cpu(out.test5);
 
+
 ip-sroa(*);
 sroa(*);
 dce(*);
@@ -42,14 +43,21 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
+xdot[true](*);
+fork-fission-bufferize(test7@loop, test7@bufferize1, test7@bufferize2, test7@bufferize3, test7@bufferize4);
 fork-tile[32, 0, true](test6@loop);
 let out = fork-split(test6@loop);
 //let out = outline(out.test6.fj1);
+
+let out7 = auto-outline(test7);
+cpu(out7.test7);
+
 let out = auto-outline(test6);
 cpu(out.test6);
 ip-sroa(*);
 sroa(*);
 unforkify(out.test6);
+unforkify(out7.test7);
 dce(*);
 ccp(*);
 gvn(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 806cb0f1..128d3ce0 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -81,3 +81,16 @@ fn test6(input: i32) -> i32[1024] {
   }
   return arr;
 }
+
+#[entry]
+fn test7(input : i32) -> i32[8] {
+  let arr : i32[8];
+  @loop for i = 0 to 8 {
+    @bufferize1 let a = arr[i];
+    @bufferize2 let b = a + arr[7-i];
+    @bufferize3 let c = b * i as i32;
+    @bufferize4 let d = c;
+    arr[i] = d;
+  }
+  return arr;
+}
\ No newline at end of file
-- 
GitLab


From 3d9e31eedef992af21699fd94e5ec5ff9fdbb644 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Wed, 5 Feb 2025 12:07:19 -0600
Subject: [PATCH 26/33] stash apply

---
 hercules_cg/src/cpu.rs                        |   4 +-
 hercules_opt/src/fork_transforms.rs           | 298 ++++++++++++++----
 .../hercules_tests/tests/loop_tests.rs        |  27 +-
 juno_samples/fork_join_tests/src/cpu.sch      |  12 +-
 .../fork_join_tests/src/fork_join_tests.jn    |   2 +
 juno_samples/fork_join_tests/src/main.rs      |  22 ++
 juno_samples/matmul/build.rs                  |   2 +
 juno_scheduler/src/compile.rs                 |   1 +
 juno_scheduler/src/ir.rs                      |   1 +
 juno_scheduler/src/pm.rs                      |  50 ++-
 10 files changed, 326 insertions(+), 93 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index f6a1f309..20a3e6cb 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -694,7 +694,7 @@ impl<'a> CPUContext<'a> {
                         );
                         write!(
                             body,
-                            "  {} = call i64 @llvm.umin.i64(i64{},i64%dc{}))\n",
+                            "  {} = call i64 @llvm.umin.i64(i64{},i64%dc{})\n",
                             new_val,
                             cur_value,
                             x.idx()
@@ -719,7 +719,7 @@ impl<'a> CPUContext<'a> {
                         );
                         write!(
                             body,
-                            "  {} = call i64 @llvm.umax.i64(i64{},i64%dc{}))\n",
+                            "  {} = call i64 @llvm.umax.i64(i64{},i64%dc{})\n",
                             new_val,
                             cur_value,
                             x.idx()
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 456f670e..f82bbdfb 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -98,6 +98,31 @@ pub fn find_reduce_dependencies<'a>(
     ret_val
 }
 
+pub fn copy_subgraph_in_edit<'a>(
+    mut edit: FunctionEdit<'a, 'a>,
+    subgraph: HashSet<NodeID>,
+) -> (
+    Result<(FunctionEdit<'a, 'a>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'a>>// a map from old nodes to new nodes
+     // A list of (inside *old* node, outside node) s.t insde old node -> outside node. 
+    // The caller probably wants   
+) {
+    
+    let mut map: HashMap<NodeID, NodeID> = HashMap::new();
+    
+    // Copy nodes in subgraph
+    for old_id in subgraph.iter().cloned() {
+        let new_id = edit.copy_node(old_id);
+        map.insert(old_id, new_id);
+    }
+
+    // Update edges to new nodes
+    for old_id in subgraph.iter() {
+        edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| map.values().contains(node_id))?;
+    }
+
+    Ok((edit, map))
+}
+
 pub fn copy_subgraph(
     editor: &mut FunctionEditor,
     subgraph: HashSet<NodeID>,
@@ -144,6 +169,37 @@ pub fn copy_subgraph(
     (new_nodes, map, outside_users)
 }
 
+
+pub fn fork_fission_bufferize_toplevel<'a>(
+    editor: &'a mut FunctionEditor<'a>,
+    loop_tree: &'a LoopTree,
+    fork_join_map: &'a HashMap<NodeID, NodeID>,
+    data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
+    typing: &'a Vec<TypeID>
+) -> bool {
+
+    let forks: Vec<_> = loop_tree
+        .bottom_up_loops()
+        .into_iter()
+        .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
+        .collect();
+
+    for l in forks {
+        let fork_info = &Loop {
+            header: l.0,
+            control: l.1.clone(),
+        };
+        let fork = fork_info.header;
+        let join = fork_join_map[&fork];
+        let mut edges = HashSet::new();
+        edges.insert((NodeID::new(8), NodeID::new(3)));
+        let result = fork_bufferize_fission_helper(editor, fork_info, edges, data_node_in_fork_joins, typing, fork, join);
+        return result.is_some();
+    }
+    return false;
+}
+
+
 pub fn fork_fission<'a>(
     editor: &'a mut FunctionEditor,
     _control_subgraph: &Subgraph,
@@ -184,71 +240,100 @@ pub fn fork_fission<'a>(
 
 /** Split a 1D fork into two forks, placing select intermediate data into buffers. */
 pub fn fork_bufferize_fission_helper<'a>(
-    editor: &'a mut FunctionEditor,
-    fork_join_map: &HashMap<NodeID, NodeID>,
+    editor: &'a mut FunctionEditor<'a>,
+    l: &Loop,
     bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
-    _original_control_pred: NodeID,              // What the new fork connects to.
+    data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
     types: &Vec<TypeID>,
     fork: NodeID,
-) -> (NodeID, NodeID) {
+    join: NodeID
+) -> Option<(NodeID, NodeID)> {
     // Returns the two forks that it generates.
 
-    // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork.
+    if bufferized_edges.is_empty() {
+        return None
+    }
 
-    // Copy fork + control intermediates + join to new fork + join,
-    // How does control get partitioned?
-    //      (depending on how it affects the data nodes on each side of the bufferized_edges)
-    //      may end up in each loop, fix me later.
-    // place new fork + join after join of first.
+    let all_loop_nodes = l.get_all_nodes();
 
-    // Only handle fork+joins with no inner control for now.
+    // FIXME: Cloning hell.
+    let data_nodes = data_node_in_fork_joins[&fork].clone();
+    let loop_nodes = editor.node_ids().filter(|node_id| all_loop_nodes[node_id.idx()]);
+    // Clone the subgraph that consists of this fork-join and all data and control nodes in it. 
+    let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes));
 
-    // Create fork + join + Thread control
-    let join = fork_join_map[&fork];
-    let mut new_fork_id = NodeID::new(0);
-    let mut new_join_id = NodeID::new(0);
+    let mut outside_users = Vec::new(); // old_node, outside_user
 
-    editor.edit(|mut edit| {
-        new_join_id = edit.add_node(Node::Join { control: fork });
-        let factors = edit.get_node(fork).try_fork().unwrap().1;
-        new_fork_id = edit.add_node(Node::Fork {
-            control: new_join_id,
-            factors: factors.into(),
-        });
-        edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
-    });
+    for node in subgraph.iter() {
+        for user in editor.get_users(*node) {
+            if !subgraph.iter().contains(&user) {
+                outside_users.push((*node, user));
+            }
+        }
+    }
 
-    for (src, dst) in bufferized_edges {
-        // FIXME: Disgusting cloning and allocationing and iterators.
-        let factors: Vec<_> = editor.func().nodes[fork.idx()]
-            .try_fork()
-            .unwrap()
-            .1
-            .iter()
-            .cloned()
-            .collect();
-        editor.edit(|mut edit| {
-            // Create write to buffer
+    let factors: Vec<_> = editor.func().nodes[fork.idx()]
+        .try_fork()
+        .unwrap()
+        .1
+        .iter()
+        .cloned()
+        .collect();
 
-            let thread_stuff_it = factors.into_iter().enumerate();
+    let thread_stuff_it = factors.into_iter().enumerate();
 
-            // FIxme: try to use unzip here? Idk why it wasn't working.
-            let tids = thread_stuff_it.clone().map(|(dim, _)| {
-                edit.add_node(Node::ThreadID {
-                    control: fork,
-                    dimension: dim,
-                })
+    // Control succesors
+    let fork_pred = editor.get_uses(fork).filter(|a| editor.node(a).is_control()).next().unwrap();
+    let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap();
+
+    let mut new_fork_id = NodeID::new(0);
+    editor.edit(|edit| {
+        let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
+
+        // Put new subgraph after old subgraph
+        println!("map: {:?}", map);
+        println!("join: {:?}, fork: {:?}", join, fork);
+        println!("fork_sccueue: {:?}", join_successor);
+        edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
+        edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
+
+        // Replace outside uses of reduces in old subgraph with reduces in new subgraph.
+        for (old_node, outside_user) in outside_users {
+            edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
+        }
+
+        // Add buffers to old subgraph
+
+        let new_join = map[&join];
+        let new_fork = map[&fork];
+        
+        // FIXME: Do this as part of copy subgraph?
+        // Add tids to original subgraph for indexing.
+        let mut old_tids = Vec::new();
+        let mut new_tids = Vec::new();
+        for (dim, _) in thread_stuff_it.clone() {
+            let old_id = edit.add_node(Node::ThreadID {
+                control: fork,
+                dimension: dim,
             });
 
-            let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
+            let new_id = edit.add_node(Node::ThreadID {
+                control: new_fork,
+                dimension: dim,
+            });
 
-            // Assume 1-d fork only for now.
-            // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
-            let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
+            old_tids.push(old_id);
+            new_tids.push(new_id);
+        }
+    
+        for (src, dst) in &bufferized_edges {
+            let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
+            let position_idx = Index::Position(old_tids.clone().into_boxed_slice());
+            
             let write = edit.add_node(Node::Write {
                 collect: NodeID::new(0),
-                data: src,
-                indices: vec![position_idx].into(),
+                data: *src,
+                indices: vec![position_idx.clone()].into(),
             });
             let ele_type = types[src.idx()];
             let empty_buffer = edit.add_type(hercules_ir::Type::Array(
@@ -258,36 +343,125 @@ pub fn fork_bufferize_fission_helper<'a>(
             let empty_buffer = edit.add_zero_constant(empty_buffer);
             let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
             let reduce = Node::Reduce {
-                control: new_join_id,
+                control: join,
                 init: empty_buffer,
                 reduct: write,
             };
             let reduce = edit.add_node(reduce);
+
             // Fix write node
             edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
 
-            // Create read from buffer
-            let tids = thread_stuff_it.clone().map(|(dim, _)| {
-                edit.add_node(Node::ThreadID {
-                    control: new_fork_id,
-                    dimension: dim,
-                })
-            });
 
-            let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
+            // Create reads from buffer
+            let position_idx = Index::Position(new_tids.clone().into_boxed_slice());
 
             let read = edit.add_node(Node::Read {
                 collect: reduce,
                 indices: vec![position_idx].into(),
             });
 
-            edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;
+            // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph
+            edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?;
+            
+        }
+        
+        
+
 
-            Ok(edit)
-        });
-    }
+        new_fork_id = new_fork;
+
+        Ok(edit)
+    });
 
-    (fork, new_fork_id)
+    Some((fork, new_fork_id))
+
+    // let internal_control: Vec<NodeID> = Vec::new();
+
+    // // Create fork + join + Thread control
+    // let mut new_fork_id = NodeID::new(0);
+    // let mut new_join_id = NodeID::new(0);
+
+    // editor.edit(|mut edit| {
+    //     new_join_id = edit.add_node(Node::Join { control: fork });
+    //     let factors = edit.get_node(fork).try_fork().unwrap().1;
+    //     new_fork_id = edit.add_node(Node::Fork {
+    //         control: new_join_id,
+    //         factors: factors.into(),
+    //     });
+    //     edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
+    // });
+
+    // for (src, dst) in bufferized_edges {
+    //     // FIXME: Disgusting cloning and allocationing and iterators.
+    //     let factors: Vec<_> = editor.func().nodes[fork.idx()]
+    //         .try_fork()
+    //         .unwrap()
+    //         .1
+    //         .iter()
+    //         .cloned()
+    //         .collect();
+    //     editor.edit(|mut edit| {
+    //         // Create write to buffer
+
+    //         let thread_stuff_it = factors.into_iter().enumerate();
+
+    //         // FIxme: try to use unzip here? Idk why it wasn't working.
+    //         let tids = thread_stuff_it.clone().map(|(dim, _)| {
+    //             edit.add_node(Node::ThreadID {
+    //                 control: fork,
+    //                 dimension: dim,
+    //             })
+    //         });
+
+    //         let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
+
+    //         // Assume 1-d fork only for now.
+    //         // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
+    //         let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
+    //         let write = edit.add_node(Node::Write {
+    //             collect: NodeID::new(0),
+    //             data: src,
+    //             indices: vec![position_idx].into(),
+    //         });
+    //         let ele_type = types[src.idx()];
+    //         let empty_buffer = edit.add_type(hercules_ir::Type::Array(
+    //             ele_type,
+    //             array_dims.collect::<Vec<_>>().into_boxed_slice(),
+    //         ));
+    //         let empty_buffer = edit.add_zero_constant(empty_buffer);
+    //         let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
+    //         let reduce = Node::Reduce {
+    //             control: new_join_id,
+    //             init: empty_buffer,
+    //             reduct: write,
+    //         };
+    //         let reduce = edit.add_node(reduce);
+    //         // Fix write node
+    //         edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
+
+    //         // Create read from buffer
+    //         let tids = thread_stuff_it.clone().map(|(dim, _)| {
+    //             edit.add_node(Node::ThreadID {
+    //                 control: new_fork_id,
+    //                 dimension: dim,
+    //             })
+    //         });
+
+    //         let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
+
+    //         let read = edit.add_node(Node::Read {
+    //             collect: reduce,
+    //             indices: vec![position_idx].into(),
+    //         });
+
+    //         edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;
+
+    //         Ok(edit)
+    //     });
+    // }
+
+    // (fork, new_fork_id)
 }
 
 /** Split a 1D fork into a separate fork for each reduction. */
diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs
index 795642b2..f42a6520 100644
--- a/hercules_test/hercules_tests/tests/loop_tests.rs
+++ b/hercules_test/hercules_tests/tests/loop_tests.rs
@@ -385,7 +385,7 @@ fn matmul_pipeline() {
     let dyn_consts = [I, J, K];
 
     // FIXME: This path should not leave the crate
-    let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin");
+    let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin");
     //
     let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
     for i in 0..I {
@@ -398,25 +398,20 @@ fn matmul_pipeline() {
 
     let result_1 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone());
 
-    println!("golden: {:?}", correct_c);
-    println!("result: {:?}", result_1);
+    // println!("golden: {:?}", correct_c);
+    // println!("result: {:?}", result_1);
 
-    let InterpreterVal::Array(_, d) = result_1.clone() else {
-        panic!()
-    };
-    let InterpreterVal::Integer32(value) = d[0] else {
-        panic!()
-    };
-    assert_eq!(correct_c[0], value);
+    // let InterpreterVal::Array(_, d) = result_1.clone() else {
+    //     panic!()
+    // };
+    // let InterpreterVal::Integer32(value) = d[0] else {
+    //     panic!()
+    // };
+    // assert_eq!(correct_c[0], value);
 
     let schedule = Some(default_schedule![
-        AutoOutline,
-        InterproceduralSROA,
-        SROA,
-        InferSchedules,
-        DCE,
         Xdot,
-        GCM
+        Verify,
     ]);
 
     module = run_schedule_on_hercules(module, schedule).unwrap();
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index a557cd03..947d0dc8 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -7,10 +7,13 @@ dce(*);
 
 let out = auto-outline(test1, test2, test3, test4, test5);
 cpu(out.test1);
+<<<<<<< Updated upstream
 cpu(out.test2);
 cpu(out.test3);
 cpu(out.test4);
 cpu(out.test5);
+=======
+>>>>>>> Stashed changes
 
 
 ip-sroa(*);
@@ -24,8 +27,15 @@ dce(*);
 fixpoint panic after 20 {
   forkify(*);
   fork-guard-elim(*);
-  fork-coalesce(*);
+  fork-coalesce(*);  
 }
+
+dce(*);
+gvn(*);
+
+xdot[true](*);
+fork-fission-bufferize(*);
+xdot[true](*);
 gvn(*);
 phi-elim(*);
 dce(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 128d3ce0..90d06c2f 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -9,6 +9,7 @@ fn test1(input : i32) -> i32[4, 4] {
   return arr;
 }
 
+/**
 #[entry]
 fn test2(input : i32) -> i32[4, 4] {
   let arr : i32[4, 4];
@@ -72,6 +73,7 @@ fn test5(input : i32) -> i32[4] {
   }
   return arr1;
 }
+<<<<<<< Updated upstream
 
 #[entry]
 fn test6(input: i32) -> i32[1024] {
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 19838fd7..33c5602e 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -23,6 +23,7 @@ fn main() {
         let correct = vec![5i32; 16];
         assert(&correct, output);
 
+<<<<<<< Updated upstream
         let mut r = runner!(test2);
         let output = r.run(3).await;
         let correct = vec![24i32; 16];
@@ -47,6 +48,27 @@ fn main() {
         let output = r.run(73).await;
         let correct = (73i32..73i32+1024i32).collect();
         assert(&correct, output);
+=======
+        // let mut r = runner!(test2);
+        // let output = r.run(3).await;
+        // let correct = vec![24i32; 16];
+        // assert(correct, output);
+
+        // let mut r = runner!(test3);
+        // let output = r.run(0).await;
+        // let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
+        // assert(correct, output);
+
+        // let mut r = runner!(test4);
+        // let output = r.run(9).await;
+        // let correct = vec![63i32; 16];
+        // assert(correct, output);
+
+        // let mut r = runner!(test5);
+        // let output = r.run(4).await;
+        // let correct = vec![7i32; 4];
+        // assert(correct, output);
+>>>>>>> Stashed changes
     });
 }
 
diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index 0be838c6..d2813388 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -6,6 +6,8 @@ fn main() {
         JunoCompiler::new()
             .file_in_src("matmul.jn")
             .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
             .build()
             .unwrap();
     }
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 713c30d4..08e952a4 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -109,6 +109,7 @@ impl FromStr for Appliable {
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             }
+            "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)),
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 9e85509f..8a6e04ed 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -12,6 +12,7 @@ pub enum Pass {
     ForkSplit,
     ForkCoalesce,
     Forkify,
+    ForkFissionBufferize,
     ForkDimMerge,
     ForkChunk,
     GCM,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d4acac19..28b0cbf5 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1987,10 +1987,7 @@ fn run_pass(
                 });
             };
             let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
-                return Err(SchedulerError::PassError {
-                    pass: "forkChunk".to_string(),
-                    error: "expected integer argument".to_string(),
-                });
+                panic!(); // How to error here?
             };
             let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
                 return Err(SchedulerError::PassError {
@@ -2002,9 +1999,10 @@ fn run_pass(
             assert_eq!(*guarded_flag, true);
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in build_selection(pm, selection)
-                .into_iter()
-                .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -2015,13 +2013,41 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkFissionBufferize => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            pm.make_typing();
+            pm.make_loops();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let typing = pm.typing.take().unwrap();
+            let loops = pm.loops.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(loops.iter())
+                    .zip(typing.iter())
+                    .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                let result = fork_fission_bufferize_toplevel(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing);
+                changed |= result;
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in build_selection(pm, selection)
-                .into_iter()
-                .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -2051,8 +2077,8 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                changed |= fork_coalesce(&mut func, loop_nest, fork_join_map);
-                // func.modified();
+                fork_coalesce(&mut func, loop_nest, fork_join_map);
+                changed |= func.modified();
             }
             pm.delete_gravestones();
             pm.clear_analyses();
-- 
GitLab


From d2cec2531bce086be7132bae76b16383d7ddccaa Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 8 Feb 2025 13:23:19 -0600
Subject: [PATCH 27/33] fork fission bufferize

---
 hercules_opt/src/fork_transforms.rs           | 107 +++++++++++++-----
 juno_samples/fork_join_tests/src/blah.sch     |  34 ++++++
 juno_samples/fork_join_tests/src/cpu.sch      |  14 +--
 .../fork_join_tests/src/fork_join_tests.jn    |  24 ++--
 juno_samples/fork_join_tests/src/main.rs      |  28 +----
 juno_scheduler/src/pm.rs                      |  60 +++++++++-
 6 files changed, 198 insertions(+), 69 deletions(-)
 create mode 100644 juno_samples/fork_join_tests/src/blah.sch

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index f82bbdfb..ea486c94 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,6 +1,7 @@
 use std::collections::{HashMap, HashSet};
 use std::iter::zip;
 use std::thread::ThreadId;
+use std::hash::Hash;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -98,11 +99,11 @@ pub fn find_reduce_dependencies<'a>(
     ret_val
 }
 
-pub fn copy_subgraph_in_edit<'a>(
-    mut edit: FunctionEdit<'a, 'a>,
+pub fn copy_subgraph_in_edit<'a, 'b>(
+    mut edit: FunctionEdit<'a, 'b>,
     subgraph: HashSet<NodeID>,
 ) -> (
-    Result<(FunctionEdit<'a, 'a>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'a>>// a map from old nodes to new nodes
+    Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>// a map from old nodes to new nodes
      // A list of (inside *old* node, outside node) s.t insde old node -> outside node. 
     // The caller probably wants   
 ) {
@@ -169,14 +170,56 @@ pub fn copy_subgraph(
     (new_nodes, map, outside_users)
 }
 
+fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> bool {
+    // A strict superset must be larger than its subset
+    if set1.len() <= set2.len() {
+        return false;
+    }
+    
+    // Every element in set2 must be in set1
+    set2.iter().all(|item| set1.contains(item))
+}
 
-pub fn fork_fission_bufferize_toplevel<'a>(
-    editor: &'a mut FunctionEditor<'a>,
-    loop_tree: &'a LoopTree,
-    fork_join_map: &'a HashMap<NodeID, NodeID>,
-    data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
-    typing: &'a Vec<TypeID>
-) -> bool {
+pub fn find_bufferize_edges(
+    editor: & mut FunctionEditor,
+    fork: NodeID,
+    loop_tree: & LoopTree,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: & HashMap<NodeID, HashSet<NodeID>>,
+) -> HashSet<(NodeID, NodeID)> {
+
+    // println!("func: {:?}", editor.func_id());
+    let mut edges: HashSet<_> = HashSet::new();
+    // print labels 
+    for node in &nodes_in_fork_joins[&fork] {
+        println!("node: {:?}, label {:?}, ", node, editor.func().labels[node.idx()]);
+        let node_labels = &editor.func().labels[node.idx()];
+        for usee in editor.get_uses(*node) {
+            // If usee labels is a superset of this node labels, then make an edge. 
+            let usee_labels = &editor.func().labels[usee.idx()];
+            // strict superset
+            if !(usee_labels.is_superset(node_labels) && usee_labels.len() > node_labels.len())  {
+                continue;
+            }
+
+            if editor.node(usee).is_control() || editor.node(node).is_control() {
+                continue;
+            }
+
+            edges.insert((usee, *node));
+        }
+    }
+    println!("edges: {:?}", edges);
+    edges
+}
+
+pub fn ff_bufferize_any_fork<'a, 'b>(
+    editor: &'b mut FunctionEditor<'a>,
+    loop_tree: &'b LoopTree,
+    fork_join_map: &'b HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
+    typing: &'b Vec<TypeID>
+) -> Option<(NodeID, NodeID)> where 'a: 'b {
 
     let forks: Vec<_> = loop_tree
         .bottom_up_loops()
@@ -185,18 +228,22 @@ pub fn fork_fission_bufferize_toplevel<'a>(
         .collect();
 
     for l in forks {
-        let fork_info = &Loop {
+        let fork_info = Loop {
             header: l.0,
             control: l.1.clone(),
         };
         let fork = fork_info.header;
         let join = fork_join_map[&fork];
-        let mut edges = HashSet::new();
-        edges.insert((NodeID::new(8), NodeID::new(3)));
-        let result = fork_bufferize_fission_helper(editor, fork_info, edges, data_node_in_fork_joins, typing, fork, join);
-        return result.is_some();
+
+        let edges = find_bufferize_edges(editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins);
+        let result = fork_bufferize_fission_helper(editor, &fork_info, &edges, nodes_in_fork_joins, typing, fork, join);
+        if result.is_none() {
+            continue
+        } else {
+            return result; 
+        }
     }
-    return false;
+    return None;
 }
 
 
@@ -239,15 +286,15 @@ pub fn fork_fission<'a>(
 }
 
 /** Split a 1D fork into two forks, placing select intermediate data into buffers. */
-pub fn fork_bufferize_fission_helper<'a>(
-    editor: &'a mut FunctionEditor<'a>,
+pub fn fork_bufferize_fission_helper<'a, 'b>(
+    editor: &'b mut FunctionEditor<'a>,
     l: &Loop,
-    bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
-    data_node_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
-    types: &Vec<TypeID>,
+    bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
+    data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
+    types: &'b Vec<TypeID>,
     fork: NodeID,
     join: NodeID
-) -> Option<(NodeID, NodeID)> {
+) -> Option<(NodeID, NodeID)>  where 'a: 'b {
     // Returns the two forks that it generates.
 
     if bufferized_edges.is_empty() {
@@ -287,13 +334,15 @@ pub fn fork_bufferize_fission_helper<'a>(
     let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap();
 
     let mut new_fork_id = NodeID::new(0);
-    editor.edit(|edit| {
+
+    let edit_result = editor.edit(|edit| {
         let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
 
         // Put new subgraph after old subgraph
         println!("map: {:?}", map);
         println!("join: {:?}, fork: {:?}", join, fork);
         println!("fork_sccueue: {:?}", join_successor);
+
         edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
         edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
 
@@ -302,6 +351,8 @@ pub fn fork_bufferize_fission_helper<'a>(
             edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
         }
 
+
+
         // Add buffers to old subgraph
 
         let new_join = map[&join];
@@ -326,7 +377,7 @@ pub fn fork_bufferize_fission_helper<'a>(
             new_tids.push(new_id);
         }
     
-        for (src, dst) in &bufferized_edges {
+        for (src, dst) in bufferized_edges {
             let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
             let position_idx = Index::Position(old_tids.clone().into_boxed_slice());
             
@@ -366,14 +417,18 @@ pub fn fork_bufferize_fission_helper<'a>(
             
         }
         
-        
-
 
         new_fork_id = new_fork;
 
         Ok(edit)
     });
 
+    println!("edit_result: {:?}", edit_result);
+    if edit_result == false {
+        todo!();
+        return None
+    }
+
     Some((fork, new_fork_id))
 
     // let internal_control: Vec<NodeID> = Vec::new();
diff --git a/juno_samples/fork_join_tests/src/blah.sch b/juno_samples/fork_join_tests/src/blah.sch
new file mode 100644
index 00000000..52dea702
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/blah.sch
@@ -0,0 +1,34 @@
+
+xdot[true](*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+  dce(*);
+}
+
+xdot[true](*);
+
+//gvn(*);
+//phi-elim(*);
+//dce(*);
+
+//gvn(*);
+//phi-elim(*);
+//dce(*);
+
+//fixpoint panic after 20 {
+//  infer-schedules(*);
+//}
+
+//fork-split(*);
+//gvn(*);
+//phi-elim(*);
+//dce(*);
+//unforkify(*);
+//gvn(*);
+//phi-elim(*);
+//dce(*);
+
+//gcm(*);
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 947d0dc8..e04d8dfe 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -7,13 +7,10 @@ dce(*);
 
 let out = auto-outline(test1, test2, test3, test4, test5);
 cpu(out.test1);
-<<<<<<< Updated upstream
 cpu(out.test2);
 cpu(out.test3);
 cpu(out.test4);
 cpu(out.test5);
-=======
->>>>>>> Stashed changes
 
 
 ip-sroa(*);
@@ -33,9 +30,6 @@ fixpoint panic after 20 {
 dce(*);
 gvn(*);
 
-xdot[true](*);
-fork-fission-bufferize(*);
-xdot[true](*);
 gvn(*);
 phi-elim(*);
 dce(*);
@@ -53,10 +47,12 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-xdot[true](*);
-fork-fission-bufferize(test7@loop, test7@bufferize1, test7@bufferize2, test7@bufferize3, test7@bufferize4);
+fork-fission-bufferize(test7);
+dce(*);
+
 fork-tile[32, 0, true](test6@loop);
 let out = fork-split(test6@loop);
+fork-split(*);
 //let out = outline(out.test6.fj1);
 
 let out7 = auto-outline(test7);
@@ -69,6 +65,8 @@ sroa(*);
 unforkify(out.test6);
 unforkify(out7.test7);
 dce(*);
+unforkify(*);
+dce(*);
 ccp(*);
 gvn(*);
 phi-elim(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 90d06c2f..ae3be778 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -9,7 +9,6 @@ fn test1(input : i32) -> i32[4, 4] {
   return arr;
 }
 
-/**
 #[entry]
 fn test2(input : i32) -> i32[4, 4] {
   let arr : i32[4, 4];
@@ -73,7 +72,6 @@ fn test5(input : i32) -> i32[4] {
   }
   return arr1;
 }
-<<<<<<< Updated upstream
 
 #[entry]
 fn test6(input: i32) -> i32[1024] {
@@ -87,12 +85,22 @@ fn test6(input: i32) -> i32[1024] {
 #[entry]
 fn test7(input : i32) -> i32[8] {
   let arr : i32[8];
+  let out : i32[8];
+
+  for i = 0 to 8 {
+    arr[i] = i as i32;
+  }
+
   @loop for i = 0 to 8 {
-    @bufferize1 let a = arr[i];
-    @bufferize2 let b = a + arr[7-i];
-    @bufferize3 let c = b * i as i32;
-    @bufferize4 let d = c;
-    arr[i] = d;
+    let b: i32;
+    @bufferize1 {
+      let a = arr[i];
+      let a2 = a + arr[7-i];
+      b = a2 * i as i32;
+    }
+    let c = b;
+    let d = c + 10;
+    out[i] = d;
   }
-  return arr;
+  return out;
 }
\ No newline at end of file
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 33c5602e..caf956a1 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -23,7 +23,6 @@ fn main() {
         let correct = vec![5i32; 16];
         assert(&correct, output);
 
-<<<<<<< Updated upstream
         let mut r = runner!(test2);
         let output = r.run(3).await;
         let correct = vec![24i32; 16];
@@ -44,31 +43,10 @@ fn main() {
         let correct = vec![7i32; 4];
         assert(&correct, output);
 
-        let mut r = runner!(test6);
-        let output = r.run(73).await;
-        let correct = (73i32..73i32+1024i32).collect();
+        let mut r = runner!(test7);
+        let output = r.run(0).await;
+        let correct = vec![10, 17, 24, 31, 38, 45, 52, 59];
         assert(&correct, output);
-=======
-        // let mut r = runner!(test2);
-        // let output = r.run(3).await;
-        // let correct = vec![24i32; 16];
-        // assert(correct, output);
-
-        // let mut r = runner!(test3);
-        // let output = r.run(0).await;
-        // let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
-        // assert(correct, output);
-
-        // let mut r = runner!(test4);
-        // let output = r.run(9).await;
-        // let correct = vec![63i32; 16];
-        // assert(correct, output);
-
-        // let mut r = runner!(test5);
-        // let output = r.run(4).await;
-        // let correct = vec![7i32; 4];
-        // assert(correct, output);
->>>>>>> Stashed changes
     });
 }
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 28b0cbf5..2c5a3687 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2015,6 +2015,8 @@ fn run_pass(
         }
         Pass::ForkFissionBufferize => {
             assert!(args.is_empty());
+            let mut created_fork_joins = vec![vec![]; pm.functions.len()];
+
             pm.make_fork_join_maps();
             pm.make_typing();
             pm.make_loops();
@@ -2034,9 +2036,63 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                let result = fork_fission_bufferize_toplevel(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing);
-                changed |= result;
+                if let Some((fork1, fork2)) = 
+                    ff_bufferize_any_fork(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing)
+                {
+                    let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
+                    created_fork_joins.push(fork1);
+                    created_fork_joins.push(fork2);
+                }
+                changed |= func.modified();
+            }
+
+            pm.make_nodes_in_fork_joins();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            let mut new_fork_joins = HashMap::new();
+
+            for (mut func, created_fork_joins) in
+                build_editors(pm).into_iter().zip(created_fork_joins)
+            {
+                 // For every function, create a label for every level of fork-
+                // joins resulting from the split.
+                let name = func.func().name.clone();
+                let func_id = func.func_id();
+                let labels = create_labels_for_node_sets(
+                    &mut func,
+                    created_fork_joins.into_iter().map(|fork| {
+                                nodes_in_fork_joins[func_id.idx()][&fork]
+                                    .iter()
+                                    .map(|id| *id)
+                            })
+                    ,
+                );
+
+                // Assemble those labels into a record for this function. The
+                // format of the records is <function>.<f>, where N is the
+                // level of the split fork-joins being referred to.
+                todo!();
+                // FIXME: What if there are multiple bufferized forks per function?
+                let mut func_record = HashMap::new();
+                for (idx, label) in labels {
+                    func_record.insert(
+                        format!("fj{}", idx),
+                        Value::Label {
+                            labels: vec![LabelInfo {
+                                func: func_id,
+                                label: label,
+                            }],
+                        },
+                    );
+                }
+
+                // Try to avoid creating unnecessary record entries.
+                if !func_record.is_empty() {
+                    new_fork_joins.entry(name).insert_entry(Value::Record {
+                        fields: func_record,
+                    });
+                }
             }
+
             pm.delete_gravestones();
             pm.clear_analyses();
         }
-- 
GitLab


From c0d161709074c5c79705de7052953b0fb27489a4 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 9 Feb 2025 13:09:53 -0600
Subject: [PATCH 28/33] fork fission bufferize pm

---
 hercules_opt/src/fork_transforms.rs      | 144 +++++++++++++++--------
 juno_samples/fork_join_tests/src/cpu.sch |  11 +-
 juno_samples/fork_join_tests/src/main.rs |   2 +-
 juno_scheduler/src/ir.rs                 |   1 +
 juno_scheduler/src/pm.rs                 | 101 ++++++++++++----
 5 files changed, 179 insertions(+), 80 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ea486c94..ad81d90b 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,7 +1,7 @@
 use std::collections::{HashMap, HashSet};
+use std::hash::Hash;
 use std::iter::zip;
 use std::thread::ThreadId;
-use std::hash::Hash;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -102,14 +102,9 @@ pub fn find_reduce_dependencies<'a>(
 pub fn copy_subgraph_in_edit<'a, 'b>(
     mut edit: FunctionEdit<'a, 'b>,
     subgraph: HashSet<NodeID>,
-) -> (
-    Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>// a map from old nodes to new nodes
-     // A list of (inside *old* node, outside node) s.t insde old node -> outside node. 
-    // The caller probably wants   
-) {
-    
+) -> (Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>) {
     let mut map: HashMap<NodeID, NodeID> = HashMap::new();
-    
+
     // Copy nodes in subgraph
     for old_id in subgraph.iter().cloned() {
         let new_id = edit.copy_node(old_id);
@@ -118,7 +113,9 @@ pub fn copy_subgraph_in_edit<'a, 'b>(
 
     // Update edges to new nodes
     for old_id in subgraph.iter() {
-        edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| map.values().contains(node_id))?;
+        edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| {
+            map.values().contains(node_id)
+        })?;
     }
 
     Ok((edit, map))
@@ -175,38 +172,48 @@ fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> boo
     if set1.len() <= set2.len() {
         return false;
     }
-    
+
     // Every element in set2 must be in set1
     set2.iter().all(|item| set1.contains(item))
 }
 
 pub fn find_bufferize_edges(
-    editor: & mut FunctionEditor,
+    editor: &mut FunctionEditor,
     fork: NodeID,
-    loop_tree: & LoopTree,
+    loop_tree: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
-    nodes_in_fork_joins: & HashMap<NodeID, HashSet<NodeID>>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    data_label: &LabelID,
 ) -> HashSet<(NodeID, NodeID)> {
-
     // println!("func: {:?}", editor.func_id());
     let mut edges: HashSet<_> = HashSet::new();
-    // print labels 
+
+    println!("ndoes in fork joins: {:?}", &nodes_in_fork_joins[&fork]);
+    // print labels
     for node in &nodes_in_fork_joins[&fork] {
-        println!("node: {:?}, label {:?}, ", node, editor.func().labels[node.idx()]);
+        // Edge from *has data label** to doesn't have data label*
         let node_labels = &editor.func().labels[node.idx()];
-        for usee in editor.get_uses(*node) {
-            // If usee labels is a superset of this node labels, then make an edge. 
-            let usee_labels = &editor.func().labels[usee.idx()];
-            // strict superset
-            if !(usee_labels.is_superset(node_labels) && usee_labels.len() > node_labels.len())  {
+
+        if !node_labels.contains(data_label) {
+            continue;
+        }
+
+        // Don't draw bufferize edges from fork tids
+        if editor.get_users(fork).contains(node) {
+            continue;
+        }
+
+        for user in editor.get_users(*node) {
+            let user_labels = &editor.func().labels[user.idx()];
+            if user_labels.contains(data_label) {
                 continue;
             }
 
-            if editor.node(usee).is_control() || editor.node(node).is_control() {
+            if editor.node(user).is_control() || editor.node(node).is_control() {
                 continue;
             }
 
-            edges.insert((usee, *node));
+            edges.insert((*node, user));
         }
     }
     println!("edges: {:?}", edges);
@@ -218,15 +225,20 @@ pub fn ff_bufferize_any_fork<'a, 'b>(
     loop_tree: &'b LoopTree,
     fork_join_map: &'b HashMap<NodeID, NodeID>,
     nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
-    typing: &'b Vec<TypeID>
-) -> Option<(NodeID, NodeID)> where 'a: 'b {
-
+    typing: &'b Vec<TypeID>,
+    fork_label: &'b LabelID,
+    data_label: &'b LabelID,
+) -> Option<(NodeID, NodeID)>
+where
+    'a: 'b,
+{
     let forks: Vec<_> = loop_tree
         .bottom_up_loops()
         .into_iter()
         .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
+    println!("fork_label {:?}", fork_label);
     for l in forks {
         let fork_info = Loop {
             header: l.0,
@@ -235,18 +247,39 @@ pub fn ff_bufferize_any_fork<'a, 'b>(
         let fork = fork_info.header;
         let join = fork_join_map[&fork];
 
-        let edges = find_bufferize_edges(editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins);
-        let result = fork_bufferize_fission_helper(editor, &fork_info, &edges, nodes_in_fork_joins, typing, fork, join);
+        println!("fork labels: {:?}", editor.func().labels[fork.idx()]);
+        if !editor.func().labels[fork.idx()].contains(fork_label) {
+            continue;
+        }
+
+        println!("fork: {:?}", fork);
+
+        let edges = find_bufferize_edges(
+            editor,
+            fork,
+            &loop_tree,
+            &fork_join_map,
+            &nodes_in_fork_joins,
+            data_label,
+        );
+        let result = fork_bufferize_fission_helper(
+            editor,
+            &fork_info,
+            &edges,
+            nodes_in_fork_joins,
+            typing,
+            fork,
+            join,
+        );
         if result.is_none() {
-            continue
+            continue;
         } else {
-            return result; 
+            return result;
         }
     }
     return None;
 }
 
-
 pub fn fork_fission<'a>(
     editor: &'a mut FunctionEditor,
     _control_subgraph: &Subgraph,
@@ -293,20 +326,25 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
     data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
     types: &'b Vec<TypeID>,
     fork: NodeID,
-    join: NodeID
-) -> Option<(NodeID, NodeID)>  where 'a: 'b {
+    join: NodeID,
+) -> Option<(NodeID, NodeID)>
+where
+    'a: 'b,
+{
     // Returns the two forks that it generates.
 
     if bufferized_edges.is_empty() {
-        return None
+        return None;
     }
 
     let all_loop_nodes = l.get_all_nodes();
 
     // FIXME: Cloning hell.
     let data_nodes = data_node_in_fork_joins[&fork].clone();
-    let loop_nodes = editor.node_ids().filter(|node_id| all_loop_nodes[node_id.idx()]);
-    // Clone the subgraph that consists of this fork-join and all data and control nodes in it. 
+    let loop_nodes = editor
+        .node_ids()
+        .filter(|node_id| all_loop_nodes[node_id.idx()]);
+    // Clone the subgraph that consists of this fork-join and all data and control nodes in it.
     let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes));
 
     let mut outside_users = Vec::new(); // old_node, outside_user
@@ -330,8 +368,16 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
     let thread_stuff_it = factors.into_iter().enumerate();
 
     // Control succesors
-    let fork_pred = editor.get_uses(fork).filter(|a| editor.node(a).is_control()).next().unwrap();
-    let join_successor = editor.get_users(join).filter(|a| editor.node(a).is_control()).next().unwrap();
+    let fork_pred = editor
+        .get_uses(fork)
+        .filter(|a| editor.node(a).is_control())
+        .next()
+        .unwrap();
+    let join_successor = editor
+        .get_users(join)
+        .filter(|a| editor.node(a).is_control())
+        .next()
+        .unwrap();
 
     let mut new_fork_id = NodeID::new(0);
 
@@ -339,25 +385,24 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
         let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
 
         // Put new subgraph after old subgraph
-        println!("map: {:?}", map);
-        println!("join: {:?}, fork: {:?}", join, fork);
-        println!("fork_sccueue: {:?}", join_successor);
+        // println!("map: {:?}", map);
+        // println!("join: {:?}, fork: {:?}", join, fork);
+        // println!("fork_sccueue: {:?}", join_successor);
 
         edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
         edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
 
         // Replace outside uses of reduces in old subgraph with reduces in new subgraph.
         for (old_node, outside_user) in outside_users {
-            edit = edit.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
+            edit = edit
+                .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
         }
 
-
-
         // Add buffers to old subgraph
 
         let new_join = map[&join];
         let new_fork = map[&fork];
-        
+
         // FIXME: Do this as part of copy subgraph?
         // Add tids to original subgraph for indexing.
         let mut old_tids = Vec::new();
@@ -376,11 +421,11 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
             old_tids.push(old_id);
             new_tids.push(new_id);
         }
-    
+
         for (src, dst) in bufferized_edges {
             let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
             let position_idx = Index::Position(old_tids.clone().into_boxed_slice());
-            
+
             let write = edit.add_node(Node::Write {
                 collect: NodeID::new(0),
                 data: *src,
@@ -403,7 +448,6 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
             // Fix write node
             edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
 
-
             // Create reads from buffer
             let position_idx = Index::Position(new_tids.clone().into_boxed_slice());
 
@@ -414,9 +458,7 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
 
             // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph
             edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?;
-            
         }
-        
 
         new_fork_id = new_fork;
 
@@ -426,7 +468,7 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
     println!("edit_result: {:?}", edit_result);
     if edit_result == false {
         todo!();
-        return None
+        return None;
     }
 
     Some((fork, new_fork_id))
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index e04d8dfe..57290a62 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -47,23 +47,24 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-fork-fission-bufferize(test7);
-dce(*);
+let blah = fork-fission-bufferize[test8@loop, test8@bufferize1](test8);
+dce(blah.test8.fj_loop_top);
+
 
 fork-tile[32, 0, true](test6@loop);
 let out = fork-split(test6@loop);
 fork-split(*);
 //let out = outline(out.test6.fj1);
 
-let out7 = auto-outline(test7);
-cpu(out7.test7);
+let out8 = auto-outline(test8);
+cpu(out8.test8);
 
 let out = auto-outline(test6);
 cpu(out.test6);
 ip-sroa(*);
 sroa(*);
 unforkify(out.test6);
-unforkify(out7.test7);
+unforkify(out8.test8);
 dce(*);
 unforkify(*);
 dce(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index caf956a1..fa0d80b3 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -43,7 +43,7 @@ fn main() {
         let correct = vec![7i32; 4];
         assert(&correct, output);
 
-        let mut r = runner!(test7);
+        let mut r = runner!(test8);
         let output = r.run(0).await;
         let correct = vec![10, 17, 24, 31, 38, 45, 52, 59];
         assert(&correct, output);
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 8a6e04ed..4e5dc4c5 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -38,6 +38,7 @@ impl Pass {
         match self {
             Pass::Xdot => 1,
             Pass::ForkChunk => 3,
+            Pass::ForkFissionBufferize => 2,
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 2c5a3687..75cd377a 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1987,7 +1987,10 @@ fn run_pass(
                 });
             };
             let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
             };
             let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
                 return Err(SchedulerError::PassError {
@@ -1999,10 +2002,9 @@ fn run_pass(
             assert_eq!(*guarded_flag, true);
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in build_selection(pm, selection)
+                .into_iter()
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -2014,7 +2016,29 @@ fn run_pass(
             pm.clear_analyses();
         }
         Pass::ForkFissionBufferize => {
-            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+
+            assert_eq!(args.len(), 2);
+            let Some(Value::Label {
+                labels: fork_labels,
+            }) = args.get(0)
+            else {
+                return Err(SchedulerError::PassError {
+                    pass: "forkFissionBufferize".to_string(),
+                    error: "expected label argument".to_string(),
+                });
+            };
+
+            let Some(Value::Label {
+                labels: fork_data_labels,
+            }) = args.get(1)
+            else {
+                return Err(SchedulerError::PassError {
+                    pass: "forkFissionBufferize".to_string(),
+                    error: "expected label argument".to_string(),
+                });
+            };
+
             let mut created_fork_joins = vec![vec![]; pm.functions.len()];
 
             pm.make_fork_join_maps();
@@ -2025,6 +2049,21 @@ fn run_pass(
             let typing = pm.typing.take().unwrap();
             let loops = pm.loops.take().unwrap();
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+
+            // assert only one function is in the selection.
+            let num_functions = build_selection(pm, selection.clone())
+                .iter()
+                .filter(|func| func.is_some())
+                .count();
+
+            assert!(num_functions <= 1);
+            assert_eq!(fork_labels.len(), 1);
+            assert_eq!(fork_data_labels.len(), 1);
+
+            let fork_label = fork_labels[0].label;
+            let data_label = fork_data_labels[0].label;
+
+            // Only one func is_some().
             for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in
                 build_selection(pm, selection)
                     .into_iter()
@@ -2036,9 +2075,15 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                if let Some((fork1, fork2)) = 
-                    ff_bufferize_any_fork(&mut func, loop_tree, fork_join_map, nodes_in_fork_joins, typing)
-                {
+                if let Some((fork1, fork2)) = ff_bufferize_any_fork(
+                    &mut func,
+                    loop_tree,
+                    fork_join_map,
+                    nodes_in_fork_joins,
+                    typing,
+                    &fork_label,
+                    &data_label,
+                ) {
                     let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
                     created_fork_joins.push(fork1);
                     created_fork_joins.push(fork2);
@@ -2046,36 +2091,41 @@ fn run_pass(
                 changed |= func.modified();
             }
 
+            pm.clear_analyses();
             pm.make_nodes_in_fork_joins();
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
             let mut new_fork_joins = HashMap::new();
 
+            let fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone();
+
             for (mut func, created_fork_joins) in
                 build_editors(pm).into_iter().zip(created_fork_joins)
             {
-                 // For every function, create a label for every level of fork-
+                // For every function, create a label for every level of fork-
                 // joins resulting from the split.
                 let name = func.func().name.clone();
                 let func_id = func.func_id();
                 let labels = create_labels_for_node_sets(
                     &mut func,
                     created_fork_joins.into_iter().map(|fork| {
-                                nodes_in_fork_joins[func_id.idx()][&fork]
-                                    .iter()
-                                    .map(|id| *id)
-                            })
-                    ,
+                        nodes_in_fork_joins[func_id.idx()][&fork]
+                            .iter()
+                            .map(|id| *id)
+                    }),
                 );
 
                 // Assemble those labels into a record for this function. The
                 // format of the records is <function>.<f>, where N is the
                 // level of the split fork-joins being referred to.
-                todo!();
-                // FIXME: What if there are multiple bufferized forks per function?
                 let mut func_record = HashMap::new();
                 for (idx, label) in labels {
+                    let fmt = if idx % 2 == 0 {
+                        format!("fj_{}_top", fork_label_name)
+                    } else {
+                        format!("fj_{}_bottom", fork_label_name)
+                    };
                     func_record.insert(
-                        format!("fj{}", idx),
+                        fmt,
                         Value::Label {
                             labels: vec![LabelInfo {
                                 func: func_id,
@@ -2095,15 +2145,21 @@ fn run_pass(
 
             pm.delete_gravestones();
             pm.clear_analyses();
+
+            result = Value::Record {
+                fields: new_fork_joins,
+            };
+
+            println!("result: {:?}", result);
+
         }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in build_selection(pm, selection)
+                .into_iter()
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -2211,7 +2267,6 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
-        Pass::ForkChunk => todo!(),
     }
     println!("Ran Pass: {:?}", pass);
 
-- 
GitLab


From ab3035178e8ad0433b8d7ce31ace008b648fe25b Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 9 Feb 2025 13:31:05 -0600
Subject: [PATCH 29/33] pm fixes

---
 juno_samples/fork_join_tests/src/blah.sch     | 34 -------------------
 juno_samples/fork_join_tests/src/cpu.sch      | 12 ++++---
 .../fork_join_tests/src/fork_join_tests.jn    |  2 +-
 juno_scheduler/src/pm.rs                      |  5 ++-
 4 files changed, 10 insertions(+), 43 deletions(-)
 delete mode 100644 juno_samples/fork_join_tests/src/blah.sch

diff --git a/juno_samples/fork_join_tests/src/blah.sch b/juno_samples/fork_join_tests/src/blah.sch
deleted file mode 100644
index 52dea702..00000000
--- a/juno_samples/fork_join_tests/src/blah.sch
+++ /dev/null
@@ -1,34 +0,0 @@
-
-xdot[true](*);
-
-fixpoint panic after 20 {
-  forkify(*);
-  fork-guard-elim(*);
-  fork-coalesce(*);
-  dce(*);
-}
-
-xdot[true](*);
-
-//gvn(*);
-//phi-elim(*);
-//dce(*);
-
-//gvn(*);
-//phi-elim(*);
-//dce(*);
-
-//fixpoint panic after 20 {
-//  infer-schedules(*);
-//}
-
-//fork-split(*);
-//gvn(*);
-//phi-elim(*);
-//dce(*);
-//unforkify(*);
-//gvn(*);
-//phi-elim(*);
-//dce(*);
-
-//gcm(*);
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 3b672846..fd3fc0ff 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -5,7 +5,7 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8);
 cpu(auto.test1);
 cpu(auto.test2);
 cpu(auto.test3);
@@ -50,12 +50,8 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-
 fork-tile[32, 0, true](test6@loop);
 let out = fork-split(test6@loop);
-
-let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8);
-
 let out = outline(out.test6.fj1);
 cpu(out);
 ip-sroa(*);
@@ -75,4 +71,10 @@ dce(auto.test7);
 simplify-cfg(auto.test7);
 dce(auto.test7);
 
+let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8);
+dce(auto.test8);
+unforkify(auto.test8);
+ccp(auto.test8);
+dce(auto.test8);
+
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index bc89b2e2..a765726f 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -82,7 +82,7 @@ fn test6(input: i32) -> i32[1024] {
   return arr;
 }
 
-
+#[entry]
 fn test7(input: i32) -> i32 {
   let arr : i32[32];
   for i = 0 to 32 {
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index f762e5ff..58ff399f 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2094,7 +2094,7 @@ fn run_pass(
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
 
             // assert only one function is in the selection.
-            let num_functions = build_selection(pm, selection.clone())
+            let num_functions = build_selection(pm, selection.clone(), false)
                 .iter()
                 .filter(|func| func.is_some())
                 .count();
@@ -2106,9 +2106,8 @@ fn run_pass(
             let fork_label = fork_labels[0].label;
             let data_label = fork_data_labels[0].label;
 
-            // Only one func is_some().
             for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in
-                build_selection(pm, selection)
+                build_selection(pm, selection, false)
                     .into_iter()
                     .zip(fork_join_maps.iter())
                     .zip(loops.iter())
-- 
GitLab


From b4978ef526f248e205f2df22c4493e649b080162 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 10 Feb 2025 09:28:34 -0600
Subject: [PATCH 30/33] cleanup

---
 hercules_opt/src/fork_transforms.rs      | 106 +----------------------
 juno_samples/fork_join_tests/src/cpu.sch |   3 +
 juno_samples/fork_join_tests/src/main.rs |  11 ++-
 3 files changed, 12 insertions(+), 108 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ad81d90b..80ff7b8f 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,7 +1,6 @@
 use std::collections::{HashMap, HashSet};
 use std::hash::Hash;
 use std::iter::zip;
-use std::thread::ThreadId;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -102,7 +101,7 @@ pub fn find_reduce_dependencies<'a>(
 pub fn copy_subgraph_in_edit<'a, 'b>(
     mut edit: FunctionEdit<'a, 'b>,
     subgraph: HashSet<NodeID>,
-) -> (Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>>) {
+) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> {
     let mut map: HashMap<NodeID, NodeID> = HashMap::new();
 
     // Copy nodes in subgraph
@@ -185,11 +184,8 @@ pub fn find_bufferize_edges(
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     data_label: &LabelID,
 ) -> HashSet<(NodeID, NodeID)> {
-    // println!("func: {:?}", editor.func_id());
     let mut edges: HashSet<_> = HashSet::new();
 
-    println!("ndoes in fork joins: {:?}", &nodes_in_fork_joins[&fork]);
-    // print labels
     for node in &nodes_in_fork_joins[&fork] {
         // Edge from *has data label** to doesn't have data label*
         let node_labels = &editor.func().labels[node.idx()];
@@ -216,7 +212,6 @@ pub fn find_bufferize_edges(
             edges.insert((*node, user));
         }
     }
-    println!("edges: {:?}", edges);
     edges
 }
 
@@ -238,7 +233,6 @@ where
         .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
-    println!("fork_label {:?}", fork_label);
     for l in forks {
         let fork_info = Loop {
             header: l.0,
@@ -247,13 +241,10 @@ where
         let fork = fork_info.header;
         let join = fork_join_map[&fork];
 
-        println!("fork labels: {:?}", editor.func().labels[fork.idx()]);
         if !editor.func().labels[fork.idx()].contains(fork_label) {
             continue;
         }
 
-        println!("fork: {:?}", fork);
-
         let edges = find_bufferize_edges(
             editor,
             fork,
@@ -384,11 +375,6 @@ where
     let edit_result = editor.edit(|edit| {
         let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
 
-        // Put new subgraph after old subgraph
-        // println!("map: {:?}", map);
-        // println!("join: {:?}, fork: {:?}", join, fork);
-        // println!("fork_sccueue: {:?}", join_successor);
-
         edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
         edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
 
@@ -398,8 +384,6 @@ where
                 .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
         }
 
-        // Add buffers to old subgraph
-
         let new_join = map[&join];
         let new_fork = map[&fork];
 
@@ -465,100 +449,12 @@ where
         Ok(edit)
     });
 
-    println!("edit_result: {:?}", edit_result);
     if edit_result == false {
         todo!();
         return None;
     }
 
     Some((fork, new_fork_id))
-
-    // let internal_control: Vec<NodeID> = Vec::new();
-
-    // // Create fork + join + Thread control
-    // let mut new_fork_id = NodeID::new(0);
-    // let mut new_join_id = NodeID::new(0);
-
-    // editor.edit(|mut edit| {
-    //     new_join_id = edit.add_node(Node::Join { control: fork });
-    //     let factors = edit.get_node(fork).try_fork().unwrap().1;
-    //     new_fork_id = edit.add_node(Node::Fork {
-    //         control: new_join_id,
-    //         factors: factors.into(),
-    //     });
-    //     edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
-    // });
-
-    // for (src, dst) in bufferized_edges {
-    //     // FIXME: Disgusting cloning and allocationing and iterators.
-    //     let factors: Vec<_> = editor.func().nodes[fork.idx()]
-    //         .try_fork()
-    //         .unwrap()
-    //         .1
-    //         .iter()
-    //         .cloned()
-    //         .collect();
-    //     editor.edit(|mut edit| {
-    //         // Create write to buffer
-
-    //         let thread_stuff_it = factors.into_iter().enumerate();
-
-    //         // FIxme: try to use unzip here? Idk why it wasn't working.
-    //         let tids = thread_stuff_it.clone().map(|(dim, _)| {
-    //             edit.add_node(Node::ThreadID {
-    //                 control: fork,
-    //                 dimension: dim,
-    //             })
-    //         });
-
-    //         let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
-
-    //         // Assume 1-d fork only for now.
-    //         // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
-    //         let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
-    //         let write = edit.add_node(Node::Write {
-    //             collect: NodeID::new(0),
-    //             data: src,
-    //             indices: vec![position_idx].into(),
-    //         });
-    //         let ele_type = types[src.idx()];
-    //         let empty_buffer = edit.add_type(hercules_ir::Type::Array(
-    //             ele_type,
-    //             array_dims.collect::<Vec<_>>().into_boxed_slice(),
-    //         ));
-    //         let empty_buffer = edit.add_zero_constant(empty_buffer);
-    //         let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
-    //         let reduce = Node::Reduce {
-    //             control: new_join_id,
-    //             init: empty_buffer,
-    //             reduct: write,
-    //         };
-    //         let reduce = edit.add_node(reduce);
-    //         // Fix write node
-    //         edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
-
-    //         // Create read from buffer
-    //         let tids = thread_stuff_it.clone().map(|(dim, _)| {
-    //             edit.add_node(Node::ThreadID {
-    //                 control: new_fork_id,
-    //                 dimension: dim,
-    //             })
-    //         });
-
-    //         let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
-
-    //         let read = edit.add_node(Node::Read {
-    //             collect: reduce,
-    //             indices: vec![position_idx].into(),
-    //         });
-
-    //         edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;
-
-    //         Ok(edit)
-    //     });
-    // }
-
-    // (fork, new_fork_id)
 }
 
 /** Split a 1D fork into a separate fork for each reduction. */
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index fd3fc0ff..38a38c2b 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -74,7 +74,10 @@ dce(auto.test7);
 let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8);
 dce(auto.test8);
 unforkify(auto.test8);
+dce(auto.test8);
 ccp(auto.test8);
 dce(auto.test8);
+simplify-cfg(auto.test8);
+dce(auto.test8);
 
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 4006afa8..1013a1f0 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -43,15 +43,20 @@ fn main() {
         let correct = vec![7i32; 4];
         assert(&correct, output);
 
-        let mut r = runner!(test8);
-        let output = r.run(0).await;
-        let correct = vec![10, 17, 24, 31, 38, 45, 52, 59];
+        let mut r = runner!(test6);
+        let output = r.run(73).await;
+        let correct = (73i32..73i32+1024i32).collect();
         assert(&correct, output);
 
         let mut r = runner!(test7);
         let output = r.run(42).await;
         let correct: i32 = (42i32..42i32+32i32).sum();
         assert_eq!(correct, output);
+
+        let mut r = runner!(test8);
+        let output = r.run(0).await;
+        let correct = vec![10, 17, 24, 31, 38, 45, 52, 59];
+        assert(&correct, output);
     });
 }
 
-- 
GitLab


From 7a9bdc262e76dbedb1849b3e3d318829a8c1fe3a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 10 Feb 2025 09:39:16 -0600
Subject: [PATCH 31/33] gpu

---
 hercules_opt/src/fork_transforms.rs                | 14 ++++++--------
 .../fork_join_tests/src/fork_join_tests.jn         |  6 +++---
 juno_samples/fork_join_tests/src/gpu.sch           | 13 ++++++++++++-
 juno_scheduler/src/pm.rs                           |  3 ---
 4 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 80ff7b8f..b5d3bb28 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -322,8 +322,6 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
 where
     'a: 'b,
 {
-    // Returns the two forks that it generates.
-
     if bufferized_edges.is_empty() {
         return None;
     }
@@ -384,7 +382,6 @@ where
                 .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
         }
 
-        let new_join = map[&join];
         let new_fork = map[&fork];
 
         // FIXME: Do this as part of copy subgraph?
@@ -422,12 +419,14 @@ where
             ));
             let empty_buffer = edit.add_zero_constant(empty_buffer);
             let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
+            edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?;
             let reduce = Node::Reduce {
                 control: join,
                 init: empty_buffer,
                 reduct: write,
             };
             let reduce = edit.add_node(reduce);
+            edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?;
 
             // Fix write node
             edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
@@ -449,12 +448,11 @@ where
         Ok(edit)
     });
 
-    if edit_result == false {
-        todo!();
-        return None;
+    if edit_result {
+        Some((fork, new_fork_id))
+    } else {
+        None
     }
-
-    Some((fork, new_fork_id))
 }
 
 /** Split a 1D fork into a separate fork for each reduction. */
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index a765726f..886ab13b 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -97,8 +97,8 @@ fn test7(input: i32) -> i32 {
 
 #[entry]
 fn test8(input : i32) -> i32[8] {
-  let arr : i32[8];
-  let out : i32[8];
+  @const1 let arr : i32[8];
+  @const2 let out : i32[8];
 
   for i = 0 to 8 {
     arr[i] = i as i32;
@@ -116,4 +116,4 @@ fn test8(input : i32) -> i32[8] {
     out[i] = d;
   }
   return out;
-}
\ No newline at end of file
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index bd1fd1d8..159fac94 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -5,18 +5,21 @@ no-memset(test3@const1);
 no-memset(test3@const2);
 no-memset(test3@const3);
 no-memset(test6@const);
+no-memset(test8@const1);
+no-memset(test8@const2);
 
 gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8);
 gpu(auto.test1);
 gpu(auto.test2);
 gpu(auto.test3);
 gpu(auto.test4);
 gpu(auto.test5);
 gpu(auto.test7);
+gpu(auto.test8);
 
 ip-sroa(*);
 sroa(*);
@@ -50,6 +53,14 @@ dce(auto.test7);
 simplify-cfg(auto.test7);
 dce(auto.test7);
 
+let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8);
+xdot[true](*);
+dce(auto.test8);
+ccp(auto.test8);
+dce(auto.test8);
+simplify-cfg(auto.test8);
+dce(auto.test8);
+
 ip-sroa(*);
 sroa(*);
 dce(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 58ff399f..20825c54 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2191,9 +2191,6 @@ fn run_pass(
             result = Value::Record {
                 fields: new_fork_joins,
             };
-
-            println!("result: {:?}", result);
-
         }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
-- 
GitLab


From 2b48a39eeea07a6654dafff1372e45ccd75a7321 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 10 Feb 2025 09:41:45 -0600
Subject: [PATCH 32/33] wut

---
 juno_samples/matmul/build.rs | 2 --
 1 file changed, 2 deletions(-)

diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index d2813388..0be838c6 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -6,8 +6,6 @@ fn main() {
         JunoCompiler::new()
             .file_in_src("matmul.jn")
             .unwrap()
-            .schedule_in_src("cpu.sch")
-            .unwrap()
             .build()
             .unwrap();
     }
-- 
GitLab


From 17d1c03be21ec990964917f073e0e25b8bae5f66 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 10 Feb 2025 09:42:18 -0600
Subject: [PATCH 33/33] cleanup

---
 hercules_opt/src/fork_transforms.rs | 10 ----------
 1 file changed, 10 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index b5d3bb28..7c423892 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -166,16 +166,6 @@ pub fn copy_subgraph(
     (new_nodes, map, outside_users)
 }
 
-fn is_strict_superset<T: Eq + Hash>(set1: &HashSet<T>, set2: &HashSet<T>) -> bool {
-    // A strict superset must be larger than its subset
-    if set1.len() <= set2.len() {
-        return false;
-    }
-
-    // Every element in set2 must be in set1
-    set2.iter().all(|item| set1.contains(item))
-}
-
 pub fn find_bufferize_edges(
     editor: &mut FunctionEditor,
     fork: NodeID,
-- 
GitLab