From e0f26259f8895724ffc4d7767c7b66675ab2e871 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 17:45:14 -0600
Subject: [PATCH 01/10] new dot test

---
 Cargo.lock                   | 11 ++++++++++
 Cargo.toml                   | 39 +++++++++++++++++-------------------
 juno_samples/dot/Cargo.toml  | 22 ++++++++++++++++++++
 juno_samples/dot/build.rs    | 24 ++++++++++++++++++++++
 juno_samples/dot/src/cpu.sch | 17 ++++++++++++++++
 juno_samples/dot/src/dot.jn  | 10 +++++++++
 juno_samples/dot/src/gpu.sch | 18 +++++++++++++++++
 juno_samples/dot/src/main.rs | 27 +++++++++++++++++++++++++
 8 files changed, 147 insertions(+), 21 deletions(-)
 create mode 100644 juno_samples/dot/Cargo.toml
 create mode 100644 juno_samples/dot/build.rs
 create mode 100644 juno_samples/dot/src/cpu.sch
 create mode 100644 juno_samples/dot/src/dot.jn
 create mode 100644 juno_samples/dot/src/gpu.sch
 create mode 100644 juno_samples/dot/src/main.rs

diff --git a/Cargo.lock b/Cargo.lock
index 4431cf5d..81c37d79 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1170,6 +1170,17 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_dot"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "rand 0.8.5",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_edge_detection"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index 6514046b..7a3906fb 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -5,34 +5,31 @@ members = [
 	"hercules_ir",
 	"hercules_opt",
 	"hercules_rt",
-
-	"juno_utils",
-	"juno_frontend",
-	"juno_scheduler",
-	"juno_build",
-
-	"hercules_test/hercules_interpreter",
-	"hercules_test/hercules_tests",
-
-	"hercules_samples/dot",
-	"hercules_samples/matmul",
-	"hercules_samples/fac",
 	"hercules_samples/call",
 	"hercules_samples/ccp",
-
-	"juno_samples/simple3",
-	"juno_samples/patterns",
-	"juno_samples/matmul",
-	"juno_samples/casts_and_intrinsics",
-	"juno_samples/control",
+	"hercules_samples/dot",
+	"hercules_samples/fac",
+	"hercules_samples/matmul",
+	"hercules_test/hercules_interpreter",
+	"hercules_test/hercules_tests",
+	"juno_build",
+	"juno_frontend",
 	"juno_samples/antideps",
-	"juno_samples/implicit_clone",
+	"juno_samples/casts_and_intrinsics",
 	"juno_samples/cava",
 	"juno_samples/concat",
-	"juno_samples/schedule_test",
+	"juno_samples/control",
+	"juno_samples/dot",
 	"juno_samples/edge_detection",
 	"juno_samples/fork_join_tests",
+	"juno_samples/implicit_clone",
+	"juno_samples/matmul",
+	"juno_samples/median_window",
 	"juno_samples/multi_device",
+	"juno_samples/patterns",
 	"juno_samples/products",
-	"juno_samples/median_window",
+	"juno_samples/schedule_test",
+	"juno_samples/simple3",
+	"juno_scheduler",
+	"juno_utils",
 ]
