From 6a0c4a3410d40410dbdac7a92926583e9c6e640a Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Feb 2025 15:04:24 -0600
Subject: [PATCH 01/25] control in reduce cycle fixes

---
 Cargo.lock                                    | 11 ++++
 Cargo.toml                                    |  3 +-
 hercules_opt/src/forkify.rs                   | 12 +++++
 hercules_opt/src/unforkify.rs                 |  2 +-
 .../hercules_tests/tests/loop_tests.rs        |  4 +-
 juno_scheduler/src/pm.rs                      | 53 +++++++++++--------
 6 files changed, 59 insertions(+), 26 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 49630436..ad69bc72 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1181,6 +1181,17 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_test"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "rand",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_utils"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index ced011a9..46fc7eaa 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,7 +21,7 @@ members = [
 	"hercules_samples/ccp",
 
 	"juno_samples/simple3",
-  "juno_samples/patterns",
+	"juno_samples/patterns",
 	"juno_samples/matmul",
 	"juno_samples/casts_and_intrinsics",
 	"juno_samples/nested_ccp",
@@ -30,4 +30,5 @@ members = [
   	"juno_samples/cava",
 	"juno_samples/concat",
   	"juno_samples/schedule_test",
+	"juno_samples/test",
 ]
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index ec4e9fbc..0f06627d 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -152,6 +152,7 @@ pub fn forkify_loop(
         .filter(|id| !l.control[id.idx()])
         .collect();
 
+    // FIXME: @xrouth
     if loop_preds.len() != 1 {
         return false;
     }
@@ -388,6 +389,7 @@ nest! {
             is_associative: bool,
         },
         LoopDependant(NodeID),
+        ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop.
         UsedByDependant(NodeID),
     }
 }
@@ -398,6 +400,7 @@ impl LoopPHI {
             LoopPHI::Reductionable { phi, .. } => *phi,
             LoopPHI::LoopDependant(node_id) => *node_id,
             LoopPHI::UsedByDependant(node_id) => *node_id,
+            LoopPHI::ControlDependant(node_id) => *node_id,
         }
     }
 }
@@ -415,6 +418,9 @@ pub fn analyze_phis<'a>(
     loop_nodes: &'a HashSet<NodeID>,
 ) -> impl Iterator<Item = LoopPHI> + 'a {
 
+    // We are also moving the phi from the top of the loop (the header),
+    // to the very end (the join). If there are uses of the phi somewhere in the loop,
+    // then they may try to use the phi (now a reduce) before it hits the join. 
     // Find data cycles within the loop of this phi, 
     // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. 
 
@@ -509,6 +515,12 @@ pub fn analyze_phis<'a>(
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
 