diff --git a/juno_samples/dot/Cargo.toml b/juno_samples/dot/Cargo.toml
new file mode 100644
index 00000000..155a0b13
--- /dev/null
+++ b/juno_samples/dot/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "juno_dot"
+version = "0.1.0"
+authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_dot"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
+rand = "*"
diff --git a/juno_samples/dot/build.rs b/juno_samples/dot/build.rs
new file mode 100644
index 00000000..b23b9860
--- /dev/null
+++ b/juno_samples/dot/build.rs
@@ -0,0 +1,24 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    #[cfg(not(feature = "cuda"))]
+    {
+        JunoCompiler::new()
+            .file_in_src("dot.jn")
+            .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+    #[cfg(feature = "cuda")]
+    {
+        JunoCompiler::new()
+            .file_in_src("dot.jn")
+            .unwrap()
+            .schedule_in_src("gpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+}
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
new file mode 100644
index 00000000..be110bde
--- /dev/null
+++ b/juno_samples/dot/src/cpu.sch
@@ -0,0 +1,17 @@
+phi-elim(*);
+
+forkify(*);
+fork-guard-elim(*);
+dce(*);
+
+fork-tile[8, 0, false, true](*);
+fork-split(*);
+
+let out = auto-outline(*);
+cpu(out.dot);
+ip-sroa(*);
+sroa(*);
+dce(*);
+
+unforkify(*);
+gcm(*);
diff --git a/juno_samples/dot/src/dot.jn b/juno_samples/dot/src/dot.jn
new file mode 100644
index 00000000..0421dc4c
--- /dev/null
+++ b/juno_samples/dot/src/dot.jn
@@ -0,0 +1,10 @@
+#[entry]
+fn dot<n : usize>(a : f32[n], b : f32[n]) -> f32 {
+  let res : f32;
+
+  for i = 0 to n {
+    res += a[i] * b[i];
+  }
+
+  return res;
+}
diff --git a/juno_samples/dot/src/gpu.sch b/juno_samples/dot/src/gpu.sch
new file mode 100644
index 00000000..b7ece681
--- /dev/null
+++ b/juno_samples/dot/src/gpu.sch
@@ -0,0 +1,18 @@
+phi-elim(*);
+
+forkify(*);
+fork-guard-elim(*);
+dce(*);
+
+fork-tile[8, 0, false, true](*);
+fork-split(*);
+
+let out = auto-outline(*);
+gpu(out.dot);
+ip-sroa(*);
+sroa(*);
+dce(*);
+
+unforkify(*);
+gcm(*);
+
diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs
new file mode 100644
index 00000000..5d0aaf7b
--- /dev/null
+++ b/juno_samples/dot/src/main.rs
@@ -0,0 +1,27 @@
+#![feature(concat_idents)]
+use std::iter::zip;
+
+use rand::random;
+
+use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
+
+juno_build::juno!("dot");
+
+fn main() {
+    async_std::task::block_on(async {
+        const N: u64 = 4096;
+        let a: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect();
+        let b: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect();
+        let a_herc = HerculesImmBox::from(&a as &[f32]);
+        let b_herc = HerculesImmBox::from(&b as &[f32]);
+        let mut r = runner!(dot);
+        let output = r.run(N, a_herc.to(), b_herc.to()).await;
+        let correct = zip(a, b).map(|(a, b)| a * b).sum();
+        assert_eq!(output, correct);
+    });
+}
+
+#[test]
+fn dot_test() {
+    main();
+}
-- 
GitLab


From 3329cb5f72e361010d2596f5724211798467a3dd Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 18:01:01 -0600
Subject: [PATCH 02/10] Rename tightassociative to monoidreduce

---
 hercules_cg/src/gpu.rs              |  4 ++--
 hercules_ir/src/einsum.rs           | 13 ++++++------
 hercules_ir/src/ir.rs               |  2 +-
 hercules_opt/src/fork_transforms.rs |  4 ++--
 hercules_opt/src/forkify.rs         | 31 ++++++++++++++---------------
 hercules_opt/src/schedule.rs        | 24 ++++++++++++----------
 juno_scheduler/src/compile.rs       |  2 +-
 juno_scheduler/src/pm.rs            |  2 +-
 8 files changed, 42 insertions(+), 40 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index d6461a1e..33b239f7 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -821,7 +821,7 @@ extern \"C\" {} {}(",
             && fork_size.is_power_of_two()
             && reduces.iter().all(|&reduce| {
                 self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
-                    || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                    || self.function.schedules[reduce.idx()].contains(&Schedule::MonoidReduce)
             })
         {
             // If there's an associative Reduce, parallelize the larger factor
@@ -834,7 +834,7 @@ extern \"C\" {} {}(",
             // restriction doesn't help for parallel Writes, so nested parallelization
             // is possible.
             if reduces.iter().any(|&reduce| {
-                self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                self.function.schedules[reduce.idx()].contains(&Schedule::MonoidReduce)
             }) || fork_size > self.kernel_params.max_num_threads / subtree_quota
             {
                 if fork_size >= subtree_quota {
diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index b222e1bc..6c2ca31b 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -150,13 +150,12 @@ pub fn einsum(
                 ctx.result_insert(reduce, total_id);
             }
             // The reduce defines a sum reduction over a set of fork dimensions.
-            else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
-                && let Node::Binary {
-                    op: BinaryOperator::Add,
-                    left,
-                    right,
-                } = function.nodes[reduct.idx()]
-                && (left == reduce || right == reduce)
+            else if let Node::Binary {
+                op: BinaryOperator::Add,
+                left,
+                right,
+            } = function.nodes[reduct.idx()]
+                && ((left == reduce) ^ (right == reduce))
             {
                 let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left });
                 let reduce_expr = MathExpr::SumReduction(
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index eb008904..972fd7f9 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -328,7 +328,7 @@ pub enum Schedule {
     Vectorizable,
     // This reduce can be re-associated. This may lower a sequential dependency
     // chain into a reduction tree.
-    TightAssociative,
+    MonoidReduce,
     // This constant node doesn't need to be memset to zero.
     NoResetConstant,
     // This call should be called in a spawned future.
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index c32a517e..342728fd 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1175,7 +1175,7 @@ fn fork_interchange(
     first_dim: usize,
     second_dim: usize,
 ) {
-    // Check that every reduce on the join is parallel or tight associative.
+    // Check that every reduce on the join is parallel or associative.
     let nodes = &editor.func().nodes;
     let schedules = &editor.func().schedules;
     if !editor
@@ -1183,7 +1183,7 @@ fn fork_interchange(
         .filter(|id| nodes[id.idx()].is_reduce())
         .all(|id| {
             schedules[id.idx()].contains(&Schedule::ParallelReduce)
-                || schedules[id.idx()].contains(&Schedule::TightAssociative)
+                || schedules[id.idx()].contains(&Schedule::MonoidReduce)
         })
     {
         // If not, we can't necessarily do interchange.
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 774220df..2f6466c0 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -30,15 +30,17 @@ pub fn forkify(
 
     for l in natural_loops {
         // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses.
-        if editor.is_mutable(l.0) && forkify_loop(
-            editor,
-            control_subgraph,
-            fork_join_map,
-            &Loop {
-                header: l.0,
-                control: l.1.clone(),
-            },
-        ) {
+        if editor.is_mutable(l.0)
+            && forkify_loop(
+                editor,
+                control_subgraph,
+                fork_join_map,
+                &Loop {
+                    header: l.0,
+                    control: l.1.clone(),
+                },
+            )
+        {
             return true;
         }
     }
@@ -166,7 +168,6 @@ pub fn forkify_loop(
         return false;
     }
 
-
     // Get all phis used outside of the loop, they need to be reductionable.
     // For now just assume all phis will be phis used outside of the loop, except for the canonical iv.
     // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one
@@ -371,15 +372,13 @@ pub fn forkify_loop(
                 edit = edit.add_schedule(reduce_id, Schedule::ParallelReduce)?;
             }
             if (!edit.get_node(init).is_reduce()
-                && edit
-                    .get_schedule(init)
-                    .contains(&Schedule::TightAssociative))
+                && edit.get_schedule(init).contains(&Schedule::MonoidReduce))
                 || (!edit.get_node(continue_latch).is_reduce()
                     && edit
                         .get_schedule(continue_latch)
-                        .contains(&Schedule::TightAssociative))
+                        .contains(&Schedule::MonoidReduce))
             {
-                edit = edit.add_schedule(reduce_id, Schedule::TightAssociative)?;
+                edit = edit.add_schedule(reduce_id, Schedule::MonoidReduce)?;
             }
 
             edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?;
@@ -539,7 +538,7 @@ pub fn analyze_phis<'a>(
             // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
-    
+
             // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
             // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
             if intersection
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index fe894e47..4cb912fd 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -146,21 +146,25 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N
 }
 
 /*
- * Infer tight associative reduction loops. Exactly one of the associative
- * operation's operands must be the Reduce node, and all other operands must
- * not be in the Reduce node's cycle.
+ * Infer monoid reduction loops. Exactly one of the associative operation's
+ * operands must be the Reduce node, and all other operands must not be in the
+ * Reduce node's cycle.
  */
-pub fn infer_tight_associative(
+pub fn infer_monoid_reduce(
     editor: &mut FunctionEditor,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
 ) {
-    let is_binop_associative = |op| {
+    let is_binop_monoid = |op| {
         matches!(
             op,
-            BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor
+            BinaryOperator::Add
+                | BinaryOperator::Mul
+                | BinaryOperator::Or
+                | BinaryOperator::And
+                | BinaryOperator::Xor
         )
     };
-    let is_intrinsic_associative = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);
+    let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);
 
     for id in editor.node_ids() {
         let func = editor.func();
@@ -172,12 +176,12 @@ pub fn infer_tight_associative(
             && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } 
                 if ((left == id && !reduce_cycles[&id].contains(&right)) || 
                     (right == id && !reduce_cycles[&id].contains(&left))) && 
-                    is_binop_associative(op))
+                    is_binop_monoid(op))
                 || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
-                if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && 
+                if (args.contains(&id) && is_intrinsic_monoid(*intrinsic) && 
                     args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg)))))
         {
-            editor.edit(|edit| edit.add_schedule(id, Schedule::TightAssociative));
+            editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce));
         }
     }
 }
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 7c92e00d..188cb1c6 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -153,7 +153,7 @@ impl FromStr for Appliable {
             "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)),
             "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)),
 
-            "associative" => Ok(Appliable::Schedule(Schedule::TightAssociative)),
+            "monoid" | "associative" => Ok(Appliable::Schedule(Schedule::MonoidReduce)),
             "parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)),
             "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)),
             "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)),
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 8e152cfe..19bd78e2 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2034,7 +2034,7 @@ fn run_pass(
                 infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles);
                 infer_parallel_fork(&mut func, fork_join_map);
                 infer_vectorizable(&mut func, fork_join_map);
-                infer_tight_associative(&mut func, reduce_cycles);
+                infer_monoid_reduce(&mut func, reduce_cycles);
                 infer_no_reset_constants(&mut func, no_reset_constants);
                 changed |= func.modified();
             }
-- 
GitLab


From 822eb61ac348e7b80a36690234c305d427e2b03b Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 18:26:29 -0600
Subject: [PATCH 03/10] Clean monoid reduces pass

---
 hercules_ir/src/ir.rs               | 16 ++++++++
 hercules_opt/src/fork_transforms.rs | 61 +++++++++++++++++++++++++++++
 hercules_opt/src/schedule.rs        |  1 -
 hercules_opt/src/utils.rs           | 26 ++++++++++++
 juno_samples/dot/src/cpu.sch        |  2 +
 juno_scheduler/src/compile.rs       |  1 +
 juno_scheduler/src/ir.rs            |  1 +
 juno_scheduler/src/pm.rs            | 17 ++++++++
 8 files changed, 124 insertions(+), 1 deletion(-)

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 972fd7f9..e8dfc280 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1086,6 +1086,22 @@ impl DynamicConstant {
         }
     }
 
+    pub fn is_zero(&self) -> bool {
+        if *self == DynamicConstant::Constant(0) {
+            true
+        } else {
+            false
+        }
+    }
+
+    pub fn is_one(&self) -> bool {
+        if *self == DynamicConstant::Constant(1) {
+            true
+        } else {
+            false
+        }
+    }
+
     pub fn try_parameter(&self) -> Option<usize> {
         if let DynamicConstant::Parameter(v) = self {
             Some(*v)
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 342728fd..0b5de1e5 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1446,3 +1446,64 @@ fn fork_fusion(
         Ok(edit)
     })
 }
+
+/*
+ * Looks for monoid reductions where the initial input is not the identity
+ * element, and converts them into a form whose initial input is an identity
+ * element. This aides in parallelizing outer loops. Looks only at reduces with
+ * the monoid reduce schedule, since that indicates a particular structure which
+ * is annoying to check for again.
+ */
+pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
+    for id in editor.node_ids() {
+        if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
+            continue;
+        }
+        let nodes = &editor.func().nodes;
+        let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
+            continue;
+        };
+
+        match nodes[reduct.idx()] {
+            Node::Binary {
+                op,
+                left: _,
+                right: _,
+            } if (op == BinaryOperator::Add || op == BinaryOperator::Or)
+                && !is_zero(editor, init) =>
+            {
+                editor.edit(|mut edit| {
+                    let zero = edit.add_zero_constant(typing[init.idx()]);
+                    let zero = edit.add_node(Node::Constant { id: zero });
+                    edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?;
+                    let final_add = edit.add_node(Node::Binary {
+                        op,
+                        left: init,
+                        right: id,
+                    });
+                    edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add)
+                });
+            }
+            Node::Binary {
+                op,
+                left: _,
+                right: _,
+            } if (op == BinaryOperator::Mul || op == BinaryOperator::And)
+                && !is_one(editor, init) =>
+            {
+                editor.edit(|mut edit| {
+                    let one = edit.add_one_constant(typing[init.idx()]);
+                    let one = edit.add_node(Node::Constant { id: one });
+                    edit = edit.replace_all_uses_where(init, one, |u| *u == id)?;
+                    let final_add = edit.add_node(Node::Binary {
+                        op,
+                        left: init,
+                        right: id,
+                    });
+                    edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add)
+                });
+            }
+            _ => panic!(),
+        }
+    }
+}
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index 4cb912fd..7ecf07a4 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -161,7 +161,6 @@ pub fn infer_monoid_reduce(
                 | BinaryOperator::Mul
                 | BinaryOperator::Or
                 | BinaryOperator::And
-                | BinaryOperator::Xor
         )
     };
     let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index 3f12ad7c..1806d5c7 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -541,3 +541,29 @@ where
         nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len()
     })
 }
+
+pub fn is_zero(editor: &FunctionEditor, id: NodeID) -> bool {
+    let nodes = &editor.func().nodes;
+    nodes[id.idx()]
+        .try_constant()
+        .map(|id| editor.get_constant(id).is_zero())
+        .unwrap_or(false)
+        || nodes[id.idx()]
+            .try_dynamic_constant()
+            .map(|id| editor.get_dynamic_constant(id).is_zero())
+            .unwrap_or(false)
+        || nodes[id.idx()].is_undef()
+}
+
+pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
+    let nodes = &editor.func().nodes;
+    nodes[id.idx()]
+        .try_constant()
+        .map(|id| editor.get_constant(id).is_one())
+        .unwrap_or(false)
+        || nodes[id.idx()]
+            .try_dynamic_constant()
+            .map(|id| editor.get_dynamic_constant(id).is_one())
+            .unwrap_or(false)
+        || nodes[id.idx()].is_undef()
+}
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index be110bde..6ee00c8b 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -6,6 +6,8 @@ dce(*);
 
 fork-tile[8, 0, false, true](*);
 fork-split(*);