+            // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
+            // Which is not allowed.
+            if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
+                return LoopPHI::ControlDependant(*phi);
+            }
+
             // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
             // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
             if intersection
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 85ffd233..7d158d1a 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -133,7 +133,7 @@ pub fn unforkify(
         if factors.len() > 1 {
             // For now, don't convert multi-dimensional fork-joins. Rely on pass
             // that splits fork-joins.
-            continue;
+            break; // Because we have to unforkify top down, we can't unforkify forks that are contained 
         }
         let join_control = nodes[join.idx()].try_join().unwrap();
         let tids: Vec<_> = editor
diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs
index 5832a161..192c1366 100644
--- a/hercules_test/hercules_tests/tests/loop_tests.rs
+++ b/hercules_test/hercules_tests/tests/loop_tests.rs
@@ -401,7 +401,7 @@ fn matmul_pipeline() {
     let dyn_consts = [I, J, K];
 
     // FIXME: This path should not leave the crate
-    let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin");
+    let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin");
     //
     let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
     for i in 0..I {
@@ -425,7 +425,7 @@ fn matmul_pipeline() {
     };
     assert_eq!(correct_c[0], value);
 
-    let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]);
+    let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]);
 
     module = run_schedule_on_hercules(module, schedule).unwrap();
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 2371e0f2..d2772c71 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1471,29 +1471,38 @@ fn run_pass(
         }
         Pass::Forkify => {
             assert!(args.is_empty());
-            pm.make_fork_join_maps();
-            pm.make_control_subgraphs();
-            pm.make_loops();
-            let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            let loops = pm.loops.take().unwrap();
-            let control_subgraphs = pm.control_subgraphs.take().unwrap();
-            for (((func, fork_join_map), loop_nest), control_subgraph) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
-                    .zip(loops.iter())
-                    .zip(control_subgraphs.iter())
-            {
-                let Some(mut func) = func else {
-                    continue;
-                };
-                // TODO: uses direct return from forkify for now instead of
-                // func.modified, see comment on top of `forkify` for why. Fix
-                // this eventually.
-                changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+            loop {
+                let mut inner_changed = false;
+                pm.make_fork_join_maps();
+                pm.make_control_subgraphs();
+                pm.make_loops();
+                let fork_join_maps = pm.fork_join_maps.take().unwrap();
+                let loops = pm.loops.take().unwrap();
+                let control_subgraphs = pm.control_subgraphs.take().unwrap();
+                for (((func, fork_join_map), loop_nest), control_subgraph) in
+                    build_selection(pm, selection.clone())
+                        .into_iter()
+                        .zip(fork_join_maps.iter())
+                        .zip(loops.iter())
+                        .zip(control_subgraphs.iter())
+                {
+                    let Some(mut func) = func else {
+                        continue;
+                    };
+                    // TODO: uses direct return from forkify for now instead of
+                    // func.modified, see comment on top of `forkify` for why. Fix
+                    // this eventually.
+                    let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+                    changed |= c;
+                    inner_changed |= c; 
+                }
+                pm.delete_gravestones();
+                pm.clear_analyses();
+
+                if !inner_changed {
+                    break;
+                }
             }
-            pm.delete_gravestones();
-            pm.clear_analyses();
         }
         Pass::GCM => {
             assert!(args.is_empty());
-- 
GitLab


From fab913636e8b63cef26af9db6df2eb699f415161 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Feb 2025 16:30:10 -0600
Subject: [PATCH 02/25] misc

---
 hercules_opt/src/forkify.rs | 1 -
 1 file changed, 1 deletion(-)

diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 0f06627d..299422c1 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -514,7 +514,6 @@ pub fn analyze_phis<'a>(
             // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
-
             // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
             // Which is not allowed.
             if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
-- 
GitLab


From 43b4022c38f65fcf5c2abad523c63f59888dabe9 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 18:29:52 -0600
Subject: [PATCH 03/25] simple test

---
 Cargo.lock                                    | 10 +++++++
 Cargo.toml                                    |  1 +
 juno_samples/fork_join_tests/Cargo.toml       | 21 +++++++++++++
 juno_samples/fork_join_tests/build.rs         | 24 +++++++++++++++
 .../fork_join_tests/src/fork_join_tests.jn    | 10 +++++++
 juno_samples/fork_join_tests/src/gpu.sch      | 30 +++++++++++++++++++
 juno_samples/fork_join_tests/src/main.rs      | 17 +++++++++++
 7 files changed, 113 insertions(+)
 create mode 100644 juno_samples/fork_join_tests/Cargo.toml
 create mode 100644 juno_samples/fork_join_tests/build.rs
 create mode 100644 juno_samples/fork_join_tests/src/fork_join_tests.jn
 create mode 100644 juno_samples/fork_join_tests/src/gpu.sch
 create mode 100644 juno_samples/fork_join_tests/src/main.rs

diff --git a/Cargo.lock b/Cargo.lock
index b8bf2278..af7902c6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1130,6 +1130,16 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_fork_join_tests"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_frontend"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index f7b9322a..890d7924 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,4 +31,5 @@ members = [
 	"juno_samples/concat",
 	"juno_samples/schedule_test",
 	"juno_samples/edge_detection",
+	"juno_samples/fork_join_tests",
 ]
diff --git a/juno_samples/fork_join_tests/Cargo.toml b/juno_samples/fork_join_tests/Cargo.toml
new file mode 100644
index 00000000..a109e782
--- /dev/null
+++ b/juno_samples/fork_join_tests/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "juno_fork_join_tests"
+version = "0.1.0"
+authors = ["Russel Arbore <rarbore2@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_fork_join_tests"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
diff --git a/juno_samples/fork_join_tests/build.rs b/juno_samples/fork_join_tests/build.rs
new file mode 100644
index 00000000..796e9f32
--- /dev/null
+++ b/juno_samples/fork_join_tests/build.rs
@@ -0,0 +1,24 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    #[cfg(not(feature = "cuda"))]
+    {
+        JunoCompiler::new()
+            .file_in_src("fork_join_tests.jn")
+            .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+    #[cfg(feature = "cuda")]
+    {
+        JunoCompiler::new()
+            .file_in_src("fork_join_tests.jn")
+            .unwrap()
+            .schedule_in_src("gpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+}
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
new file mode 100644
index 00000000..aa8eb4bb
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -0,0 +1,10 @@
+#[entry]
+fn test1(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 4 {
+    for j = 0 to 4 {
+      arr[i, j] = input;
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
new file mode 100644
index 00000000..e2fe980e
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -0,0 +1,30 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+let out = auto-outline(*);
+gpu(out.test1);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+}
+
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  infer-schedules(*);
+}
+xdot[true](*);
+
+gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
new file mode 100644
index 00000000..6e5f2182
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -0,0 +1,17 @@
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
+
+juno_build::juno!("fork_join_tests");
+
+fn main() {
+    async_std::task::block_on(async {
+        let mut r = runner!(tests1);
+        let output = r.run(5).await;
+    });
+}
+
+#[test]
+fn implicit_clone_test() {
+    main();
+}
-- 
GitLab


From 5eeb70cb675bdd7ac2cf4ac4887f803d24ebe165 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 1 Feb 2025 18:52:57 -0600
Subject: [PATCH 04/25] fixes for test

---
 hercules_cg/src/gpu.rs                   |  8 +++---
 juno_samples/fork_join_tests/src/cpu.sch | 31 ++++++++++++++++++++++++
 juno_samples/fork_join_tests/src/gpu.sch |  1 -
 juno_samples/fork_join_tests/src/main.rs | 13 +++++++++-
 juno_scheduler/src/pm.rs                 |  2 ++
 5 files changed, 50 insertions(+), 5 deletions(-)
 create mode 100644 juno_samples/fork_join_tests/src/cpu.sch

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 81e31396..afc016a4 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -539,15 +539,17 @@ namespace cg = cooperative_groups;
         w: &mut String,
     ) -> Result<(), Error> {
         write!(w, "\n")?;
-        for (id, goto) in gotos.iter() {
-            let goto_block = self.get_block_name(*id, false);
+        let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
+        for id in rev_po {
+            let goto = &gotos[&id];
+            let goto_block = self.get_block_name(id, false);
             write!(w, "{}:\n", goto_block)?;
             if goto_debug {
                 write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?;
             }
             write!(w, "{}", goto.init)?;
             if !goto.post_init.is_empty() {
-                let goto_block = self.get_block_name(*id, true);
+                let goto_block = self.get_block_name(id, true);
                 write!(w, "{}:\n", goto_block)?;
                 write!(w, "{}", goto.post_init)?;
             }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
new file mode 100644
index 00000000..81f5a12c
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -0,0 +1,31 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+let out = auto-outline(*);
+cpu(out.test1);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+}
+
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  infer-schedules(*);
+}
+fork-split(*);
+unforkify(*);
+
+gcm(*);
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index e2fe980e..e4e4e04f 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -25,6 +25,5 @@ dce(*);
 fixpoint panic after 20 {
   infer-schedules(*);
 }
-xdot[true](*);
 
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 6e5f2182..a63b3f78 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -6,8 +6,19 @@ juno_build::juno!("fork_join_tests");
 
 fn main() {
     async_std::task::block_on(async {
-        let mut r = runner!(tests1);
+        let mut r = runner!(test1);
         let output = r.run(5).await;
+        let correct = vec![5i32; 16];
+        #[cfg(not(feature = "cuda"))]
+        {
+            assert_eq!(output.as_slice::<i32>(), &correct);
+        }
+        #[cfg(feature = "cuda")]
+        {
+            let mut dst = vec![0i32; 16];
+            let output = output.to_cpu_ref(&mut dst);
+            assert_eq!(output.as_slice::<i32>(), &correct);
+        }
     });
 }
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index f6fe2fc1..5d398804 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -575,6 +575,8 @@ impl PassManager {
         self.postdoms = None;
         self.fork_join_maps = None;
         self.fork_join_nests = None;
+        self.fork_control_maps = None;
+        self.fork_trees = None;
         self.loops = None;
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
-- 
GitLab


From d67263e6a871fa539067ec84e87c58f242809ea4 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 10:50:56 -0600
Subject: [PATCH 05/25] second fork test, fails on cpu w/ forkify/unforkify

---
 hercules_opt/src/float_collections.rs         | 27 ++++++++++++------
 juno_samples/fork_join_tests/src/cpu.sch      |  1 +
 .../fork_join_tests/src/fork_join_tests.jn    | 13 +++++++++
 juno_samples/fork_join_tests/src/gpu.sch      |  2 ++
 juno_samples/fork_join_tests/src/main.rs      | 28 ++++++++++++-------
 juno_scheduler/src/pm.rs                      | 17 +++++------
 6 files changed, 60 insertions(+), 28 deletions(-)

diff --git a/hercules_opt/src/float_collections.rs b/hercules_opt/src/float_collections.rs
index faa38375..6ef050c2 100644
--- a/hercules_opt/src/float_collections.rs
+++ b/hercules_opt/src/float_collections.rs
@@ -1,3 +1,5 @@
+use std::collections::BTreeMap;
+
 use hercules_ir::*;
 
 use crate::*;
@@ -7,27 +9,36 @@ use crate::*;
  * allowed.
  */
 pub fn float_collections(
-    editors: &mut [FunctionEditor],
+    editors: &mut BTreeMap<FunctionID, FunctionEditor>,
     typing: &ModuleTyping,
     callgraph: &CallGraph,
     devices: &Vec<Device>,
 ) {
-    let topo = callgraph.topo();
+    let topo: Vec<_> = callgraph
+        .topo()
+        .into_iter()
+        .filter(|id| editors.contains_key(&id))
+        .collect();
     for to_float_id in topo {
         // Collection constants float until reaching an AsyncRust function.
         if devices[to_float_id.idx()] == Device::AsyncRust {
             continue;
         }
 
+        // Check that all callers are in the selection as well.
+        for caller in callgraph.get_callers(to_float_id) {
+            assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller);
+        }
+
         // Find the target constant nodes in the function.
-        let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()]
+        let cons: Vec<(NodeID, Node)> = editors[&to_float_id]
             .func()
             .nodes
             .iter()
             .enumerate()
             .filter(|(_, node)| {
                 node.try_constant()
-                    .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar())
+                    .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar())
                     .unwrap_or(false)
             })
             .map(|(idx, node)| (NodeID::new(idx), node.clone()))
@@ -37,12 +48,12 @@ pub fn float_collections(
         }
 
         // Each constant node becomes a new parameter.
-        let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone();
+        let mut new_param_types = editors[&to_float_id].func().param_types.clone();
         let old_num_params = new_param_types.len();
         for (id, _) in cons.iter() {
             new_param_types.push(typing[to_float_id.idx()][id.idx()]);
         }
-        let success = editors[to_float_id.idx()].edit(|mut edit| {
+        let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| {
             for (idx, (id, _)) in cons.iter().enumerate() {
                 let param = edit.add_node(Node::Parameter {
                     index: idx + old_num_params,
@@ -59,7 +70,7 @@ pub fn float_collections(
 
         // Add constants in callers and pass them into calls.
         for caller in callgraph.get_callers(to_float_id) {
-            let calls: Vec<(NodeID, Node)> = editors[caller.idx()]
+            let calls: Vec<(NodeID, Node)> = editors[&caller]
                 .func()
                 .nodes
                 .iter()
@@ -71,7 +82,7 @@ pub fn float_collections(
                 })
                 .map(|(idx, node)| (NodeID::new(idx), node.clone()))
                 .collect();
-            let success = editors[caller.idx()].edit(|mut edit| {
+            let success = editors.get_mut(&caller).unwrap().edit(|mut edit| {
                 let cons_ids: Vec<_> = cons
                     .iter()
                     .map(|(_, node)| edit.add_node(node.clone()))
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 81f5a12c..a6b1afe7 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -4,6 +4,7 @@ dce(*);
 
 let out = auto-outline(*);
 cpu(out.test1);
+cpu(out.test2);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index aa8eb4bb..4a6a94c9 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -8,3 +8,16 @@ fn test1(input : i32) -> i32[4, 4] {
   }
   return arr;
 }
+
+#[entry]
+fn test2(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 8 {
+    for j = 0 to 4 {
+      for k = 0 to 4 {
+        arr[j, k] += input;
+      }
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index e4e4e04f..b506c4a4 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -4,6 +4,7 @@ dce(*);
 
 let out = auto-outline(*);
 gpu(out.test1);
+gpu(out.test2);
 
 ip-sroa(*);
 sroa(*);
@@ -26,4 +27,5 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
+float-collections(test2, out.test2);
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index a63b3f78..de1b0805 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -5,20 +5,28 @@ use hercules_rt::runner;
 juno_build::juno!("fork_join_tests");
 
 fn main() {
+    #[cfg(not(feature = "cuda"))]
+    let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| {
+        assert_eq!(output.as_slice::<i32>(), &correct);
+    };
+
+    #[cfg(feature = "cuda")]
+    let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| {
+        let mut dst = vec![0i32; 16];
+        let output = output.to_cpu_ref(&mut dst);
+        assert_eq!(output.as_slice::<i32>(), &correct);
+    };
+
     async_std::task::block_on(async {
         let mut r = runner!(test1);
         let output = r.run(5).await;
         let correct = vec![5i32; 16];
-        #[cfg(not(feature = "cuda"))]
-        {
-            assert_eq!(output.as_slice::<i32>(), &correct);
-        }
-        #[cfg(feature = "cuda")]
-        {
-            let mut dst = vec![0i32; 16];
-            let output = output.to_cpu_ref(&mut dst);
-            assert_eq!(output.as_slice::<i32>(), &correct);
-        }
+        assert(correct, output);
+
+        let mut r = runner!(test2);
+        let output = r.run(3).await;
+        let correct = vec![24i32; 16];
+        assert(correct, output);
     });
 }
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5d398804..1ebc885c 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1305,7 +1305,7 @@ fn run_pass(
     pm: &mut PassManager,
     pass: Pass,
     args: Vec<Value>,
-    selection: Option<Vec<CodeLocation>>,
+    mut selection: Option<Vec<CodeLocation>>,
 ) -> Result<(Value, bool), SchedulerError> {
     let mut result = Value::Record {
         fields: HashMap::new(),
@@ -1441,13 +1441,6 @@ fn run_pass(
         }
         Pass::FloatCollections => {
             assert!(args.is_empty());
-            if let Some(_) = selection {
-                return Err(SchedulerError::PassError {
-                    pass: "floatCollections".to_string(),
-                    error: "must be applied to the entire module".to_string(),
-                });
-            }
-
             pm.make_typing();
             pm.make_callgraph();
             pm.make_devices();
@@ -1455,11 +1448,15 @@ fn run_pass(
             let callgraph = pm.callgraph.take().unwrap();
             let devices = pm.devices.take().unwrap();
 
-            let mut editors = build_editors(pm);
+            // Modify the selection to include callers of selected functions.
+            let mut editors = build_selection(pm, selection)
+                .into_iter()
+                .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor)))
+                .collect();
             float_collections(&mut editors, &typing, &callgraph, &devices);
 
             for func in editors {
-                changed |= func.modified();
+                changed |= func.1.modified();
             }
 
             pm.delete_gravestones();
-- 
GitLab


From ba614011066dcde2305a649f7911368d99f95a84 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 11:12:34 -0600
Subject: [PATCH 06/25] fix reduce cycles

---
 hercules_ir/src/fork_join_analysis.rs    | 34 ++++++++++++------------
 juno_samples/fork_join_tests/src/cpu.sch |  9 +++++++
 2 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs
index 263fa952..7a098a35 100644
--- a/hercules_ir/src/fork_join_analysis.rs
+++ b/hercules_ir/src/fork_join_analysis.rs
@@ -140,23 +140,23 @@ fn reduce_cycle_dfs_helper(
     }
 
     current_visited.insert(iter);
-    let found_reduce = get_uses(&function.nodes[iter.idx()])
-        .as_ref()
-        .into_iter()
-        .any(|u| {
-            !current_visited.contains(u)
-                && !function.nodes[u.idx()].is_control()
-                && isnt_outside_fork_join(*u)
-                && reduce_cycle_dfs_helper(
-                    function,
-                    *u,
-                    fork,
-                    reduce,
-                    current_visited,
-                    in_cycle,
-                    fork_join_nest,
-                )
-        });
+    let mut found_reduce = false;
+
+    // This doesn't short circuit on purpose.
+    for u in get_uses(&function.nodes[iter.idx()]).as_ref() {
+        found_reduce |= !current_visited.contains(u)
+            && !function.nodes[u.idx()].is_control()
+            && isnt_outside_fork_join(*u)
+            && reduce_cycle_dfs_helper(
+                function,
+                *u,
+                fork,
+                reduce,
+                current_visited,
+                in_cycle,
+                fork_join_nest,
+            )
+    }
     if found_reduce {
         in_cycle.insert(iter);
     }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index a6b1afe7..2889cec0 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -18,6 +18,9 @@ fixpoint panic after 20 {
   fork-guard-elim(*);
   fork-coalesce(*);
 }
+gvn(*);
+phi-elim(*);
+dce(*);
 
 gvn(*);
 phi-elim(*);
@@ -27,6 +30,12 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 fork-split(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 unforkify(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 
 gcm(*);
-- 
GitLab


From f2865cbfd0fd7f3a7e84c08e70283c8c0eedefdf Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 11:21:02 -0600
Subject: [PATCH 07/25] interesting test

---
 juno_samples/fork_join_tests/src/cpu.sch      |  1 +
 .../fork_join_tests/src/fork_join_tests.jn    | 23 +++++++++++++++++++
 juno_samples/fork_join_tests/src/gpu.sch      |  1 +
 juno_samples/fork_join_tests/src/main.rs      |  5 ++++
 4 files changed, 30 insertions(+)

diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 2889cec0..0263c275 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -5,6 +5,7 @@ dce(*);
 let out = auto-outline(*);
 cpu(out.test1);
 cpu(out.test2);
+cpu(out.test3);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 4a6a94c9..073cfd1e 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -21,3 +21,26 @@ fn test2(input : i32) -> i32[4, 4] {
   }
   return arr;
 }
+
+#[entry]
+fn test3(input : i32) -> i32[3, 3] {
+  let arr1 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr1[i, j] = (i + j) as i32 + input;
+    }
+  }
+  let arr2 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr2[i, j] = arr1[3 - i, 3 - j];
+    }
+  }
+  let arr3 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr3[i, j] = arr2[i, j] + 7;
+    }
+  }
+  return arr3;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index b506c4a4..80f1bbc9 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -5,6 +5,7 @@ dce(*);
 let out = auto-outline(*);
 gpu(out.test1);
 gpu(out.test2);
+gpu(out.test3);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index de1b0805..4384ecd5 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -27,6 +27,11 @@ fn main() {
         let output = r.run(3).await;
         let correct = vec![24i32; 16];
         assert(correct, output);
+
+        let mut r = runner!(test3);
+        let output = r.run(0).await;
+        let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
+        assert(correct, output);
     });
 }
 
-- 
GitLab


From 912a729ac05bd2077f2ab864cef04e1dddde7667 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 12:34:13 -0600
Subject: [PATCH 08/25] unforkify fixes

---
 hercules_opt/src/unforkify.rs |  6 +++--
 juno_scheduler/src/pm.rs      | 42 ++++++++++++++++++++++-------------
 2 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 7d158d1a..a08d1667 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -118,7 +118,7 @@ pub fn unforkify(
     // control insides of the fork-join should become the successor of the true
     // projection node, and what was the use of the join should become a use of
     // the new region.
-    for l in loop_tree.bottom_up_loops().into_iter().rev() {
+    for l in loop_tree.bottom_up_loops().iter().rev() {
         if !editor.node(l.0).is_fork() {
             continue;
         }
@@ -133,7 +133,8 @@ pub fn unforkify(
         if factors.len() > 1 {
             // For now, don't convert multi-dimensional fork-joins. Rely on pass
             // that splits fork-joins.
-            break; // Because we have to unforkify top down, we can't unforkify forks that are contained 
+            // We can't unforkify, because then the outer forks reduce will depend on non-fork control.
+            break;
         }
         let join_control = nodes[join.idx()].try_join().unwrap();
         let tids: Vec<_> = editor
@@ -293,5 +294,6 @@ pub fn unforkify(
 
             Ok(edit)
         });
+        break;
     }
 }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d2772c71..378d8730 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1815,25 +1815,35 @@ fn run_pass(
         }
         Pass::Unforkify => {
             assert!(args.is_empty());
-            pm.make_fork_join_maps();
-            pm.make_loops();
+            loop {
+                let mut inner_changed = false;
 
-            let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            let loops = pm.loops.take().unwrap();
+                pm.make_fork_join_maps();
+                pm.make_loops();
 
-            for ((func, fork_join_map), loop_tree) in build_selection(pm, selection)
-                .into_iter()
-                .zip(fork_join_maps.iter())
-                .zip(loops.iter())
-            {
-                let Some(mut func) = func else {
-                    continue;
-                };
-                unforkify(&mut func, fork_join_map, loop_tree);
-                changed |= func.modified();
+                let fork_join_maps = pm.fork_join_maps.take().unwrap();
+                let loops = pm.loops.take().unwrap();
+
+                for ((func, fork_join_map), loop_tree) in build_selection(pm, selection.clone())
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(loops.iter())
+                {
+                    let Some(mut func) = func else {
+                        continue;
+                    };
+                    unforkify(&mut func, fork_join_map, loop_tree);
+                    changed |= func.modified();
+                    inner_changed |= func.modified();
+                }
+                pm.delete_gravestones();
+                pm.clear_analyses();
+
+                if !inner_changed {
+                    break;
+                }
+                break;
             }
-            pm.delete_gravestones();
-            pm.clear_analyses();
         }
         Pass::ForkCoalesce => {
             assert!(args.is_empty());
-- 
GitLab


From d390705275493b662dcfa92a8d7ad35551d119a4 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 13:20:04 -0600
Subject: [PATCH 09/25] whoops

---
 juno_samples/fork_join_tests/src/fork_join_tests.jn | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 073cfd1e..3d003f3c 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -33,7 +33,7 @@ fn test3(input : i32) -> i32[3, 3] {
   let arr2 : i32[3, 3];
   for i = 0 to 3 {
     for j = 0 to 3 {
-      arr2[i, j] = arr1[3 - i, 3 - j];
+      arr2[i, j] = arr1[2 - i, 2 - j];
     }
   }
   let arr3 : i32[3, 3];
-- 
GitLab


From ed33189c2fcaf725a8ae34f09d4fb963826ac4b6 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 13:38:57 -0600
Subject: [PATCH 10/25] whoops x2

---
 hercules_ir/src/ir.rs                    | 2 +-
 juno_samples/fork_join_tests/src/gpu.sch | 1 +
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 846347b0..5c575ea1 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
  * list of indices B.
  */
 pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool {
-    if indices_a.len() < indices_b.len() {
+    if indices_a.len() > indices_b.len() {
         return false;
     }
 
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 80f1bbc9..0647d781 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -28,5 +28,6 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
+xdot[true](*);
 float-collections(test2, out.test2);
 gcm(*);
-- 
GitLab


From e1f5634f38cca264a25abe83174434ef333af9c3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 13:59:43 -0600
Subject: [PATCH 11/25] fix to antideps

---
 hercules_opt/src/gcm.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 271bfaf1..b13c919a 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -290,7 +290,7 @@ fn basic_blocks(
                     .collect();
                 for mutator in reverse_postorder.iter() {
                     let mutator_early = schedule_early[mutator.idx()].unwrap();
-                    if dom.does_dom(root_early, mutator_early)
+                    if dom.does_prop_dom(root_early, mutator_early)
                         && (root_early != mutator_early
                             || root_block_iterated_users.contains(&mutator))
                         && mutating_objects(function, func_id, *mutator, objects)
-- 
GitLab


From 7f381ff54caddce24b0808725390a982e27bfdbc Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 14:38:13 -0600
Subject: [PATCH 12/25] hack for gpu

---
 hercules_cg/src/gpu.rs  | 17 ++++++++---------
 hercules_opt/src/gcm.rs | 18 ++++++++++++++++--
 2 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index afc016a4..8f186aa7 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -622,23 +622,23 @@ extern \"C\" {} {}(",
             write!(pass_args, "ret")?;
             write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
         }
-        write!(w, "\tcudaError_t err;\n");
+        write!(w, "\tcudaError_t err;\n")?;
         write!(
             w,
             "\t{}_gpu<<<{}, {}, {}>>>({});\n",
             self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args
         )?;
-        write!(w, "\terr = cudaGetLastError();\n");
+        write!(w, "\terr = cudaGetLastError();\n")?;
         write!(
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
-        );
+        )?;
         write!(w, "\tcudaDeviceSynchronize();\n")?;
-        write!(w, "\terr = cudaGetLastError();\n");
+        write!(w, "\terr = cudaGetLastError();\n")?;
         write!(
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n"
-        );
+        )?;
         if has_ret_var {
             // Copy return from device to host, whether it's primitive value or collection pointer
             write!(w, "\t{} host_ret;\n", ret_type)?;
@@ -1150,7 +1150,8 @@ extern \"C\" {} {}(",
             // for all threads. Otherwise, it can be inside or outside block fork.
             // If inside, it's stored in shared memory so we "allocate" it once
             // and parallelize memset to 0. If outside, we initialize as offset
-            // to backing, but if multi-block grid, don't memset to avoid grid-level sync.
+            // to backing, but if multi-block grid, don't memset to avoid grid-
+            // level sync.
             Node::Constant { id: cons_id } => {
                 let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive();
                 let cg_tile = match state {
@@ -1192,9 +1193,7 @@ extern \"C\" {} {}(",
                     )?;
                 }
                 if !is_primitive
-                    && (state != KernelState::OutBlock
-                        || is_block_parallel.is_none()
-                        || !is_block_parallel.unwrap())
+                    && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
                 {
                     let data_size =
                         self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects));
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index b13c919a..65f7c2d0 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -90,6 +90,7 @@ pub fn gcm(
         loops,
         fork_join_map,
         objects,
+        devices,
     );
 
     let liveness = liveness_dataflow(
@@ -174,6 +175,7 @@ fn basic_blocks(
     loops: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
     objects: &CollectionObjects,
+    devices: &Vec<Device>,
 ) -> BasicBlocks {
     let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
 
@@ -421,9 +423,18 @@ fn basic_blocks(
                 // If the next node further up the dominator tree is in a shallower
                 // loop nest or if we can get out of a reduce loop when we don't
                 // need to be in one, place this data node in a higher-up location.
-                // Only do this is the node isn't a constant or undef.
+                // Only do this is the node isn't a constant or undef - if a
+                // node is a constant or undef, we want its placement to be as
+                // control dependent as possible, even inside loops. In GPU
+                // functions specifically, lift constants that may be returned
+                // outside fork-joins.
                 let is_constant_or_undef =
                     function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef();
+                let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
+                    && objects[&func_id]
+                        .objects(id)
+                        .into_iter()
+                        .any(|obj| objects[&func_id].returned_objects().contains(obj));
                 let old_nest = loops
                     .header_of(location)
                     .map(|header| loops.nesting(header).unwrap());
@@ -444,7 +455,10 @@ fn basic_blocks(
                 // loop use the reduce node forming the loop, so the dominator chain
                 // will consist of one block, and this loop won't ever iterate.
                 let currently_at_join = function.nodes[location.idx()].is_join();
-                if !is_constant_or_undef && (shallower_nest || currently_at_join) {
+
+                if (!is_constant_or_undef || is_gpu_returned)
+                    && (shallower_nest || currently_at_join)
+                {
                     location = control_node;
                 }
             }
-- 
GitLab


From 25725cb1df4ec7b3c59f963c00aa5a2844a52916 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 14:50:58 -0600
Subject: [PATCH 13/25] Ok fix antideps for real this time

---
 hercules_opt/src/gcm.rs  | 19 +++++++++++++++++--
 juno_scheduler/src/pm.rs |  3 +++
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 65f7c2d0..3ff6d2fe 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1,5 +1,5 @@
 use std::cell::Ref;
-use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
+use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
 use std::iter::{empty, once, zip, FromIterator};
 
 use bitvec::prelude::*;
@@ -76,6 +76,7 @@ pub fn gcm(
     dom: &DomTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
     loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
     object_device_demands: &FunctionObjectDeviceDemands,
@@ -88,6 +89,7 @@ pub fn gcm(
         reverse_postorder,
         dom,
         loops,
+        reduce_cycles,
         fork_join_map,
         objects,
         devices,
@@ -173,6 +175,7 @@ fn basic_blocks(
     reverse_postorder: &Vec<NodeID>,
     dom: &DomTree,
     loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     fork_join_map: &HashMap<NodeID, NodeID>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
@@ -246,6 +249,9 @@ fn basic_blocks(
     //    but not forwarding read - forwarding reads are collapsed, and the
     //    bottom read is treated as reading from the transitive parent of the
     //    forwarding read(s).
+    // 3: If the node producing the collection is a reduce node, then any read
+    //    users that aren't in the reduce's cycle shouldn't anti-depend user any
+    //    mutators in the reduce cycle.
     let mut antideps = BTreeSet::new();
     for id in reverse_postorder.iter() {
         // Find a terminating read node and the collections it reads.
@@ -271,6 +277,10 @@ fn basic_blocks(
             // TODO: make this less outrageously inefficient.
             let func_objects = &objects[&func_id];
             for root in roots.iter() {
+                let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles
+                    .get(root)
+                    .map(|cycle| !cycle.contains(&id))
+                    .unwrap_or(false);
                 let root_early = schedule_early[root.idx()].unwrap();
                 let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new();
                 let mut workset = BTreeSet::new();
@@ -292,12 +302,17 @@ fn basic_blocks(
                     .collect();
                 for mutator in reverse_postorder.iter() {
                     let mutator_early = schedule_early[mutator.idx()].unwrap();
-                    if dom.does_prop_dom(root_early, mutator_early)
+                    if dom.does_dom(root_early, mutator_early)
                         && (root_early != mutator_early
                             || root_block_iterated_users.contains(&mutator))
                         && mutating_objects(function, func_id, *mutator, objects)
                             .any(|mutated| read_objs.contains(&mutated))
                         && id != mutator
+                        && (!root_is_reduce_and_read_isnt_in_cycle
+                            || !reduce_cycles
+                                .get(root)
+                                .map(|cycle| cycle.contains(mutator))
+                                .unwrap_or(false))
                     {
                         antideps.insert((*id, *mutator));
                     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 1ebc885c..db4455e8 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1567,6 +1567,7 @@ fn run_pass(
                 pm.make_doms();
                 pm.make_fork_join_maps();
                 pm.make_loops();
+                pm.make_reduce_cycles();
                 pm.make_collection_objects();
                 pm.make_devices();
                 pm.make_object_device_demands();
@@ -1577,6 +1578,7 @@ fn run_pass(
                 let doms = pm.doms.take().unwrap();
                 let fork_join_maps = pm.fork_join_maps.take().unwrap();
                 let loops = pm.loops.take().unwrap();
+                let reduce_cycles = pm.reduce_cycles.take().unwrap();
                 let control_subgraphs = pm.control_subgraphs.take().unwrap();
                 let collection_objects = pm.collection_objects.take().unwrap();
                 let devices = pm.devices.take().unwrap();
@@ -1598,6 +1600,7 @@ fn run_pass(
                         &doms[id.idx()],
                         &fork_join_maps[id.idx()],
                         &loops[id.idx()],
+                        &reduce_cycles[id.idx()],
                         &collection_objects,
                         &devices,
                         &object_device_demands[id.idx()],
-- 
GitLab


From c7d47ec8de0737d73395c06cd3ea5e835b2324d8 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 14:51:22 -0600
Subject: [PATCH 14/25] remove xdot from schedule

---
 juno_samples/fork_join_tests/src/gpu.sch | 1 -
 1 file changed, 1 deletion(-)

diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 0647d781..80f1bbc9 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -28,6 +28,5 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
-xdot[true](*);
 float-collections(test2, out.test2);
 gcm(*);
-- 
GitLab


From 0caa58c5be3deb70cd0ef5dfde77b02a09bfbf2a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Feb 2025 15:28:02 -0600
Subject: [PATCH 15/25] Test requiring outer split + unforkify

---
 .../fork_join_tests/src/fork_join_tests.jn        | 15 +++++++++++++++
 juno_samples/fork_join_tests/src/gpu.sch          |  1 +
 juno_samples/fork_join_tests/src/main.rs          |  5 +++++
 3 files changed, 21 insertions(+)

diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 3d003f3c..55e0a37e 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -44,3 +44,18 @@ fn test3(input : i32) -> i32[3, 3] {
   }
   return arr3;
 }
+
+#[entry]
+fn test4(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 4 {
+    for j = 0 to 4 {
+      let acc = arr[i, j];
+      for k = 0 to 7 {
+        acc += input;
+      }
+      arr[i, j] = acc;
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 80f1bbc9..bf35caea 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -6,6 +6,7 @@ let out = auto-outline(*);
 gpu(out.test1);
 gpu(out.test2);
 gpu(out.test3);
+gpu(out.test4);
 
 ip-sroa(*);
 sroa(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 4384ecd5..cbd42c50 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -32,6 +32,11 @@ fn main() {
         let output = r.run(0).await;
         let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
         assert(correct, output);
+
+        let mut r = runner!(test4);
+        let output = r.run(9).await;
+        let correct = vec![63i32; 16];
+        assert(correct, output);
     });
 }
 
-- 
GitLab


From 50555d655836a3090d605dfe622a9cc1127076a6 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 15:42:08 -0600
Subject: [PATCH 16/25] interpreter fixes + product consts

---
 hercules_test/hercules_interpreter/src/interpreter.rs | 3 +--
 hercules_test/hercules_interpreter/src/value.rs       | 4 ++--
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs
index 871e304a..22ef062a 100644
--- a/hercules_test/hercules_interpreter/src/interpreter.rs
+++ b/hercules_test/hercules_interpreter/src/interpreter.rs
@@ -783,8 +783,7 @@ impl<'a> FunctionExecutionState<'a> {
                                 &self.module.dynamic_constants,
                                 &self.dynamic_constant_params,
                             )
-                        })
-                        .rev();
+                        });
 
                     let n_tokens: usize = factors.clone().product();
 
diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs
index 53911e05..adbed6e6 100644
--- a/hercules_test/hercules_interpreter/src/value.rs
+++ b/hercules_test/hercules_interpreter/src/value.rs
@@ -156,8 +156,8 @@ impl<'a> InterpreterVal {
             Constant::Float64(v) => Self::Float64(v),
 
             Constant::Product(ref type_id, ref constant_ids) => {
-                // Self::Product((), ())
-                todo!()
+                let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params));
+                InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice())
             }
             Constant::Summation(_, _, _) => todo!(),
             Constant::Array(type_id) => {
-- 
GitLab


From acf060c7daf486ef28a68a292beda3132850891d Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 15:42:26 -0600
Subject: [PATCH 17/25] read schedule from file

---
 juno_scheduler/src/lib.rs | 36 ++++++++++++++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 571d1fbf..2479af98 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -146,6 +146,42 @@ pub fn run_schedule_on_hercules(
     .map_err(|e| format!("Scheduling Error: {}", e))
 }
 
+
+pub fn run_schedule_from_file_on_hercules(
+    module: Module,
+    sched_filename: Option<String>,
+) -> Result<Module, String> {
+    let sched = process_schedule(sched_filename)?;
+
+    // Prepare the scheduler's string table and environment
+    // For this, we put all of the Hercules function names into the environment
+    // and string table
+    let mut strings = StringTable::new();
+    let mut env = Env::new();
+
+    env.open_scope();
+
+    for (idx, func) in module.functions.iter().enumerate() {
+        let func_name = strings.lookup_string(func.name.clone());
+        env.insert(
+            func_name,
+            Value::HerculesFunction {
+                func: FunctionID::new(idx),
+            },
+        );
+    }
+
+    env.open_scope();
+    schedule_module(
+        module,
+        sched,
+        strings,
+        env,
+        JunoFunctions { func_ids: vec![] },
+    )
+    .map_err(|e| format!("Scheduling Error: {}", e))
+}
+
 pub fn schedule_hercules(
     module: Module,
     sched_filename: Option<String>,
-- 
GitLab


From 61bc9ae455c3933f930d7a980a4cc4efc39afb6f Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 15:43:08 -0600
Subject: [PATCH 18/25] unforkify fix

---
 hercules_opt/src/fork_guard_elim.rs | 8 ++++----
 juno_scheduler/src/pm.rs            | 1 -
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs
index 052fd0e4..f6914b74 100644
--- a/hercules_opt/src/fork_guard_elim.rs
+++ b/hercules_opt/src/fork_guard_elim.rs
@@ -39,7 +39,6 @@ struct GuardedFork {
     guard_if: NodeID,
     fork_taken_proj: NodeID,
     fork_skipped_proj: NodeID,
-    guard_pred: NodeID,
     guard_join_region: NodeID,
     phi_reduce_map: HashMap<NodeID, NodeID>,
     factor: Factor, // The factor that matches the guard
@@ -302,7 +301,6 @@ fn guarded_fork(
         guard_if: if_node,
         fork_taken_proj: *control,
         fork_skipped_proj: other_pred,
-        guard_pred: if_pred,
         guard_join_region: join_control,
         phi_reduce_map: phi_nodes,
         factor,
@@ -323,13 +321,15 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node
         join,
         fork_taken_proj,
         fork_skipped_proj,
-        guard_pred,
         phi_reduce_map,
         factor,
         guard_if,
         guard_join_region,
     } in guard_info
-    {
+    {   
+        let Some(guard_pred) = editor.get_uses(guard_if).next() else {
+            unreachable!()
+        };
         let new_fork_info = if let Factor::Max(idx, dc) = factor {
             let Node::Fork {
                 control: _,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index db3904f7..dd2ae73a 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1895,7 +1895,6 @@ fn run_pass(
                 if !inner_changed {
                     break;
                 }
-                break;
             }
         }
         Pass::ForkCoalesce => {
-- 
GitLab


From 82ab7c076640225cf0ff4de2b8443a46dcc9b95f Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 22:39:36 -0600
Subject: [PATCH 19/25] fork dim merge

---
 hercules_opt/src/fork_transforms.rs | 161 ++++++++++++++++++++++++++++
 juno_scheduler/src/compile.rs       |   3 +-
 juno_scheduler/src/ir.rs            |   2 +
 juno_scheduler/src/lib.rs           |   3 +-
 juno_scheduler/src/pm.rs            |  30 ++++--
 5 files changed, 189 insertions(+), 10 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e23f586f..58ace775 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,5 +1,6 @@
 use std::collections::{HashMap, HashSet};
 use std::iter::zip;
+use std::thread::ThreadId;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -693,3 +694,163 @@ pub(crate) fn split_fork(
         None
     }
 }
+
+// Splits a dimension of a single fork join into multiple. 
+// Iterates an outer loop original_dim / tile_size times 
+// adds a tile_size loop as the inner loop 
+// Assumes that tile size divides original dim evenly.
+pub fn chunk_fork_unguarded(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    dim_idx: usize,
+    tile_size: DynamicConstantID,
+) -> () {
+    // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
+
+    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
+
+    let mut new_factors: Vec<_> = old_factors.to_vec();
+
+    let fork_users: Vec<_> = editor.get_users(fork).collect();
+
+
+    for tid in fork_users {
+        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+        editor.edit(|mut edit| {
+            if tid_dim > dim_idx {
+                let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let new_tid = edit.add_node(new_tid);
+                edit.replace_all_uses(tid, new_tid)
+            } else if tid_dim == dim_idx {
+                let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let tile_tid = edit.add_node(tile_tid);
+                
+                let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
+                let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
+                let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
+                edit.replace_all_uses_where(tid, add, |usee| *usee != mul )
+            } else {
+                Ok(edit)
+            }
+        });
+    }
+
+    editor.edit(|mut edit| {
+        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
+        new_factors.insert(dim_idx + 1, tile_size);
+        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
+
+        edit.replace_all_uses(fork, new_fork)
+    });
+}
+
+
+pub fn merge_all_fork_dims(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+) {
+    for (fork, _) in fork_join_map {
+        let Node::Fork { control: _, factors: dims } = editor.node(fork) else {
+            unreachable!();
+        };
+
+        let mut fork = *fork;
+        for _ in 0..dims.len() - 1 {
+            let outer = 0;
+            let inner = 1;
+            fork = fork_dim_merge(editor, fork, outer, inner);
+        }
+    }
+}
+
+// Splits a dimension of a single fork join into multiple. 
+// Iterates an outer loop original_dim / tile_size times 
+// adds a tile_size loop as the inner loop 
+// Assumes that tile size divides original dim evenly.
+pub fn fork_dim_merge(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    dim_idx1: usize,
+    dim_idx2: usize,
+) -> NodeID {
+    // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2)
+    // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2)
+    assert_ne!(dim_idx1, dim_idx2);
+
+    // Outer is smaller, and also closer to the left of the factors array.
+    let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 {
+        (dim_idx2, dim_idx1)
+    } else {
+        (dim_idx1, dim_idx2)
+    };
+
+    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork};
+
+    let mut new_factors: Vec<_> = old_factors.to_vec();
+
+    
+
+    let fork_users: Vec<_> = editor.get_users(fork).collect();
+
+    let mut new_nodes = vec![];
+
+    let outer_dc_id = new_factors[outer_idx];
+    let inner_dc_id = new_factors[inner_idx];
+
+    let mut new_fork_id = NodeID::new(0);
+    editor.edit(|mut edit| {
+        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
+        new_factors.remove(inner_idx);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
+        new_fork_id = new_fork;
+
+
+        edit = edit.replace_all_uses(fork, new_fork)?;
+        edit.delete_node(fork)
+    });
+
+    
+
+    for tid in fork_users {
+        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+
+        println!("tid: {:?}", tid);
+        editor.edit(|mut edit| {
+            if tid_dim > inner_idx {
+                let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
+                let new_tid = edit.add_node(new_tid);
+                edit.replace_all_uses(tid, new_tid)
+            } else if tid_dim == outer_idx {
+                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = edit.add_node(outer_tid);
+
+                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
+                new_nodes.push(outer_tid);
+
+                // inner_idx % dim(outer_idx)
+                let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
+                
+                edit.replace_all_uses(tid, rem)
+            } else if tid_dim == inner_idx {
+                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = edit.add_node(outer_tid);
+
+                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
+                // inner_idx / dim(outer_idx)
+                let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
+                
+                edit.replace_all_uses(tid, div)
+            } else {
+                Ok(edit)
+            }
+        });
+    };
+
+    return new_fork_id;
+
+}
\ No newline at end of file
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 11a8ec53..07ad5e7a 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -108,7 +108,8 @@ impl FromStr for Appliable {
             "inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
-            }
+            },
+            "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index d6a41baf..939ef925 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -12,6 +12,8 @@ pub enum Pass {
     ForkSplit,
     ForkCoalesce,
     Forkify,
+    ForkDimMerge,
+    ForkChunk,
     GCM,
     GVN,
     InferSchedules,
diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 2479af98..ad9195fb 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -60,7 +60,7 @@ fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> {
     }
 }
 
-fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> {
+pub fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> {
     if let Some(name) = sched_filename {
         build_schedule(name)
     } else {
@@ -146,7 +146,6 @@ pub fn run_schedule_on_hercules(
     .map_err(|e| format!("Scheduling Error: {}", e))
 }
 
-
 pub fn run_schedule_from_file_on_hercules(
     module: Module,
     sched_filename: Option<String>,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 43f355c3..8b71d24e 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1871,14 +1871,11 @@ fn run_pass(
         }
         Pass::Unforkify => {
             assert!(args.is_empty());
-            loop {
-                let mut inner_changed = false;
-
-                pm.make_fork_join_maps();
-                pm.make_loops();
+            pm.make_fork_join_maps();
+            pm.make_loops();
 
-                let fork_join_maps = pm.fork_join_maps.take().unwrap();
-                let loops = pm.loops.take().unwrap();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
 
             for ((func, fork_join_map), loop_tree) in build_selection(pm, selection)
                 .into_iter()
@@ -1894,6 +1891,24 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkDimMerge => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                merge_all_fork_dims(&mut func, fork_join_map);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkCoalesce => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
@@ -1991,6 +2006,7 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
+        Pass::ForkChunk => todo!(),
     }
     println!("Ran Pass: {:?}", pass);
 
-- 
GitLab


From 1788bd1f2ba4ec496cb5afa62da884f92c1761a4 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Mon, 3 Feb 2025 11:10:10 -0600
Subject: [PATCH 20/25] tiling + dim merge with one edit per loop dim

---
 hercules_opt/src/fork_transforms.rs | 96 ++++++++++++++++-------------
 juno_scheduler/src/compile.rs       |  1 +
 juno_scheduler/src/ir.rs            |  1 +
 juno_scheduler/src/pm.rs            | 34 ++++++++++
 4 files changed, 90 insertions(+), 42 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 58ace775..cbb09bbf 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -695,6 +695,24 @@ pub(crate) fn split_fork(
     }
 }
 
+pub fn chunk_all_forks_unguarded(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    dim_idx: usize,
+    tile_size: usize,
+) -> () {
+    // Add dc
+    let mut dc_id = DynamicConstantID::new(0);
+    editor.edit(|mut edit| {
+        dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size));
+        Ok(edit)
+    });
+
+    for (fork, _ ) in fork_join_map {
+        chunk_fork_unguarded(editor, *fork, dim_idx, dc_id);
+    }
+
+}
 // Splits a dimension of a single fork join into multiple. 
 // Iterates an outer loop original_dim / tile_size times 
 // adds a tile_size loop as the inner loop 
@@ -711,39 +729,36 @@ pub fn chunk_fork_unguarded(
 
     let mut new_factors: Vec<_> = old_factors.to_vec();
 
-    let fork_users: Vec<_> = editor.get_users(fork).collect();
+    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
+
+    editor.edit(|mut edit| {
+        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
+        new_factors.insert(dim_idx + 1, tile_size);
+        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
 
+        edit = edit.replace_all_uses(fork, new_fork)?;
 
-    for tid in fork_users {
-        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
-        editor.edit(|mut edit| {
+        for (tid, node) in fork_users {
+            let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue};
             if tid_dim > dim_idx {
-                let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
                 let new_tid = edit.add_node(new_tid);
-                edit.replace_all_uses(tid, new_tid)
+                edit = edit.replace_all_uses(tid, new_tid)?;
             } else if tid_dim == dim_idx {
-                let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
                 let tile_tid = edit.add_node(tile_tid);
                 
                 let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                 let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
                 let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
-                edit.replace_all_uses_where(tid, add, |usee| *usee != mul )
-            } else {
-                Ok(edit)
+                edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?;
             }
-        });
-    }
-
-    editor.edit(|mut edit| {
-        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
-        new_factors.insert(dim_idx + 1, tile_size);
-        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
-
-        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
-        let new_fork = edit.add_node(new_fork);
-
-        edit.replace_all_uses(fork, new_fork)
+        }
+        edit = edit.delete_node(fork)?;
+        Ok(edit)
     });
 }
 
@@ -791,9 +806,8 @@ pub fn fork_dim_merge(
 
     let mut new_factors: Vec<_> = old_factors.to_vec();
 
-    
 
-    let fork_users: Vec<_> = editor.get_users(fork).collect();
+    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
 
     let mut new_nodes = vec![];
 
@@ -801,6 +815,7 @@ pub fn fork_dim_merge(
     let inner_dc_id = new_factors[inner_idx];
 
     let mut new_fork_id = NodeID::new(0);
+
     editor.edit(|mut edit| {
         new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
         new_factors.remove(inner_idx);
@@ -809,22 +824,20 @@ pub fn fork_dim_merge(
         let new_fork = edit.add_node(new_fork);
         new_fork_id = new_fork;
 
+        edit.sub_edit(fork, new_fork);
 
         edit = edit.replace_all_uses(fork, new_fork)?;
-        edit.delete_node(fork)
-    });
-
-    
-
-    for tid in fork_users {
-        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+        edit = edit.delete_node(fork)?;
 
-        println!("tid: {:?}", tid);
-        editor.edit(|mut edit| {
+        for (tid, node) in fork_users {
+            // FIXME: DO we want sub edits in this?
+            
+            let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue };
             if tid_dim > inner_idx {
                 let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
                 let new_tid = edit.add_node(new_tid);
-                edit.replace_all_uses(tid, new_tid)
+                edit = edit.replace_all_uses(tid, new_tid)?;
+                edit.sub_edit(tid, new_tid);
             } else if tid_dim == outer_idx {
                 let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
                 let outer_tid = edit.add_node(outer_tid);
@@ -834,8 +847,8 @@ pub fn fork_dim_merge(
 
                 // inner_idx % dim(outer_idx)
                 let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
-                
-                edit.replace_all_uses(tid, rem)
+                edit.sub_edit(tid, rem);
+                edit = edit.replace_all_uses(tid, rem)?;
             } else if tid_dim == inner_idx {
                 let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
                 let outer_tid = edit.add_node(outer_tid);
@@ -843,13 +856,12 @@ pub fn fork_dim_merge(
                 let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                 // inner_idx / dim(outer_idx)
                 let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
-                
-                edit.replace_all_uses(tid, div)
-            } else {
-                Ok(edit)
+                edit.sub_edit(tid, div);
+                edit = edit.replace_all_uses(tid, div)?;
             }
-        });
-    };
+        }
+        Ok(edit)
+    });
 
     return new_fork_id;
 
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 07ad5e7a..49dedd2b 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -110,6 +110,7 @@ impl FromStr for Appliable {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             },
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
+            "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 939ef925..796437a7 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -36,6 +36,7 @@ impl Pass {
     pub fn num_args(&self) -> usize {
         match self {
             Pass::Xdot => 1,
+            Pass::ForkChunk => 3, 
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 8b71d24e..5740d2a6 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1891,6 +1891,40 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkChunk => {
+            assert_eq!(args.len(), 3);
+            let tile_size = args.get(0);
+            let dim_idx = args.get(1);
+
+            let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
+                panic!(); // How to error here?
+            };
+
+            let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
+                panic!(); // How to error here?
+            };
+
+            let Some(Value::Integer { val: tile_size }) = args.get(0) else {
+                panic!(); // How to error here?
+            };
+
+            assert_eq!(*guarded_flag, true);
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
-- 
GitLab


From 27731fa4fe3553ba1a93a98115636b60a4dd00ce Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Mon, 3 Feb 2025 11:15:56 -0600
Subject: [PATCH 21/25] check for out of bounds dim on chunking

---
 hercules_opt/src/fork_transforms.rs | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index cbb09bbf..190dbd25 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -725,8 +725,13 @@ pub fn chunk_fork_unguarded(
 ) -> () {
     // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
 
+
     let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
 
+    if dim_idx >= old_factors.len()  {
+        return; // FIXME Error here?
+    }
+
     let mut new_factors: Vec<_> = old_factors.to_vec();
 
     let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
-- 
GitLab


From 1f1c6cb94c80ac867958c41d2a84d506bf06ef92 Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Mon, 3 Feb 2025 11:46:58 -0600
Subject: [PATCH 22/25] rewrite forkify as single edit per loop

---
 hercules_opt/src/forkify.rs | 75 ++++++++++++++++++++-----------------
 1 file changed, 40 insertions(+), 35 deletions(-)

diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 38b9aaaa..0a2d5601 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -298,29 +298,11 @@ pub fn forkify_loop(
     let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap();
     let dimension = factors.len() - 1;
 
-    // Create ThreadID
-    editor.edit(|mut edit| {
-        let thread_id = Node::ThreadID {
-            control: fork_id,
-            dimension: dimension,
-        };
-        let thread_id_id = edit.add_node(thread_id);
-
-        // Replace uses that are inside with the thread id
-        edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
-            loop_nodes.contains(node)
-        })?;
+    // Start failable edit:
 
-        // Replace uses that are outside with DC - 1. Or just give up.
-        let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id });
-        edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| {
-            !loop_nodes.contains(node)
-        })?;
-
-        edit.delete_node(canonical_iv.phi())
-    });
-
-    for reduction_phi in reductionable_phis {
+    let redcutionable_phis_and_init: Vec<(_, NodeID)> = 
+        reductionable_phis.iter().map(|reduction_phi| {
+        
         let LoopPHI::Reductionable {
             phi,
             data_cycle: _,
@@ -342,12 +324,41 @@ pub fn forkify_loop(
         .unwrap()
         .1;
 
-        editor.edit(|mut edit| {
+        (reduction_phi, init)
+    }).collect();
+
+    editor.edit(|mut edit| {
+        let thread_id = Node::ThreadID {
+            control: fork_id,
+            dimension: dimension,
+        };
+        let thread_id_id = edit.add_node(thread_id);
+
+        // Replace uses that are inside with the thread id
+        edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
+            loop_nodes.contains(node)
+        })?;
+
+        edit = edit.delete_node(canonical_iv.phi())?;
+
+        for (reduction_phi, init) in redcutionable_phis_and_init {
+            let LoopPHI::Reductionable {
+                phi,
+                data_cycle: _,
+                continue_latch,
+                is_associative: _,
+            } = *reduction_phi
+            else {
+                panic!();
+            };
+            
             let reduce = Node::Reduce {
                 control: join_id,
                 init,
                 reduct: continue_latch,
             };
+            
+            
             let reduce_id = edit.add_node(reduce);
 
             if (!edit.get_node(init).is_reduce()
@@ -375,20 +386,14 @@ pub fn forkify_loop(
             edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| {
                 !loop_nodes.contains(usee) && *usee != reduce_id
             })?;
-            edit.delete_node(phi)
-        });
-    }
-
-    // Replace all uses of the loop header with the fork
-    editor.edit(|edit| edit.replace_all_uses(l.header, fork_id));
+            edit = edit.delete_node(phi)?
 
-    editor.edit(|edit| edit.replace_all_uses(loop_continue_projection, fork_id));
+        }
 
-    editor.edit(|edit| edit.replace_all_uses(loop_exit_projection, join_id));
+        edit = edit.replace_all_uses(l.header, fork_id)?;
+        edit = edit.replace_all_uses(loop_continue_projection, fork_id)?;
+        edit = edit.replace_all_uses(loop_exit_projection, join_id)?;
 
-    // Get rid of loop condition
-    // DCE should get these, but delete them ourselves because we are nice :)
-    editor.edit(|mut edit| {
         edit = edit.delete_node(loop_continue_projection)?;
         edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this.
         edit = edit.delete_node(loop_exit_projection)?;
@@ -396,7 +401,7 @@ pub fn forkify_loop(
         edit = edit.delete_node(l.header)?;
         Ok(edit)
     });
-
+    
     return true;
 }
 
-- 
GitLab


From 6382ef4263b16f54a8d3b4d5e3a795c9c9e11013 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 3 Feb 2025 15:44:25 -0600
Subject: [PATCH 23/25] fix ups

---
 Cargo.lock                                    |  11 --
 hercules_cg/src/fork_tree.rs                  |  19 +-
 hercules_opt/src/fork_guard_elim.rs           |  19 +-
 hercules_opt/src/fork_transforms.rs           | 166 +++++++++++-------
 hercules_opt/src/forkify.rs                   |  61 ++++---
 hercules_samples/dot/build.rs                 |   6 +-
 hercules_samples/dot/src/main.rs              |   2 +-
 hercules_samples/matmul/build.rs              |   6 +-
 hercules_samples/matmul/src/main.rs           |   6 +-
 .../hercules_interpreter/src/interpreter.rs   |  76 ++++----
 .../hercules_interpreter/src/value.rs         |  10 +-
 .../hercules_tests/tests/loop_tests.rs        |  40 ++---
 juno_frontend/src/semant.rs                   |  11 +-
 juno_samples/cava/src/main.rs                 |  45 ++---
 juno_samples/concat/src/main.rs               |   4 +-
 juno_samples/edge_detection/src/main.rs       |  11 +-
 juno_samples/matmul/src/main.rs               |  18 +-
 juno_samples/nested_ccp/src/main.rs           |   2 +-
 juno_samples/patterns/src/main.rs             |   2 +-
 juno_samples/schedule_test/build.rs           |   6 +-
 juno_samples/schedule_test/src/main.rs        |  13 +-
 juno_samples/simple3/src/main.rs              |   2 +-
 juno_scheduler/src/compile.rs                 |   2 +-
 juno_scheduler/src/ir.rs                      |   2 +-
 juno_scheduler/src/pm.rs                      |  31 ++--
 25 files changed, 336 insertions(+), 235 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 7a70825a..af7902c6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1236,17 +1236,6 @@ dependencies = [
  "with_builtin_macros",
 ]
 
-[[package]]
-name = "juno_test"
-version = "0.1.0"
-dependencies = [
- "async-std",
- "hercules_rt",
- "juno_build",
- "rand",
- "with_builtin_macros",
-]
-
 [[package]]
 name = "juno_utils"
 version = "0.1.0"
diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs
index 64a93160..c048f7e3 100644
--- a/hercules_cg/src/fork_tree.rs
+++ b/hercules_cg/src/fork_tree.rs
@@ -9,11 +9,16 @@ use crate::*;
  * c) no domination by any other fork that's also dominated by F, where we do count self-domination
  * Here too we include the non-fork start node, as key for all controls outside any fork.
  */
-pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+pub fn fork_control_map(
+    fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>,
+) -> HashMap<NodeID, HashSet<NodeID>> {
     let mut fork_control_map = HashMap::new();
     for (control, forks) in fork_join_nesting {
         let fork = forks.first().copied().unwrap_or(NodeID::new(0));
-        fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control);
+        fork_control_map
+            .entry(fork)
+            .or_insert_with(HashSet::new)
+            .insert(*control);
     }
     fork_control_map
 }
@@ -24,13 +29,19 @@ pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> Has
  * c) no domination by any other fork that's also dominated by F, where we don't count self-domination
  * Note that the fork_tree also includes the non-fork start node, as unique root node.
  */
-pub fn fork_tree(function: &Function, fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+pub fn fork_tree(
+    function: &Function,
+    fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>,
+) -> HashMap<NodeID, HashSet<NodeID>> {
     let mut fork_tree = HashMap::new();
     for (control, forks) in fork_join_nesting {
         if function.nodes[control.idx()].is_fork() {
             fork_tree.entry(*control).or_insert_with(HashSet::new);
             let nesting_fork = forks.get(1).copied().unwrap_or(NodeID::new(0));
-            fork_tree.entry(nesting_fork).or_insert_with(HashSet::new).insert(*control);
+            fork_tree
+                .entry(nesting_fork)
+                .or_insert_with(HashSet::new)
+                .insert(*control);
         }
     }
     fork_tree
diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs
index f6914b74..df40e60f 100644
--- a/hercules_opt/src/fork_guard_elim.rs
+++ b/hercules_opt/src/fork_guard_elim.rs
@@ -76,13 +76,16 @@ fn guarded_fork(
         };
 
         // Filter out any terms which are just 1s
-        let non_ones = xs.iter().filter(|i| {
-            if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() {
-                false
-            } else {
-                true
-            }
-        }).collect::<Vec<_>>();
+        let non_ones = xs
+            .iter()
+            .filter(|i| {
+                if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() {
+                    false
+                } else {
+                    true
+                }
+            })
+            .collect::<Vec<_>>();
         // If we're left with just one term x, we had max { 1, x }
         if non_ones.len() == 1 {
             Factor::Max(idx, *non_ones[0])
@@ -326,7 +329,7 @@ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<Node
         guard_if,
         guard_join_region,
     } in guard_info
-    {   
+    {
         let Some(guard_pred) = editor.get_uses(guard_if).next() else {
             unreachable!()
         };
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 190dbd25..ed6283fd 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -708,14 +708,13 @@ pub fn chunk_all_forks_unguarded(
         Ok(edit)
     });
 
-    for (fork, _ ) in fork_join_map {
+    for (fork, _) in fork_join_map {
         chunk_fork_unguarded(editor, *fork, dim_idx, dc_id);
     }
-
 }
-// Splits a dimension of a single fork join into multiple. 
-// Iterates an outer loop original_dim / tile_size times 
-// adds a tile_size loop as the inner loop 
+// Splits a dimension of a single fork join into multiple.
+// Iterates an outer loop original_dim / tile_size times
+// adds a tile_size loop as the inner loop
 // Assumes that tile size divides original dim evenly.
 pub fn chunk_fork_unguarded(
     editor: &mut FunctionEditor,
@@ -724,42 +723,68 @@ pub fn chunk_fork_unguarded(
     tile_size: DynamicConstantID,
 ) -> () {
     // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
-
-
-    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
-
-    if dim_idx >= old_factors.len()  {
-        return; // FIXME Error here?
-    }
-
+    let Node::Fork {
+        control: old_control,
+        factors: ref old_factors,
+    } = *editor.node(fork)
+    else {
+        return;
+    };
+    assert!(dim_idx < old_factors.len());
     let mut new_factors: Vec<_> = old_factors.to_vec();
-
-    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
+    let fork_users: Vec<_> = editor
+        .get_users(fork)
+        .map(|f| (f, editor.node(f).clone()))
+        .collect();
 
     editor.edit(|mut edit| {
         let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
         new_factors.insert(dim_idx + 1, tile_size);
         new_factors[dim_idx] = edit.add_dynamic_constant(outer);
 
-        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = Node::Fork {
+            control: old_control,
+            factors: new_factors.into(),
+        };
         let new_fork = edit.add_node(new_fork);
 
         edit = edit.replace_all_uses(fork, new_fork)?;
 
         for (tid, node) in fork_users {
-            let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue};
+            let Node::ThreadID {
+                control: _,
+                dimension: tid_dim,
+            } = node
+            else {
+                continue;
+            };
             if tid_dim > dim_idx {
-                let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
+                let new_tid = Node::ThreadID {
+                    control: new_fork,
+                    dimension: tid_dim + 1,
+                };
                 let new_tid = edit.add_node(new_tid);
                 edit = edit.replace_all_uses(tid, new_tid)?;
+                edit = edit.delete_node(tid)?;
             } else if tid_dim == dim_idx {
-                let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
+                let tile_tid = Node::ThreadID {
+                    control: new_fork,
+                    dimension: tid_dim + 1,
+                };
                 let tile_tid = edit.add_node(tile_tid);
-                
+
                 let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
-                let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
-                let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
-                edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?;
+                let mul = edit.add_node(Node::Binary {
+                    left: tid,
+                    right: tile_size,
+                    op: BinaryOperator::Mul,
+                });
+                let add = edit.add_node(Node::Binary {
+                    left: mul,
+                    right: tile_tid,
+                    op: BinaryOperator::Add,
+                });
+                edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
             }
         }
         edit = edit.delete_node(fork)?;
@@ -767,13 +792,13 @@ pub fn chunk_fork_unguarded(
     });
 }
 
-
-pub fn merge_all_fork_dims(
-    editor: &mut FunctionEditor,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-) {
+pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
     for (fork, _) in fork_join_map {
-        let Node::Fork { control: _, factors: dims } = editor.node(fork) else {
+        let Node::Fork {
+            control: _,
+            factors: dims,
+        } = editor.node(fork)
+        else {
             unreachable!();
         };
 
@@ -786,10 +811,6 @@ pub fn merge_all_fork_dims(
     }
 }
 
-// Splits a dimension of a single fork join into multiple. 
-// Iterates an outer loop original_dim / tile_size times 
-// adds a tile_size loop as the inner loop 
-// Assumes that tile size divides original dim evenly.
 pub fn fork_dim_merge(
     editor: &mut FunctionEditor,
     fork: NodeID,
@@ -806,61 +827,85 @@ pub fn fork_dim_merge(
     } else {
         (dim_idx1, dim_idx2)
     };
-
-    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork};
-
+    let Node::Fork {
+        control: old_control,
+        factors: ref old_factors,
+    } = *editor.node(fork)
+    else {
+        return fork;
+    };
     let mut new_factors: Vec<_> = old_factors.to_vec();
-
-
-    let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
-
+    let fork_users: Vec<_> = editor
+        .get_users(fork)
+        .map(|f| (f, editor.node(f).clone()))
+        .collect();
     let mut new_nodes = vec![];
-
     let outer_dc_id = new_factors[outer_idx];
     let inner_dc_id = new_factors[inner_idx];
-
-    let mut new_fork_id = NodeID::new(0);
+    let mut new_fork = NodeID::new(0);
 
     editor.edit(|mut edit| {
-        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
+        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(
+            new_factors[outer_idx],
+            new_factors[inner_idx],
+        ));
         new_factors.remove(inner_idx);
-
-        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
-        let new_fork = edit.add_node(new_fork);
-        new_fork_id = new_fork;
-
+        new_fork = edit.add_node(Node::Fork {
+            control: old_control,
+            factors: new_factors.into(),
+        });
         edit.sub_edit(fork, new_fork);
-
         edit = edit.replace_all_uses(fork, new_fork)?;
         edit = edit.delete_node(fork)?;
 
         for (tid, node) in fork_users {
-            // FIXME: DO we want sub edits in this?
-            
-            let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue };
+            let Node::ThreadID {
+                control: _,
+                dimension: tid_dim,
+            } = node
+            else {
+                continue;
+            };
             if tid_dim > inner_idx {
-                let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
+                let new_tid = Node::ThreadID {
+                    control: new_fork_id,
+                    dimension: tid_dim - 1,
+                };
                 let new_tid = edit.add_node(new_tid);
                 edit = edit.replace_all_uses(tid, new_tid)?;
                 edit.sub_edit(tid, new_tid);
             } else if tid_dim == outer_idx {
-                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = Node::ThreadID {
+                    control: new_fork_id,
+                    dimension: outer_idx,
+                };
                 let outer_tid = edit.add_node(outer_tid);
 
                 let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                 new_nodes.push(outer_tid);
 
                 // inner_idx % dim(outer_idx)
-                let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
+                let rem = edit.add_node(Node::Binary {
+                    left: outer_tid,
+                    right: outer_dc,
+                    op: BinaryOperator::Rem,
+                });
                 edit.sub_edit(tid, rem);
                 edit = edit.replace_all_uses(tid, rem)?;
             } else if tid_dim == inner_idx {
-                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = Node::ThreadID {
+                    control: new_fork_id,
+                    dimension: outer_idx,
+                };
                 let outer_tid = edit.add_node(outer_tid);
 
                 let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                 // inner_idx / dim(outer_idx)
-                let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
+                let div = edit.add_node(Node::Binary {
+                    left: outer_tid,
+                    right: outer_dc,
+                    op: BinaryOperator::Div,
+                });
                 edit.sub_edit(tid, div);
                 edit = edit.replace_all_uses(tid, div)?;
             }
@@ -868,6 +913,5 @@ pub fn fork_dim_merge(
         Ok(edit)
     });
 
-    return new_fork_id;
-
-}
\ No newline at end of file
+    new_fork
+}
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 0a2d5601..f6db06ca 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -300,32 +300,33 @@ pub fn forkify_loop(
 
     // Start failable edit:
 
-    let redcutionable_phis_and_init: Vec<(_, NodeID)> = 
-        reductionable_phis.iter().map(|reduction_phi| {
-        
-        let LoopPHI::Reductionable {
-            phi,
-            data_cycle: _,
-            continue_latch,
-            is_associative: _,
-        } = reduction_phi
-        else {
-            panic!();
-        };
+    let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis
+        .iter()
+        .map(|reduction_phi| {
+            let LoopPHI::Reductionable {
+                phi,
+                data_cycle: _,
+                continue_latch,
+                is_associative: _,
+            } = reduction_phi
+            else {
+                panic!();
+            };
 
-        let function = editor.func();
+            let function = editor.func();
 
-        let init = *zip(
-            editor.get_uses(l.header),
-            function.nodes[phi.idx()].try_phi().unwrap().1.iter(),
-        )
-        .filter(|(c, _)| *c == loop_pred)
-        .next()
-        .unwrap()
-        .1;
+            let init = *zip(
+                editor.get_uses(l.header),
+                function.nodes[phi.idx()].try_phi().unwrap().1.iter(),
+            )
+            .filter(|(c, _)| *c == loop_pred)
+            .next()
+            .unwrap()
+            .1;
 
-        (reduction_phi, init)
-    }).collect();
+            (reduction_phi, init)
+        })
+        .collect();
 
     editor.edit(|mut edit| {
         let thread_id = Node::ThreadID {
@@ -351,14 +352,13 @@ pub fn forkify_loop(
             else {
                 panic!();
             };
-            
+
             let reduce = Node::Reduce {
                 control: join_id,
                 init,
                 reduct: continue_latch,
             };
-            
-            
+
             let reduce_id = edit.add_node(reduce);
 
             if (!edit.get_node(init).is_reduce()
@@ -387,7 +387,6 @@ pub fn forkify_loop(
                 !loop_nodes.contains(usee) && *usee != reduce_id
             })?;
             edit = edit.delete_node(phi)?
-
         }
 
         edit = edit.replace_all_uses(l.header, fork_id)?;
@@ -401,7 +400,7 @@ pub fn forkify_loop(
         edit = edit.delete_node(l.header)?;
         Ok(edit)
     });
-    
+
     return true;
 }
 
@@ -538,7 +537,11 @@ pub fn analyze_phis<'a>(
             // by the time the reduce is triggered (at the end of the loop's internal control).
             // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
             // Which is not allowed.
-            if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
+            if intersection
+                .iter()
+                .any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi)
+                || editor.node(loop_continue_latch).is_phi()
+            {
                 return LoopPHI::ControlDependant(*phi);
             }
 
diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs
index 8657fdc1..c8de7e90 100644
--- a/hercules_samples/dot/build.rs
+++ b/hercules_samples/dot/build.rs
@@ -4,7 +4,11 @@ fn main() {
     JunoCompiler::new()
         .ir_in_src("dot.hir")
         .unwrap()
-        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
         .unwrap()
         .build()
         .unwrap();
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 8862c11a..7f5b453a 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -1,8 +1,8 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("dot");
 
diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs
index 735458c0..ed92e022 100644
--- a/hercules_samples/matmul/build.rs
+++ b/hercules_samples/matmul/build.rs
@@ -4,7 +4,11 @@ fn main() {
     JunoCompiler::new()
         .ir_in_src("matmul.hir")
         .unwrap()
-        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
         .unwrap()
         .build()
         .unwrap();
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index abd25ec9..5c879915 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -2,9 +2,9 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("matmul");
 
@@ -36,7 +36,9 @@ fn main() {
             let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
             let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
             let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await;
+            let c = r
+                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
+                .await;
             let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
             assert_eq!(&*c_cpu, &*correct_c);
diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs
index 22ef062a..2e352644 100644
--- a/hercules_test/hercules_interpreter/src/interpreter.rs
+++ b/hercules_test/hercules_interpreter/src/interpreter.rs
@@ -69,18 +69,18 @@ pub fn dyn_const_value(
     match dc {
         DynamicConstant::Constant(v) => *v,
         DynamicConstant::Parameter(v) => dyn_const_params[*v],
-        DynamicConstant::Add(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(0, |s, v| s + v)
-        }
+        DynamicConstant::Add(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(0, |s, v| s + v),
         DynamicConstant::Sub(a, b) => {
             dyn_const_value(a, dyn_const_values, dyn_const_params)
                 - dyn_const_value(b, dyn_const_values, dyn_const_params)
         }
-        DynamicConstant::Mul(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(1, |p, v| p * v)
-        }
+        DynamicConstant::Mul(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(1, |p, v| p * v),
         DynamicConstant::Div(a, b) => {
             dyn_const_value(a, dyn_const_values, dyn_const_params)
                 / dyn_const_value(b, dyn_const_values, dyn_const_params)
@@ -89,28 +89,28 @@ pub fn dyn_const_value(
             dyn_const_value(a, dyn_const_values, dyn_const_params)
                 % dyn_const_value(b, dyn_const_values, dyn_const_params)
         }
-        DynamicConstant::Max(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(None, |m, v| {
-                    if let Some(m) = m {
-                        Some(max(m, v))
-                    } else {
-                        Some(v)
-                    }
-                })
-                .unwrap()
-        }
-        DynamicConstant::Min(xs) => {
-            xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
-                .fold(None, |m, v| {
-                    if let Some(m) = m {
-                        Some(min(m, v))
-                    } else {
-                        Some(v)
-                    }
-                })
-                .unwrap()
-        }
+        DynamicConstant::Max(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(None, |m, v| {
+                if let Some(m) = m {
+                    Some(max(m, v))
+                } else {
+                    Some(v)
+                }
+            })
+            .unwrap(),
+        DynamicConstant::Min(xs) => xs
+            .iter()
+            .map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params))
+            .fold(None, |m, v| {
+                if let Some(m) = m {
+                    Some(min(m, v))
+                } else {
+                    Some(v)
+                }
+            })
+            .unwrap(),
     }
 }
 
@@ -775,15 +775,13 @@ impl<'a> FunctionExecutionState<'a> {
                     //     panic!("multi-dimensional forks unimplemented")
                     // }
 
-                    let factors = factors
-                        .iter()
-                        .map(|f| {
-                            dyn_const_value(
-                                &f,
-                                &self.module.dynamic_constants,
-                                &self.dynamic_constant_params,
-                            )
-                        });
+                    let factors = factors.iter().map(|f| {
+                        dyn_const_value(
+                            &f,
+                            &self.module.dynamic_constants,
+                            &self.dynamic_constant_params,
+                        )
+                    });
 
                     let n_tokens: usize = factors.clone().product();
 
diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs
index adbed6e6..4a802f7a 100644
--- a/hercules_test/hercules_interpreter/src/value.rs
+++ b/hercules_test/hercules_interpreter/src/value.rs
@@ -156,7 +156,15 @@ impl<'a> InterpreterVal {
             Constant::Float64(v) => Self::Float64(v),
 
             Constant::Product(ref type_id, ref constant_ids) => {
-                let contents = constant_ids.iter().map(|const_id| InterpreterVal::from_constant(&constants[const_id.idx()], constants, types, dynamic_constants, dynamic_constant_params));
+                let contents = constant_ids.iter().map(|const_id| {
+                    InterpreterVal::from_constant(
+                        &constants[const_id.idx()],
+                        constants,
+                        types,
+                        dynamic_constants,
+                        dynamic_constant_params,
+                    )
+                });
                 InterpreterVal::Product(*type_id, contents.collect_vec().into_boxed_slice())
             }
             Constant::Summation(_, _, _) => todo!(),
diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs
index 192c1366..795642b2 100644
--- a/hercules_test/hercules_tests/tests/loop_tests.rs
+++ b/hercules_test/hercules_tests/tests/loop_tests.rs
@@ -35,9 +35,7 @@ fn alternate_bounds_use_after_loop_no_tid() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = default_schedule![
-        Forkify,
-    ];
+    let schedule = default_schedule![Forkify,];
 
     let module = run_schedule_on_hercules(module, Some(schedule)).unwrap();
 
@@ -61,9 +59,7 @@ fn alternate_bounds_use_after_loop() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -108,10 +104,7 @@ fn do_while_separate_body() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        PhiElim,
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![PhiElim, Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -131,10 +124,7 @@ fn alternate_bounds_internal_control() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        PhiElim,
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![PhiElim, Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -155,10 +145,7 @@ fn alternate_bounds_internal_control2() {
 
     println!("result: {:?}", result_1);
 
-    let schedule = Some(default_schedule![
-        PhiElim,
-        Forkify,
-    ]);
+    let schedule = Some(default_schedule![PhiElim, Forkify,]);
 
     let module = run_schedule_on_hercules(module, schedule).unwrap();
 
@@ -366,16 +353,13 @@ fn look_at_local() {
         "/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin",
     );
 
-    let schedule = Some(default_schedule![
-    ]);
+    let schedule = Some(default_schedule![]);
 
     let result_1 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone());
 
     let module = run_schedule_on_hercules(module.clone(), schedule).unwrap();
 
-    let schedule = Some(default_schedule![
-        Unforkify, Verify,
-    ]);
+    let schedule = Some(default_schedule![Unforkify, Verify,]);
 
     let module = run_schedule_on_hercules(module.clone(), schedule).unwrap();
 
@@ -425,7 +409,15 @@ fn matmul_pipeline() {
     };
     assert_eq!(correct_c[0], value);
 
-    let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]);
+    let schedule = Some(default_schedule![
+        AutoOutline,
+        InterproceduralSROA,
+        SROA,
+        InferSchedules,
+        DCE,
+        Xdot,
+        GCM
+    ]);
 
     module = run_schedule_on_hercules(module, schedule).unwrap();
 
diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs
index e133e3c2..8668d1b4 100644
--- a/juno_frontend/src/semant.rs
+++ b/juno_frontend/src/semant.rs
@@ -752,7 +752,16 @@ fn analyze_program(
                             }
                             arg_info.push((ty, inout.is_some(), var));
 
-                            match process_irrefutable_pattern(pattern, false, var, ty, lexer, &mut stringtab, &mut env, &mut types) {
+                            match process_irrefutable_pattern(
+                                pattern,
+                                false,
+                                var,
+                                ty,
+                                lexer,
+                                &mut stringtab,
+                                &mut env,
+                                &mut types,
+                            ) {
                                 Ok(prep) => {
                                     stmts.extend(prep);
                                 }
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index 482bbf8d..e8a7e4e9 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -8,9 +8,9 @@ use self::camera_model::*;
 use self::cava_rust::CHAN;
 use self::image_proc::*;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 use image::ImageError;
 
@@ -31,7 +31,6 @@ fn run_cava(
     coefs: &[f32],
     tonemap: &[f32],
 ) -> Box<[u8]> {
-
     assert_eq!(image.len(), CHAN * rows * cols);
     assert_eq!(tstw.len(), CHAN * CHAN);
     assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN);
@@ -47,21 +46,24 @@ fn run_cava(
         let weights = HerculesCPURef::from_slice(weights);
         let coefs = HerculesCPURef::from_slice(coefs);
         let tonemap = HerculesCPURef::from_slice(tonemap);
-	    let mut r = runner!(cava);
-	    async_std::task::block_on(async {
-		r.run(
-		    rows as u64,
-		    cols as u64,
-		    num_ctrl_pts as u64,
-		    image,
-		    tstw,
-		    ctrl_pts,
-		    weights,
-		    coefs,
-		    tonemap,
-		)
-		.await
-	    }).as_slice::<u8>().to_vec().into_boxed_slice()
+        let mut r = runner!(cava);
+        async_std::task::block_on(async {
+            r.run(
+                rows as u64,
+                cols as u64,
+                num_ctrl_pts as u64,
+                image,
+                tstw,
+                ctrl_pts,
+                weights,
+                coefs,
+                tonemap,
+            )
+            .await
+        })
+        .as_slice::<u8>()
+        .to_vec()
+        .into_boxed_slice()
     }
 
     #[cfg(feature = "cuda")]
@@ -72,8 +74,8 @@ fn run_cava(
         let weights = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(weights));
         let coefs = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(coefs));
         let tonemap = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(tonemap));
-	    let mut r = runner!(cava);
-	    let res = async_std::task::block_on(async {
+        let mut r = runner!(cava);
+        let res = async_std::task::block_on(async {
             r.run(
                 rows as u64,
                 cols as u64,
@@ -86,7 +88,7 @@ fn run_cava(
                 tonemap.get_ref(),
             )
             .await
-	    });
+        });
         let num_out = unsafe { res.__size() / std::mem::size_of::<u8>() };
         let mut res_cpu: Box<[u8]> = vec![0; num_out].into_boxed_slice();
         res.to_cpu_ref(&mut res_cpu);
@@ -204,7 +206,8 @@ fn cava_harness(args: CavaInputs) {
                 .expect("Error saving verification image");
         }
 
-        let max_diff = result.iter()
+        let max_diff = result
+            .iter()
             .zip(cpu_result.iter())
             .map(|(a, b)| (*a as i16 - *b as i16).abs())
             .max()
diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs
index 9674c2c5..547dee08 100644
--- a/juno_samples/concat/src/main.rs
+++ b/juno_samples/concat/src/main.rs
@@ -1,9 +1,9 @@
 #![feature(concat_idents)]
 
 use hercules_rt::runner;
-use hercules_rt::HerculesCPURef;
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::HerculesCPURef;
 
 juno_build::juno!("concat");
 
@@ -20,7 +20,7 @@ fn main() {
             assert_eq!(output, 42);
 
             const N: usize = 3;
-            let arr : Box<[i32]> = (2..=4).collect();
+            let arr: Box<[i32]> = (2..=4).collect();
             let arr = HerculesCPURef::from_slice(&arr);
 
             let mut r = runner!(concat_switch);
diff --git a/juno_samples/edge_detection/src/main.rs b/juno_samples/edge_detection/src/main.rs
index eda65016..3b067ebd 100644
--- a/juno_samples/edge_detection/src/main.rs
+++ b/juno_samples/edge_detection/src/main.rs
@@ -2,9 +2,9 @@
 
 mod edge_detection_rust;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 use std::slice::from_raw_parts;
 
@@ -228,9 +228,9 @@ fn edge_detection_harness(args: EdgeDetectionInputs) {
         });
 
         #[cfg(not(feature = "cuda"))]
-        let result : Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice();
+        let result: Box<[f32]> = result.as_slice::<f32>().to_vec().into_boxed_slice();
         #[cfg(feature = "cuda")]
-        let result : Box<[f32]> = {
+        let result: Box<[f32]> = {
             let num_out = unsafe { result.__size() / std::mem::size_of::<f32>() };
             let mut res_cpu: Box<[f32]> = vec![0.0; num_out].into_boxed_slice();
             result.to_cpu_ref(&mut res_cpu);
@@ -261,7 +261,10 @@ fn edge_detection_harness(args: EdgeDetectionInputs) {
                 theta,
             );
 
-            assert_eq!(result.as_ref(), <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result));
+            assert_eq!(
+                result.as_ref(),
+                <Vec<f32> as AsRef<[f32]>>::as_ref(&rust_result)
+            );
             println!("Frames {} match", i);
 
             if display_verify {
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 50fe1760..2892cd34 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -2,9 +2,9 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("matmul");
 
@@ -28,10 +28,14 @@ fn main() {
             let a = HerculesCPURef::from_slice(&a);
             let b = HerculesCPURef::from_slice(&b);
             let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await;
+            let c = r
+                .run(I as u64, J as u64, K as u64, a.clone(), b.clone())
+                .await;
             assert_eq!(c.as_slice::<i32>(), &*correct_c);
             let mut r = runner!(tiled_64_matmul);
-            let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await;
+            let tiled_c = r
+                .run(I as u64, J as u64, K as u64, a.clone(), b.clone())
+                .await;
             assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
         }
         #[cfg(feature = "cuda")]
@@ -39,12 +43,16 @@ fn main() {
             let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
             let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
             let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await;
+            let c = r
+                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
+                .await;
             let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
             assert_eq!(&*c_cpu, &*correct_c);
             let mut r = runner!(tiled_64_matmul);
-            let tiled_c = r.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()).await;
+            let tiled_c = r
+                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
+                .await;
             let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             tiled_c.to_cpu_ref(&mut tiled_c_cpu);
             assert_eq!(&*tiled_c_cpu, &*correct_c);
diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs
index bc99a4bd..b364c03c 100644
--- a/juno_samples/nested_ccp/src/main.rs
+++ b/juno_samples/nested_ccp/src/main.rs
@@ -1,8 +1,8 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut};
 
 juno_build::juno!("nested_ccp");
 
diff --git a/juno_samples/patterns/src/main.rs b/juno_samples/patterns/src/main.rs
index 5cc2e7c8..a5586c8b 100644
--- a/juno_samples/patterns/src/main.rs
+++ b/juno_samples/patterns/src/main.rs
@@ -1,6 +1,6 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner};
+use hercules_rt::runner;
 
 juno_build::juno!("patterns");
 
diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs
index 749a660c..0129c4de 100644
--- a/juno_samples/schedule_test/build.rs
+++ b/juno_samples/schedule_test/build.rs
@@ -4,7 +4,11 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("code.jn")
         .unwrap()
-        .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
         .unwrap()
         .build()
         .unwrap();
diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs
index 1505d4e5..f769e750 100644
--- a/juno_samples/schedule_test/src/main.rs
+++ b/juno_samples/schedule_test/src/main.rs
@@ -2,9 +2,9 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("code");
 
@@ -43,7 +43,16 @@ fn main() {
             let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b));
             let c = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&c));
             let mut r = runner!(test);
-            let res = r.run(N as u64, M as u64, K as u64, a.get_ref(), b.get_ref(), c.get_ref()).await;
+            let res = r
+                .run(
+                    N as u64,
+                    M as u64,
+                    K as u64,
+                    a.get_ref(),
+                    b.get_ref(),
+                    c.get_ref(),
+                )
+                .await;
             let mut res_cpu: Box<[i32]> = vec![0; correct_res.len()].into_boxed_slice();
             res.to_cpu_ref(&mut res_cpu);
             assert_eq!(&*res_cpu, &*correct_res);
diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs
index 8eb78f7c..687ff414 100644
--- a/juno_samples/simple3/src/main.rs
+++ b/juno_samples/simple3/src/main.rs
@@ -1,8 +1,8 @@
 #![feature(concat_idents)]
 
-use hercules_rt::{runner, HerculesCPURef};
 #[cfg(feature = "cuda")]
 use hercules_rt::CUDABox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("simple3");
 
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index ea06a0f2..713c30d4 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -108,7 +108,7 @@ impl FromStr for Appliable {
             "inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
-            },
+            }
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 796437a7..9e85509f 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -36,7 +36,7 @@ impl Pass {
     pub fn num_args(&self) -> usize {
         match self {
             Pass::Xdot => 1,
-            Pass::ForkChunk => 3, 
+            Pass::ForkChunk => 3,
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d176b636..2142d5c5 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1566,7 +1566,7 @@ fn run_pass(
                     // this eventually.
                     let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
                     changed |= c;
-                    inner_changed |= c; 
+                    inner_changed |= c;
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
@@ -1921,24 +1921,32 @@ fn run_pass(
             let dim_idx = args.get(1);
 
             let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected boolean argument".to_string(),
+                });
             };
 
             let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
             };
 
             let Some(Value::Integer { val: tile_size }) = args.get(0) else {
-                panic!(); // How to error here?
+                return Err(SchedulerError::PassError {
+                    pass: "forkChunk".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
             };
 
             assert_eq!(*guarded_flag, true);
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in build_selection(pm, selection)
+                .into_iter()
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
@@ -1953,10 +1961,9 @@ fn run_pass(
             assert!(args.is_empty());
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
+            for (func, fork_join_map) in build_selection(pm, selection)
+                .into_iter()
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
-- 
GitLab


From 43404780896ebb6bd287c8a56783a08cded6852e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 3 Feb 2025 15:53:31 -0600
Subject: [PATCH 24/25] whoops

---
 hercules_opt/src/fork_transforms.rs | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ed6283fd..c4a6ba7f 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -868,7 +868,7 @@ pub fn fork_dim_merge(
             };
             if tid_dim > inner_idx {
                 let new_tid = Node::ThreadID {
-                    control: new_fork_id,
+                    control: new_fork,
                     dimension: tid_dim - 1,
                 };
                 let new_tid = edit.add_node(new_tid);
@@ -876,7 +876,7 @@ pub fn fork_dim_merge(
                 edit.sub_edit(tid, new_tid);
             } else if tid_dim == outer_idx {
                 let outer_tid = Node::ThreadID {
-                    control: new_fork_id,
+                    control: new_fork,
                     dimension: outer_idx,
                 };
                 let outer_tid = edit.add_node(outer_tid);
@@ -894,7 +894,7 @@ pub fn fork_dim_merge(
                 edit = edit.replace_all_uses(tid, rem)?;
             } else if tid_dim == inner_idx {
                 let outer_tid = Node::ThreadID {
-                    control: new_fork_id,
+                    control: new_fork,
                     dimension: outer_idx,
                 };
                 let outer_tid = edit.add_node(outer_tid);
-- 
GitLab


From 13bc8938b80f96711eb26738fec53044a7800136 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 4 Feb 2025 09:52:26 -0600
Subject: [PATCH 25/25] use editor modified() fork forkify

---
 juno_scheduler/src/pm.rs | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 80699fee..b2845913 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1561,12 +1561,9 @@ fn run_pass(
                     let Some(mut func) = func else {
                         continue;
                     };
-                    // TODO: uses direct return from forkify for now instead of
-                    // func.modified, see comment on top of `forkify` for why. Fix
-                    // this eventually.
-                    let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
-                    changed |= c;
-                    inner_changed |= c;
+                    forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+                    changed |= func.modified();
+                    inner_changed |= func.modified();
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
-- 
GitLab