+infer-schedules(*);
+clean-monoid-reduces(*);
 
 let out = auto-outline(*);
 cpu(out.dot);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 188cb1c6..2e930639 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -111,6 +111,7 @@ impl FromStr for Appliable {
             "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)),
             "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)),
             "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)),
+            "clean-monoid-reduces" => Ok(Appliable::Pass(ir::Pass::CleanMonoidReduces)),
             "dce" => Ok(Appliable::Pass(ir::Pass::DCE)),
             "delete-uncalled" => Ok(Appliable::DeleteUncalled),
             "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 205cd70b..d6f59bef 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -6,6 +6,7 @@ pub enum Pass {
     ArrayToProduct,
     AutoOutline,
     CCP,
+    CleanMonoidReduces,
     CRC,
     DCE,
     FloatCollections,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 19bd78e2..45cebe80 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1668,6 +1668,23 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::CleanMonoidReduces => {
+            assert!(args.is_empty());
+            pm.make_typing();
+            let typing = pm.typing.take().unwrap();
+            for (func, typing) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(typing.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                clean_monoid_reduces(&mut func, typing);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::CRC => {
             assert!(args.is_empty());
             for func in build_selection(pm, selection, false) {
-- 
GitLab


From 544c6659b5cfeaef555990e3d29808587cfa7b95 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 18:39:56 -0600
Subject: [PATCH 04/10] Found bug in fork tiling in the other direction

---
 hercules_opt/src/fork_transforms.rs |  2 +-
 juno_samples/dot/src/cpu.sch        |  5 +++++
 juno_samples/dot/src/dot.jn         |  4 ++--
 juno_samples/dot/src/main.rs        | 10 +++++-----
 4 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 0b5de1e5..05606f5a 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1503,7 +1503,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                     edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add)
                 });
             }
-            _ => panic!(),
+            _ => {}
         }
     }
 }
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index 6ee00c8b..4e40e351 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -5,9 +5,12 @@ fork-guard-elim(*);
 dce(*);
 
 fork-tile[8, 0, false, true](*);
+fork-tile[32, 0, false, false](*);
 fork-split(*);
 infer-schedules(*);
 clean-monoid-reduces(*);
+infer-schedules(*);
+clean-monoid-reduces(*);
 
 let out = auto-outline(*);
 cpu(out.dot);
@@ -15,5 +18,7 @@ ip-sroa(*);
 sroa(*);
 dce(*);
 
+xdot[true](*);
+
 unforkify(*);
 gcm(*);
diff --git a/juno_samples/dot/src/dot.jn b/juno_samples/dot/src/dot.jn
index 0421dc4c..8c0e029c 100644
--- a/juno_samples/dot/src/dot.jn
+++ b/juno_samples/dot/src/dot.jn
@@ -1,6 +1,6 @@
 #[entry]
-fn dot<n : usize>(a : f32[n], b : f32[n]) -> f32 {
-  let res : f32;
+fn dot<n : usize>(a : i64[n], b : i64[n]) -> i64 {
+  let res : i64;
 
   for i = 0 to n {
     res += a[i] * b[i];
diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs
index 5d0aaf7b..b73f8710 100644
--- a/juno_samples/dot/src/main.rs
+++ b/juno_samples/dot/src/main.rs
@@ -9,11 +9,11 @@ juno_build::juno!("dot");
 
 fn main() {
     async_std::task::block_on(async {
-        const N: u64 = 4096;
-        let a: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect();
-        let b: Box<[f32]> = (0..N).map(|_| random::<f32>()).collect();
-        let a_herc = HerculesImmBox::from(&a as &[f32]);
-        let b_herc = HerculesImmBox::from(&b as &[f32]);
+        const N: u64 = 1024 * 1024;
+        let a: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect();
+        let b: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect();
+        let a_herc = HerculesImmBox::from(&a as &[i64]);
+        let b_herc = HerculesImmBox::from(&b as &[i64]);
         let mut r = runner!(dot);
         let output = r.run(N, a_herc.to(), b_herc.to()).await;
         let correct = zip(a, b).map(|(a, b)| a * b).sum();
-- 
GitLab


From d9a01c8b3536d28bba4726e9bc9e331a56c52d4f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 18:44:41 -0600
Subject: [PATCH 05/10] Fix

---
 hercules_opt/src/fork_transforms.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 05606f5a..7d4fa9a2 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1002,7 +1002,7 @@ pub fn chunk_fork_unguarded(
                     } else if tid_dim == dim_idx {
                         let tile_tid = Node::ThreadID {
                             control: new_fork,
-                            dimension: tid_dim,
+                            dimension: tid_dim + 1,
                         };
                         let tile_tid = edit.add_node(tile_tid);
                         let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
-- 
GitLab


From 900024ee60c8d8b8c5fb8f1979b2a19105ac7bed Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 21:53:56 -0600
Subject: [PATCH 06/10] Fixes for bufferize fission

---
 hercules_opt/src/fork_transforms.rs | 51 ++++++++++++++++++++++++++---
 hercules_opt/src/simplify_cfg.rs    | 45 ++++++++++++-------------
 juno_samples/dot/src/dot.jn         |  2 +-
 juno_scheduler/src/ir.rs            |  4 +--
 juno_scheduler/src/pm.rs            | 48 ++++++++++++++-------------
 5 files changed, 97 insertions(+), 53 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 7d4fa9a2..e832e559 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -204,14 +204,48 @@ pub fn find_bufferize_edges(
     edges
 }
 
+pub fn ff_bufferize_create_not_reduce_cycle_label_helper(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> LabelID {
+    let join = fork_join_map[&fork];
+    let mut nodes_not_in_a_reduce_cycle = nodes_in_fork_joins[&fork].clone();
+    for (cycle, reduce) in editor
+        .get_users(join)
+        .filter_map(|id| reduce_cycles.get(&id).map(|cycle| (cycle, id)))
+    {
+        nodes_not_in_a_reduce_cycle.remove(&reduce);
+        for id in cycle {
+            nodes_not_in_a_reduce_cycle.remove(id);
+        }
+    }
+    nodes_not_in_a_reduce_cycle.remove(&join);
+
+    let mut label = LabelID::new(0);
+    let success = editor.edit(|mut edit| {
+        label = edit.fresh_label();
+        for id in nodes_not_in_a_reduce_cycle {
+            edit = edit.add_label(id, label)?;
+        }
+        Ok(edit)
+    });
+
+    assert!(success);
+    label
+}
+
 pub fn ff_bufferize_any_fork<'a, 'b>(
     editor: &'b mut FunctionEditor<'a>,
     loop_tree: &'b LoopTree,
     fork_join_map: &'b HashMap<NodeID, NodeID>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
     typing: &'b Vec<TypeID>,
-    fork_label: &'b LabelID,
-    data_label: &'b LabelID,
+    fork_label: LabelID,
+    data_label: Option<LabelID>,
 ) -> Option<(NodeID, NodeID)>
 where
     'a: 'b,
@@ -230,17 +264,26 @@ where
         let fork = fork_info.header;
         let join = fork_join_map[&fork];
 
-        if !editor.func().labels[fork.idx()].contains(fork_label) {
+        if !editor.func().labels[fork.idx()].contains(&fork_label) {
             continue;
         }
 
+        let data_label = data_label.unwrap_or_else(|| {
+            ff_bufferize_create_not_reduce_cycle_label_helper(
+                editor,
+                fork,
+                fork_join_map,
+                reduce_cycles,
+                nodes_in_fork_joins,
+            )
+        });
         let edges = find_bufferize_edges(
             editor,
             fork,
             &loop_tree,
             &fork_join_map,
             &nodes_in_fork_joins,
-            data_label,
+            &data_label,
         );
         let result = fork_bufferize_fission_helper(
             editor,
diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs
index d579012e..14a152dc 100644
--- a/hercules_opt/src/simplify_cfg.rs
+++ b/hercules_opt/src/simplify_cfg.rs
@@ -91,36 +91,33 @@ fn remove_useless_fork_joins(
     fork_join_map: &HashMap<NodeID, NodeID>,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
 ) {
-    // First, try to get rid of reduces where possible. We can only delete all
-    // the reduces or none of the reduces in a particular fork-join, since even
-    // if one reduce may have no users outside the reduction cycle, it may be
-    // used by a reduce that is used outside the cycle, so it shouldn't be
-    // deleted. The reduction cycle may contain every reduce in a fork-join.
+    // First, try to get rid of reduces where possible. Look for reduces with no
+    // users outside its reduce cycle, and its reduce cycle contains no other
+    // reduce nodes.
     for (_, join) in fork_join_map {
-        let nodes = &editor.func().nodes;
         let reduces: Vec<_> = editor
             .get_users(*join)
-            .filter(|id| nodes[id.idx()].is_reduce())
+            .filter(|id| editor.func().nodes[id.idx()].is_reduce())
             .collect();
 
-        // If every reduce has users only in the reduce cycle, then all the
-        // reduces can be deleted, along with every node in the reduce cycles.
-        if reduces.iter().all(|reduce| {
-            editor
-                .get_users(*reduce)
-                .all(|user| reduce_cycles[reduce].contains(&user))
-        }) {
-            let mut all_the_nodes = HashSet::new();
-            for reduce in reduces {
-                all_the_nodes.insert(reduce);
-                all_the_nodes.extend(&reduce_cycles[&reduce]);
+        for reduce in reduces {
+            // If the reduce has users only in the reduce cycle, and none of
+            // the nodes in the cycle are reduce nodes, then the reduce and its
+            // whole cycle can be deleted.
+            if editor
+                .get_users(reduce)
+                .all(|user| reduce_cycles[&reduce].contains(&user))
+                && reduce_cycles[&reduce]
+                    .iter()
+                    .all(|id| !editor.func().nodes[id.idx()].is_reduce())
+            {
+                editor.edit(|mut edit| {
+                    for id in reduce_cycles[&reduce].iter() {
+                        edit = edit.delete_node(*id)?;
+                    }
+                    edit.delete_node(reduce)
+                });
             }
-            editor.edit(|mut edit| {
-                for id in all_the_nodes {
-                    edit = edit.delete_node(id)?;
-                }
-                Ok(edit)
-            });
         }
     }
 
diff --git a/juno_samples/dot/src/dot.jn b/juno_samples/dot/src/dot.jn
index 8c0e029c..cf097178 100644
--- a/juno_samples/dot/src/dot.jn
+++ b/juno_samples/dot/src/dot.jn
@@ -2,7 +2,7 @@
 fn dot<n : usize>(a : i64[n], b : i64[n]) -> i64 {
   let res : i64;
 
-  for i = 0 to n {
+  @loop for i = 0 to n {
     res += a[i] * b[i];
   }
 
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index d6f59bef..11cf6b13 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -48,7 +48,7 @@ impl Pass {
         match self {
             Pass::ArrayToProduct => num == 0 || num == 1,
             Pass::ForkChunk => num == 4,
-            Pass::ForkFissionBufferize => num == 2,
+            Pass::ForkFissionBufferize => num == 2 || num == 1,
             Pass::ForkInterchange => num == 2,
             Pass::Print => num == 1,
             Pass::Rename => num == 1,
@@ -61,7 +61,7 @@ impl Pass {
         match self {
             Pass::ArrayToProduct => "0 or 1",
             Pass::ForkChunk => "4",
-            Pass::ForkFissionBufferize => "2",
+            Pass::ForkFissionBufferize => "1 or 2",
             Pass::ForkInterchange => "2",
             Pass::Print => "1",
             Pass::Rename => "1",
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 45cebe80..34c2474b 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2414,7 +2414,7 @@ fn run_pass(
             pm.clear_analyses();
         }
         Pass::ForkFissionBufferize => {
-            assert_eq!(args.len(), 2);
+            assert!(args.len() == 1 || args.len() == 2);
             let Some(Value::Label {
                 labels: fork_labels,
             }) = args.get(0)
@@ -2425,25 +2425,17 @@ fn run_pass(
                 });
             };
 
-            let Some(Value::Label {
-                labels: fork_data_labels,
-            }) = args.get(1)
-            else {
-                return Err(SchedulerError::PassError {
-                    pass: "forkFissionBufferize".to_string(),
-                    error: "expected label argument".to_string(),
-                });
-            };
-
             let mut created_fork_joins = vec![vec![]; pm.functions.len()];
 
             pm.make_fork_join_maps();
             pm.make_typing();
             pm.make_loops();
+            pm.make_reduce_cycles();
             pm.make_nodes_in_fork_joins();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
             let typing = pm.typing.take().unwrap();
             let loops = pm.loops.take().unwrap();
+            let reduce_cycles = pm.reduce_cycles.take().unwrap();
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
 
             // assert only one function is in the selection.
@@ -2454,30 +2446,42 @@ fn run_pass(
 
             assert!(num_functions <= 1);
             assert_eq!(fork_labels.len(), 1);
-            assert_eq!(fork_data_labels.len(), 1);
 
             let fork_label = fork_labels[0].label;
-            let data_label = fork_data_labels[0].label;
 
-            for ((((func, fork_join_map), loop_tree), typing), nodes_in_fork_joins) in
-                build_selection(pm, selection, false)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
-                    .zip(loops.iter())
-                    .zip(typing.iter())
-                    .zip(nodes_in_fork_joins.iter())
+            for (
+                ((((func, fork_join_map), loop_tree), typing), reduce_cycles),
+                nodes_in_fork_joins,
+            ) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(fork_join_maps.iter())
+                .zip(loops.iter())
+                .zip(typing.iter())
+                .zip(reduce_cycles.iter())
+                .zip(nodes_in_fork_joins.iter())
             {
                 let Some(mut func) = func else {
                     continue;
                 };
+
+                let data_label = if let Some(Value::Label {
+                    labels: fork_data_labels,
+                }) = args.get(1)
+                {
+                    assert_eq!(fork_data_labels.len(), 1);
+                    Some(fork_data_labels[0].label)
+                } else {
+                    None
+                };
                 if let Some((fork1, fork2)) = ff_bufferize_any_fork(
                     &mut func,
                     loop_tree,
                     fork_join_map,
+                    reduce_cycles,
                     nodes_in_fork_joins,
                     typing,
-                    &fork_label,
-                    &data_label,
+                    fork_label,
+                    data_label,
                 ) {
                     let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
                     created_fork_joins.push(fork1);
-- 
GitLab


From c8ee3bdcd8756a4d4bbb7f5229688388fb2d9e9e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 22:23:22 -0600
Subject: [PATCH 07/10] fission works for dot example

---
 hercules_opt/src/fork_transforms.rs | 18 ++++++++++-----
 juno_samples/dot/src/cpu.sch        | 35 ++++++++++++++++++-----------
 juno_scheduler/src/compile.rs       |  4 +++-
 3 files changed, 38 insertions(+), 19 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e832e559..7f6dd1bc 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -250,11 +250,12 @@ pub fn ff_bufferize_any_fork<'a, 'b>(
 where
     'a: 'b,
 {
-    let forks: Vec<_> = loop_tree
+    let mut forks: Vec<_> = loop_tree
         .bottom_up_loops()
         .into_iter()
         .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
+    forks.reverse();
 
     for l in forks {
         let fork_info = Loop {
@@ -1506,6 +1507,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
         let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
             continue;
         };
+        let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
 
         match nodes[reduct.idx()] {
             Node::Binary {
@@ -1519,12 +1521,15 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                     let zero = edit.add_zero_constant(typing[init.idx()]);
                     let zero = edit.add_node(Node::Constant { id: zero });
                     edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?;
-                    let final_add = edit.add_node(Node::Binary {
+                    let final_op = edit.add_node(Node::Binary {
                         op,
                         left: init,
                         right: id,
                     });
-                    edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add)
+                    for u in out_uses {
+                        edit.sub_edit(u, final_op);
+                    }
+                    edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                 });
             }
             Node::Binary {
@@ -1538,12 +1543,15 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                     let one = edit.add_one_constant(typing[init.idx()]);
                     let one = edit.add_node(Node::Constant { id: one });
                     edit = edit.replace_all_uses_where(init, one, |u| *u == id)?;
-                    let final_add = edit.add_node(Node::Binary {
+                    let final_op = edit.add_node(Node::Binary {
                         op,
                         left: init,
                         right: id,
                     });
-                    edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add)
+                    for u in out_uses {
+                        edit.sub_edit(u, final_op);
+                    }
+                    edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                 });
             }
             _ => {}
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index 4e40e351..734054ab 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -1,24 +1,33 @@
-phi-elim(*);
+phi-elim(dot);
+ip-sroa(*);
+sroa(dot);
+dce(dot);
 
-forkify(*);
-fork-guard-elim(*);
-dce(*);
+forkify(dot);
+fork-guard-elim(dot);
+dce(dot);
 
-fork-tile[8, 0, false, true](*);
-fork-tile[32, 0, false, false](*);
-fork-split(*);
+fork-tile[8, 0, false, true](dot);
+fork-tile[32, 0, false, false](dot);
+let split_out = fork-split(dot);
 infer-schedules(*);
 clean-monoid-reduces(*);
 infer-schedules(*);
 clean-monoid-reduces(*);
 
-let out = auto-outline(*);
-cpu(out.dot);
+let out = outline(split_out.dot.fj1);
 ip-sroa(*);
-sroa(*);
-dce(*);
+sroa(dot);
+gvn(dot);
+dce(dot);
 
-xdot[true](*);
+let fission_out = fork-fission[out@loop](dot);
+simplify-cfg(dot);
+dce(dot);
+unforkify(fission_out.dot.fj_loop_bottom);
+ccp(dot);
+gvn(dot);
+dce(dot);
 
-unforkify(*);
+unforkify(out);
 gcm(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 2e930639..88816562 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -125,7 +125,9 @@ impl FromStr for Appliable {
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
             }
-            "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)),
+            "fork-fission-bufferize" | "fork-fission" => {
+                Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize))
+            }
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
-- 
GitLab


From e9e5aa319d3ad26b543f48519ecdc06c2f97c486 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 15 Feb 2025 22:48:21 -0600
Subject: [PATCH 08/10] a bunch of stuff for dot

---
 hercules_opt/src/fork_transforms.rs |  2 ++
 hercules_opt/src/unforkify.rs       | 30 ++++++++++++++++++++++++++---
 juno_samples/dot/src/cpu.sch        | 13 +++++++++----
 juno_scheduler/src/compile.rs       |  1 +
 juno_scheduler/src/ir.rs            |  1 +
 juno_scheduler/src/pm.rs            | 22 +++++++++++++++++++++
 6 files changed, 62 insertions(+), 7 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 7f6dd1bc..283734a0 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1520,6 +1520,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                 editor.edit(|mut edit| {
                     let zero = edit.add_zero_constant(typing[init.idx()]);
                     let zero = edit.add_node(Node::Constant { id: zero });
+                    edit.sub_edit(id, zero);
                     edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?;
                     let final_op = edit.add_node(Node::Binary {
                         op,
@@ -1542,6 +1543,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
                 editor.edit(|mut edit| {
                     let one = edit.add_one_constant(typing[init.idx()]);
                     let one = edit.add_node(Node::Constant { id: one });
+                    edit.sub_edit(id, one);
                     edit = edit.replace_all_uses_where(init, one, |u| *u == id)?;
                     let final_op = edit.add_node(Node::Binary {
                         op,
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 7451b0ad..b44ed8df 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -117,7 +117,31 @@ pub fn unforkify_all(
     }
 }
 
-pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_tree: &LoopTree) {
+pub fn unforkify_one(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    loop_tree: &LoopTree,
+) {
+    for l in loop_tree.bottom_up_loops().into_iter().rev() {
+        if !editor.node(l.0).is_fork() {
+            continue;
+        }
+
+        let fork = l.0;
+        let join = fork_join_map[&fork];
+
+        if unforkify(editor, fork, join, loop_tree) {
+            break;
+        }
+    }
+}
+
+pub fn unforkify(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    join: NodeID,
+    loop_tree: &LoopTree,
+) -> bool {
     let mut zero_cons_id = ConstantID::new(0);
     let mut one_cons_id = ConstantID::new(0);
     assert!(editor.edit(|mut edit| {
@@ -138,7 +162,7 @@ pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_t
     if factors.len() > 1 {
         // For now, don't convert multi-dimensional fork-joins. Rely on pass
         // that splits fork-joins.
-        return;
+        return false;
     }
     let join_control = nodes[join.idx()].try_join().unwrap();
     let tids: Vec<_> = editor
@@ -296,5 +320,5 @@ pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_t
         }
 
         Ok(edit)
-    });
+    })
 }
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index 734054ab..aa87972e 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -17,17 +17,22 @@ clean-monoid-reduces(*);
 
 let out = outline(split_out.dot.fj1);
 ip-sroa(*);
-sroa(dot);
-gvn(dot);
-dce(dot);
+sroa(*);
+gvn(*);
+dce(*);
 
 let fission_out = fork-fission[out@loop](dot);
 simplify-cfg(dot);
 dce(dot);
 unforkify(fission_out.dot.fj_loop_bottom);
 ccp(dot);
+simplify-cfg(dot);
 gvn(dot);
 dce(dot);
 
-unforkify(out);
+unforkify-one(out);
+ccp(out);
+simplify-cfg(out);
+gvn(out);
+dce(out);
 gcm(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 88816562..fc2a729e 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -144,6 +144,7 @@ impl FromStr for Appliable {
             "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
             "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
             "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)),
+            "unforkify-one" => Ok(Appliable::Pass(ir::Pass::UnforkifyOne)),
             "fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)),
             "verify" => Ok(Appliable::Pass(ir::Pass::Verify)),
             "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 11cf6b13..bf3fe037 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -38,6 +38,7 @@ pub enum Pass {
     Serialize,
     SimplifyCFG,
     Unforkify,
+    UnforkifyOne,
     Verify,
     WritePredication,
     Xdot,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 34c2474b..8db79b46 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2364,6 +2364,28 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::UnforkifyOne => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            pm.make_loops();
+
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
+
+            for ((func, fork_join_map), loop_tree) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(fork_join_maps.iter())
+                .zip(loops.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                unforkify_one(&mut func, fork_join_map, loop_tree);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkChunk => {
             assert_eq!(args.len(), 4);
             let Some(Value::Integer { val: tile_size }) = args.get(0) else {
-- 
GitLab


From 5a1fa18bb347ff034c2adea0e924ecf2ee73425f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 16 Feb 2025 09:33:18 -0600
Subject: [PATCH 09/10] Lower read and write in rt backend

---
 hercules_cg/src/rt.rs | 34 +++++++++++++++++++++++++++++-----
 1 file changed, 29 insertions(+), 5 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index f758fed3..8c5775d8 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -622,9 +622,26 @@ impl<'a> RTContext<'a> {
             } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 let collect_ty = self.typing[collect.idx()];
-                let out_size = self.codegen_type_size(self.typing[id.idx()]);
+                let self_ty = self.typing[id.idx()];
                 let offset = self.codegen_index_math(collect_ty, indices, bb)?;
-                todo!();
+                if self.module.types[self_ty.idx()].is_primitive() {
+                    write!(
+                        block,
+                        "{} = ({}.byte_add({} as usize).0 as *mut {}).read();",
+                        self.get_value(id, bb, true),
+                        self.get_value(collect, bb, false),
+                        offset,
+                        self.get_type(self_ty)
+                    )?;
+                } else {
+                    write!(
+                        block,
+                        "{} = {}.byte_add({} as usize);",
+                        self.get_value(id, bb, true),
+                        self.get_value(collect, bb, false),
+                        offset,
+                    )?;
+                }
             }
             Node::Write {
                 collect,
@@ -633,11 +650,18 @@ impl<'a> RTContext<'a> {
             } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 let collect_ty = self.typing[collect.idx()];
-                let data_size = self.codegen_type_size(self.typing[data.idx()]);
-                let offset = self.codegen_index_math(collect_ty, indices, bb)?;
                 let data_ty = self.typing[data.idx()];
+                let data_size = self.codegen_type_size(data_ty);
+                let offset = self.codegen_index_math(collect_ty, indices, bb)?;
                 if self.module.types[data_ty.idx()].is_primitive() {
-                    todo!();
+                    write!(
+                        block,
+                        "({}.byte_add({} as usize).0 as *mut {}).write({});",
+                        self.get_value(collect, bb, false),
+                        offset,
+                        self.get_type(data_ty),
+                        self.get_value(data, bb, false),
+                    )?;
                 } else {
                     // If the data item being written is not a primitive type,
                     // then perform a memcpy from the data collection to the
-- 
GitLab


From 85ce14bde0ca675bfa967fb35dc746eea343ca86 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 16 Feb 2025 10:15:36 -0600
Subject: [PATCH 10/10] enough rt backend stuff for multi-threaded dot

---
 hercules_cg/src/rt.rs        | 69 ++++++++++++++++++++++++++++++++----
 hercules_opt/src/gcm.rs      |  9 +++--
 juno_samples/dot/src/cpu.sch |  5 +--
 juno_samples/dot/src/main.rs |  2 +-
 4 files changed, 73 insertions(+), 12 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 8c5775d8..4d9a6cf6 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -139,7 +139,10 @@ struct RTContext<'a> {
 struct RustBlock {
     prologue: String,
     data: String,
+    phi_tmp_assignments: String,
+    phi_assignments: String,
     epilogue: String,
+    join_epilogue: String,
 }
 
 impl<'a> RTContext<'a> {
@@ -251,7 +254,28 @@ impl<'a> RTContext<'a> {
         // fork and join nodes open and close environments, respectively.
         for id in rev_po.iter() {
             let block = &blocks[id];
-            write!(w, "{}{}{}", block.prologue, block.data, block.epilogue)?;
+            if func.nodes[id.idx()].is_join() {
+                write!(
+                    w,
+                    "{}{}{}{}{}{}",
+                    block.prologue,
+                    block.data,
+                    block.epilogue,
+                    block.phi_tmp_assignments,
+                    block.phi_assignments,
+                    block.join_epilogue
+                )?;
+            } else {
+                write!(
+                    w,
+                    "{}{}{}{}{}",
+                    block.prologue,
+                    block.data,
+                    block.phi_tmp_assignments,
+                    block.phi_assignments,
+                    block.epilogue
+                )?;
+            }
         }
 
         // Close the root environment.
@@ -367,7 +391,10 @@ impl<'a> RTContext<'a> {
 
                 // Close the branch inside the async closure.
                 let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
-                write!(epilogue, "return;}}")?;
+                write!(
+                    epilogue,
+                    "::std::sync::atomic::fence(::std::sync::atomic::Ordering::Release);return;}}"
+                )?;
 
                 // Close the fork's environment.
                 self.codegen_close_environment(epilogue)?;
@@ -405,9 +432,10 @@ impl<'a> RTContext<'a> {
                     }
                 }
 
+                let join_epilogue = &mut blocks.get_mut(&id).unwrap().join_epilogue;
                 // Branch to the successor control node in the surrounding
                 // context, and close the branch for the join.
-                write!(epilogue, "control_token = {};}}", succ.idx())?;
+                write!(join_epilogue, "control_token = {};}}", succ.idx())?;
             }
             _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
         }
@@ -481,15 +509,39 @@ impl<'a> RTContext<'a> {
                 write!(block, ";")?;
             }
             Node::ThreadID { control, dimension } => {
+                assert_eq!(control, bb);
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 write!(
                     block,
                     "{} = tid_{}_{};",
                     self.get_value(id, bb, true),
-                    control.idx(),
+                    bb.idx(),
                     dimension
                 )?;
             }
+            Node::Phi { control, ref data } => {
+                assert_eq!(control, bb);
+                // Phis aren't executable in their own basic block - predecessor
+                // blocks assign the to-be phi values themselves. Assign
+                // temporary values first before assigning the phi itself, since
+                // there may be simultaneous inter-dependent phis.
+                for (data, pred) in zip(data.into_iter(), self.control_subgraph.preds(bb)) {
+                    let block = &mut blocks.get_mut(&pred).unwrap().phi_tmp_assignments;
+                    write!(
+                        block,
+                        "let {}_tmp = {};",
+                        self.get_value(id, pred, true),
+                        self.get_value(*data, pred, false),
+                    )?;
+                    let block = &mut blocks.get_mut(&pred).unwrap().phi_assignments;
+                    write!(
+                        block,
+                        "{} = {}_tmp;",
+                        self.get_value(id, pred, true),
+                        self.get_value(id, pred, false),
+                    )?;
+                }
+            }
             Node::Reduce {
                 control: _,
                 init: _,
@@ -498,11 +550,12 @@ impl<'a> RTContext<'a> {
                 assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce));
             }
             Node::Call {
-                control: _,
+                control,
                 function: callee_id,
                 ref dynamic_constants,
                 ref args,
             } => {
+                assert_eq!(control, bb);
                 // The device backends ensure that device functions have the
                 // same interface as AsyncRust functions.
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
@@ -975,7 +1028,9 @@ impl<'a> RTContext<'a> {
                     if is_reduce_on_child { "reduce" } else { "node" },
                     idx,
                     self.get_type(self.typing[idx]),
-                    if self.module.types[self.typing[idx].idx()].is_integer() {
+                    if self.module.types[self.typing[idx].idx()].is_bool() {
+                        "false"
+                    } else if self.module.types[self.typing[idx].idx()].is_integer() {
                         "0"
                     } else if self.module.types[self.typing[idx].idx()].is_float() {
                         "0.0"
@@ -1241,7 +1296,7 @@ impl<'a> RTContext<'a> {
             // Before using the value of a reduction outside the fork-join,
             // await the futures.
             format!(
-                "{{for fut in fork_{}.drain(..) {{ fut.await; }}; reduce_{}}}",
+                "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
                 fork.idx(),
                 id.idx()
             )
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 99c44d52..821d02ea 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -95,9 +95,11 @@ pub fn gcm(
 
     let bbs = basic_blocks(
         editor.func(),
+        editor.get_types(),
         editor.func_id(),
         def_use,
         reverse_postorder,
+        typing,
         dom,
         loops,
         reduce_cycles,
@@ -218,9 +220,11 @@ fn preliminary_fixups(
  */
 fn basic_blocks(
     function: &Function,
+    types: Ref<Vec<Type>>,
     func_id: FunctionID,
     def_use: &ImmutableDefUseMap,
     reverse_postorder: &Vec<NodeID>,
+    typing: &Vec<TypeID>,
     dom: &DomTree,
     loops: &LoopTree,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
@@ -498,8 +502,9 @@ fn basic_blocks(
                 // control dependent as possible, even inside loops. In GPU
                 // functions specifically, lift constants that may be returned
                 // outside fork-joins.
-                let is_constant_or_undef =
-                    function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef();
+                let is_constant_or_undef = (function.nodes[id.idx()].is_constant()
+                    || function.nodes[id.idx()].is_undef())
+                    && !types[typing[id.idx()].idx()].is_primitive();
                 let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
                     && objects[&func_id]
                         .objects(id)
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index aa87972e..1f8953d9 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -8,7 +8,7 @@ fork-guard-elim(dot);
 dce(dot);
 
 fork-tile[8, 0, false, true](dot);
-fork-tile[32, 0, false, false](dot);
+fork-tile[8, 0, false, false](dot);
 let split_out = fork-split(dot);
 infer-schedules(*);
 clean-monoid-reduces(*);
@@ -29,8 +29,9 @@ ccp(dot);
 simplify-cfg(dot);
 gvn(dot);
 dce(dot);
+infer-schedules(dot);
 
-unforkify-one(out);
+unforkify(out);
 ccp(out);
 simplify-cfg(out);
 gvn(out);
diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs
index b73f8710..bd887194 100644
--- a/juno_samples/dot/src/main.rs
+++ b/juno_samples/dot/src/main.rs
@@ -9,7 +9,7 @@ juno_build::juno!("dot");
 
 fn main() {
     async_std::task::block_on(async {
-        const N: u64 = 1024 * 1024;
+        const N: u64 = 1024 * 8;
         let a: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect();
         let b: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect();
         let a_herc = HerculesImmBox::from(&a as &[i64]);
-- 
GitLab