From f1f319bf28aed5a1b988813aa166fb43b267227f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 13 Feb 2025 14:26:08 -0600
Subject: [PATCH 01/15] multi-core scale

---
 juno_samples/cava/build.rs    |  2 ++
 juno_samples/cava/src/cava.jn |  4 ++--
 juno_samples/cava/src/cpu.sch | 44 +++++++++++++++++++++++++++++++++++
 juno_samples/cava/src/main.rs |  6 ++---
 juno_scheduler/src/compile.rs |  2 ++
 juno_scheduler/src/ir.rs      |  6 ++---
 juno_scheduler/src/lang.l     |  4 ++--
 juno_scheduler/src/pm.rs      |  3 +++
 8 files changed, 60 insertions(+), 11 deletions(-)
 create mode 100644 juno_samples/cava/src/cpu.sch

diff --git a/juno_samples/cava/build.rs b/juno_samples/cava/build.rs
index 1b6dddf4..ff7e2b6b 100644
--- a/juno_samples/cava/build.rs
+++ b/juno_samples/cava/build.rs
@@ -13,6 +13,8 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("cava.jn")
         .unwrap()
+        .schedule_in_src("cpu.sch")
+        .unwrap()
         .build()
         .unwrap();
 }
diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index 95a73f5b..bb8afded 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -22,12 +22,12 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a {
 const CHAN : u64 = 3;
 
 fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] {
-  let res : f32[CHAN, row, col];
+  @const let res : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
       for c = 0 to col {
-        res[chan, r, c] = input[chan, r, c] as f32 * 1.0 / 255;
+        res[chan, r, c] = input[chan, r, c] as f32 / 255;
       }
     }
   }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
new file mode 100644
index 00000000..0913d4b3
--- /dev/null
+++ b/juno_samples/cava/src/cpu.sch
@@ -0,0 +1,44 @@
+macro simpl!(X) {
+  ccp(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  infer-schedules(X);
+}
+
+simpl!(*);
+
+inline(denoise);
+cpu(scale, demosaic, denoise, transform, gamut, tone_map, descale);
+
+ip-sroa(*);
+sroa(*);
+simpl!(*);
+
+no-memset(scale@const);
+fixpoint {
+  forkify(scale);
+  fork-guard-elim(scale);
+  fork-coalesce(scale);
+}
+simpl!(*);
+fork-dim-merge(scale);
+simpl!(*);
+fork-tile[2048, 0, false](scale);
+simpl!(*);
+let out = fork-split(scale);
+simpl!(*);
+let out = outline(out._0_scale.fj1);
+ip-sroa(*);
+sroa(*);
+simpl!(*);
+host(scale);
+unforkify(out);
+xdot[true](scale, out);
+
+gcm(*);
+fixpoint {
+  float-collections(*);
+  dce(*);
+  gcm(*);
+}
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index b4a0f6fd..a940d6eb 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -159,6 +159,7 @@ fn cava_harness(args: CavaInputs) {
         tonemap,
     } = load_cam_model(cam_model, CHAN).expect("Error loading camera model");
 
+    println!("Running cava with {} rows, {} columns, and {} control points.", rows, cols, num_ctrl_pts);
     let result = run_cava(
         rows,
         cols,
@@ -227,10 +228,8 @@ fn cava_test_small() {
     });
 }
 
-// Disabling the larger test because of how long it takes
-/*
 #[test]
-fn cava_test() {
+fn cava_test_full() {
     cava_harness(CavaInputs {
         input: "examples/raw_tulips.bin".to_string(),
         output: None,
@@ -239,4 +238,3 @@ fn cava_test() {
         cam_model: "cam_models/NikonD7000".to_string(),
     });
 }
-*/
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 7887b9b3..237ff3b9 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -130,6 +130,8 @@ impl FromStr for Appliable {
             "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)),
             "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)),
 
+            "print" => Ok(Appliable::Pass(ir::Pass::Print)),
+
             "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)),
             "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)),
             "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index a9ee7956..4b88d6a2 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -27,6 +27,7 @@ pub enum Pass {
     Outline,
     PhiElim,
     Predication,
+    Print,
     ReduceSLF,
     ReuseProducts,
     SLF,
@@ -42,10 +43,9 @@ pub enum Pass {
 impl Pass {
     pub fn num_args(&self) -> usize {
         match self {
-            Pass::Xdot => 1,
+            Pass::Xdot | Pass::Print => 1,
+            Pass::ForkFissionBufferize | Pass::ForkInterchange => 2,
             Pass::ForkChunk => 4,
-            Pass::ForkFissionBufferize => 2,
-            Pass::ForkInterchange => 2,
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l
index 9d4c34bf..2f34f01f 100644
--- a/juno_scheduler/src/lang.l
+++ b/juno_scheduler/src/lang.l
@@ -43,8 +43,8 @@ panic[\t \n\r]+after "panic_after"
 print[\t \n\r]+iter  "print_iter"
 stop[\t \n\r]+after  "stop_after"
 
-[a-zA-Z][a-zA-Z0-9_\-]*! "MACRO"
-[a-zA-Z][a-zA-Z0-9_\-]*  "ID"
+[a-zA-Z_][a-zA-Z0-9_\-]*! "MACRO"
+[a-zA-Z_][a-zA-Z0-9_\-]*  "ID"
 [0-9]+                   "INT"
 
 .                     "UNMATCHED"
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index de725608..45b36fcc 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2423,6 +2423,9 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
+        Pass::Print => {
+            println!("{:?}", args.get(0));
+        }
     }
     println!("Ran Pass: {:?}", pass);
 
-- 
GitLab


From aedcd82e7c8e65195cf92ed73c59941e886cd9b2 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 13 Feb 2025 22:11:27 -0600
Subject: [PATCH 02/15] Fuse scale and demosaic

---
 juno_samples/cava/src/cava.jn |  4 ++--
 juno_samples/cava/src/cpu.sch | 33 ++++++++++++++-------------------
 2 files changed, 16 insertions(+), 21 deletions(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index bb8afded..a06c8d7b 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -207,8 +207,8 @@ fn cava<r, c, num_ctrl_pts : usize>(
   coefs : f32[4, CHAN],
   tonemap : f32[256, CHAN],
 ) -> u8[CHAN, r, c] {
-  let scaled = scale::<r, c>(input);
-  let demosc = demosaic::<r, c>(scaled);
+  @fuse1 let scaled = scale::<r, c>(input);
+  @fuse1 let demosc = demosaic::<r, c>(scaled);
   let denosd = denoise::<r, c>(demosc);
   let transf = transform::<r, c>(denosd, TsTw);
   let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 0913d4b3..f090eab3 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -1,5 +1,6 @@
 macro simpl!(X) {
   ccp(X);
+  simplify-cfg(X);
   gvn(X);
   phi-elim(X);
   dce(X);
@@ -8,33 +9,27 @@ macro simpl!(X) {
 
 simpl!(*);
 
+let fuse1 = outline(cava@fuse1);
+inline(fuse1);
+
 inline(denoise);
-cpu(scale, demosaic, denoise, transform, gamut, tone_map, descale);
+cpu(denoise, transform, gamut, tone_map, descale);
 
 ip-sroa(*);
 sroa(*);
 simpl!(*);
 
-no-memset(scale@const);
+no-memset(fuse1@const);
 fixpoint {
-  forkify(scale);
-  fork-guard-elim(scale);
-  fork-coalesce(scale);
+  forkify(fuse1);
+  fork-guard-elim(fuse1);
+  fork-coalesce(fuse1);
 }
-simpl!(*);
-fork-dim-merge(scale);
-simpl!(*);
-fork-tile[2048, 0, false](scale);
-simpl!(*);
-let out = fork-split(scale);
-simpl!(*);
-let out = outline(out._0_scale.fj1);
-ip-sroa(*);
-sroa(*);
-simpl!(*);
-host(scale);
-unforkify(out);
-xdot[true](scale, out);
+simpl!(fuse1);
+array-slf(fuse1);
+simpl!(fuse1);
+xdot[true](fuse1);
+unforkify(fuse1);
 
 gcm(*);
 fixpoint {
-- 
GitLab


From 619fad446643b683f3cb6c1622821fce0500a9a2 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 13 Feb 2025 22:31:29 -0600
Subject: [PATCH 03/15] Start optimizing denoise, need forkify fixes +
 array2prod still

---
 juno_samples/cava/src/cava.jn |  8 ++++----
 juno_samples/cava/src/cpu.sch | 22 +++++++++++++++++++---
 2 files changed, 23 insertions(+), 7 deletions(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index a06c8d7b..720629e7 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -1,7 +1,7 @@
 fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a {
   const n : usize = rows * cols;
 
-  let tmp : a[rows * cols];
+  @tmp let tmp : a[rows * cols];
   for i = 0 to rows * cols {
     tmp[i] = m[i / cols, i % cols];
   }
@@ -102,13 +102,13 @@ fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN,
 }
 
 fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
-  let res : f32[CHAN, row, col];
+  @res let res : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
       for c = 0 to col {
         if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 {
-          let filter : f32[3][3]; // same as [3, 3]
+          @filter let filter : f32[3][3]; // same as [3, 3]
           for i = 0 to 3 by 1 {
             for j = 0 to 3 by 1 {
               filter[i, j] = input[chan, r + i - 1, c + j - 1];
@@ -209,7 +209,7 @@ fn cava<r, c, num_ctrl_pts : usize>(
 ) -> u8[CHAN, r, c] {
   @fuse1 let scaled = scale::<r, c>(input);
   @fuse1 let demosc = demosaic::<r, c>(scaled);
-  let denosd = denoise::<r, c>(demosc);
+  @fuse2 let denosd = denoise::<r, c>(demosc);
   let transf = transform::<r, c>(denosd, TsTw);
   let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
   let tonemd = tone_map::<r, c>(gamutd, tonemap);
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index f090eab3..b0479f5c 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -1,6 +1,7 @@
 macro simpl!(X) {
   ccp(X);
   simplify-cfg(X);
+  lift-dc-math(X);
   gvn(X);
   phi-elim(X);
   dce(X);
@@ -12,8 +13,8 @@ simpl!(*);
 let fuse1 = outline(cava@fuse1);
 inline(fuse1);
 
-inline(denoise);
-cpu(denoise, transform, gamut, tone_map, descale);
+let fuse2 = outline(cava@fuse2);
+inline(fuse2);
 
 ip-sroa(*);
 sroa(*);
@@ -28,9 +29,24 @@ fixpoint {
 simpl!(fuse1);
 array-slf(fuse1);
 simpl!(fuse1);
-xdot[true](fuse1);
 unforkify(fuse1);
 
+inline(fuse2);
+no-memset(fuse2@res);
+no-memset(fuse2@filter);
+no-memset(fuse2@tmp);
+fixpoint {
+  forkify(fuse2);
+  fork-guard-elim(fuse2);
+  fork-coalesce(fuse2);
+}
+simpl!(fuse2);
+array-slf(fuse2);
+simpl!(fuse2);
+array-slf(fuse2);
+simpl!(fuse2);
+xdot[true](fuse2);
+
 gcm(*);
 fixpoint {
   float-collections(*);
-- 
GitLab


From 8e8ed5ef8321f9312251262ce5636daf234275e5 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 09:33:49 -0600
Subject: [PATCH 04/15] fixes

---
 hercules_cg/src/rt.rs         |  2 +-
 juno_samples/cava/src/cava.jn |  3 +-
 juno_samples/cava/src/cpu.sch | 13 +++----
 juno_samples/cava/src/main.rs | 66 +++++++++++++++++++++--------------
 4 files changed, 46 insertions(+), 38 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 090253d4..f758fed3 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -533,7 +533,7 @@ impl<'a> RTContext<'a> {
                 {
                     write!(block, "backing_{}.byte_add(", device.name())?;
                     self.codegen_dynamic_constant(offset, block)?;
-                    write!(block, "), ")?
+                    write!(block, " as usize), ")?
                 }
                 for dc in dynamic_constants {
                     self.codegen_dynamic_constant(*dc, block)?;
diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index 720629e7..de21ba78 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -213,5 +213,6 @@ fn cava<r, c, num_ctrl_pts : usize>(
   let transf = transform::<r, c>(denosd, TsTw);
   let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
   let tonemd = tone_map::<r, c>(gamutd, tonemap);
-  return descale::<r, c>(tonemd);
+  let dscald = descale::<r, c>(tonemd);
+  return dscald;
 }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index b0479f5c..c75260ca 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -43,13 +43,8 @@ fixpoint {
 simpl!(fuse2);
 array-slf(fuse2);
 simpl!(fuse2);
-array-slf(fuse2);
-simpl!(fuse2);
-xdot[true](fuse2);
+unforkify(fuse2);
 
-gcm(*);
-fixpoint {
-  float-collections(*);
-  dce(*);
-  gcm(*);
-}
+delete-uncalled(*);
+
+gcm(*);
\ No newline at end of file
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index a940d6eb..700b74e6 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -19,24 +19,31 @@ juno_build::juno!("cava");
 // Individual lifetimes are not needed in this example but should probably be generated for
 // flexibility
 async fn safe_run<'a, 'b: 'a, 'c: 'a, 'd: 'a, 'e: 'a, 'f: 'a, 'g: 'a>(
-    runner: &'a mut HerculesRunner_cava, r: u64, c: u64, num_ctrl_pts: u64,
-    input: &'b HerculesImmBox<'b, u8>, tstw: &'c HerculesImmBox<'c, f32>,
-    ctrl_pts: &'d HerculesImmBox<'d, f32>, weights: &'e HerculesImmBox<'e, f32>,
-    coefs: &'f HerculesImmBox<'f, f32>, tonemap: &'g HerculesImmBox<'g, f32>,
+    runner: &'a mut HerculesRunner_cava,
+    r: u64,
+    c: u64,
+    num_ctrl_pts: u64,
+    input: &'b HerculesImmBox<'b, u8>,
+    tstw: &'c HerculesImmBox<'c, f32>,
+    ctrl_pts: &'d HerculesImmBox<'d, f32>,
+    weights: &'e HerculesImmBox<'e, f32>,
+    coefs: &'f HerculesImmBox<'f, f32>,
+    tonemap: &'g HerculesImmBox<'g, f32>,
 ) -> HerculesMutBox<'a, u8> {
     HerculesMutBox::from(
-        runner.run(
-            r,
-            c,
-            num_ctrl_pts,
-            input.to(),
-            tstw.to(),
-            ctrl_pts.to(),
-            weights.to(),
-            coefs.to(),
-            tonemap.to()
-        )
-        .await
+        runner
+            .run(
+                r,
+                c,
+                num_ctrl_pts,
+                input.to(),
+                tstw.to(),
+                ctrl_pts.to(),
+                weights.to(),
+                coefs.to(),
+                tonemap.to(),
+            )
+            .await,
     )
 }
 
@@ -68,16 +75,17 @@ fn run_cava(
     let mut r = runner!(cava);
 
     async_std::task::block_on(async {
-        safe_run(&mut r,
-                 rows as u64,
-                 cols as u64,
-                 num_ctrl_pts as u64,
-                 &image,
-                 &tstw,
-                 &ctrl_pts,
-                 &weights,
-                 &coefs,
-                 &tonemap,
+        safe_run(
+            &mut r,
+            rows as u64,
+            cols as u64,
+            num_ctrl_pts as u64,
+            &image,
+            &tstw,
+            &ctrl_pts,
+            &weights,
+            &coefs,
+            &tonemap,
         )
         .await
     })
@@ -159,7 +167,10 @@ fn cava_harness(args: CavaInputs) {
         tonemap,
     } = load_cam_model(cam_model, CHAN).expect("Error loading camera model");
 
-    println!("Running cava with {} rows, {} columns, and {} control points.", rows, cols, num_ctrl_pts);
+    println!(
+        "Running cava with {} rows, {} columns, and {} control points.",
+        rows, cols, num_ctrl_pts
+    );
     let result = run_cava(
         rows,
         cols,
@@ -229,6 +240,7 @@ fn cava_test_small() {
 }
 
 #[test]
+#[ignore]
 fn cava_test_full() {
     cava_harness(CavaInputs {
         input: "examples/raw_tulips.bin".to_string(),
-- 
GitLab


From 5ba0636be767b7421c6db89ce6773b71bc792b4c Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 09:43:11 -0600
Subject: [PATCH 05/15] I am surprised these forkify

---
 juno_samples/cava/src/cava.jn | 14 +++++++-------
 juno_samples/cava/src/cpu.sch | 29 +++++++++++++++++++++++++++--
 2 files changed, 34 insertions(+), 9 deletions(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index de21ba78..29fc4df5 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -22,7 +22,7 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a {
 const CHAN : u64 = 3;
 
 fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] {
-  @const let res : f32[CHAN, row, col];
+  @res1 let res : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
@@ -50,7 +50,7 @@ fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, ro
 }
 
 fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
-  let res : f32[CHAN, row, col];
+  @res2 let res : f32[CHAN, row, col];
 
   for r = 1 to row-1 {
     for c = 1 to col-1 {
@@ -129,7 +129,7 @@ fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, r
 fn transform<row : usize, col : usize>
   (input : f32[CHAN, row, col], tstw_trans : f32[CHAN, CHAN])
     -> f32[CHAN, row, col] {
-  let result : f32[CHAN, row, col];
+  @res let result : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
@@ -152,11 +152,11 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
   weights  : f32[num_ctrl_pts, CHAN],
   coefs : f32[4, CHAN]
 ) -> f32[CHAN, row, col] {
-  let result : f32[CHAN, row, col];
-  let l2_dist : f32[num_ctrl_pts];
+  @res let result : f32[CHAN, row, col];
 
   for r = 0 to row {
     for c = 0 to col {
+      @l2 let l2_dist : f32[num_ctrl_pts];
       for cp = 0 to num_ctrl_pts {
         let v1 = input[0, r, c] - ctrl_pts[cp, 0];
         let v2 = input[1, r, c] - ctrl_pts[cp, 1];
@@ -210,8 +210,8 @@ fn cava<r, c, num_ctrl_pts : usize>(
   @fuse1 let scaled = scale::<r, c>(input);
   @fuse1 let demosc = demosaic::<r, c>(scaled);
   @fuse2 let denosd = denoise::<r, c>(demosc);
-  let transf = transform::<r, c>(denosd, TsTw);
-  let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
+  @fuse3 let transf = transform::<r, c>(denosd, TsTw);
+  @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
   let tonemd = tone_map::<r, c>(gamutd, tonemap);
   let dscald = descale::<r, c>(tonemd);
   return dscald;
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index c75260ca..8099c0ba 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -16,11 +16,18 @@ inline(fuse1);
 let fuse2 = outline(cava@fuse2);
 inline(fuse2);
 
+let fuse3 = outline(cava@fuse3);
+inline(fuse3);
+
+let fuse4 = outline(cava@fuse4);
+inline(fuse4);
+
 ip-sroa(*);
 sroa(*);
 simpl!(*);
 
-no-memset(fuse1@const);
+no-memset(fuse1@res1);
+no-memset(fuse1@res2);
 fixpoint {
   forkify(fuse1);
   fork-guard-elim(fuse1);
@@ -45,6 +52,24 @@ array-slf(fuse2);
 simpl!(fuse2);
 unforkify(fuse2);
 
-delete-uncalled(*);
+no-memset(fuse3@res);
+fixpoint {
+  forkify(fuse3);
+  fork-guard-elim(fuse3);
+  fork-coalesce(fuse3);
+}
+fork-split(fuse3);
+unforkify(fuse3);
 
+no-memset(fuse4@res);
+no-memset(fuse4@l2);
+fixpoint {
+  forkify(fuse4);
+  fork-guard-elim(fuse4);
+  fork-coalesce(fuse4);
+}
+fork-split(fuse4);
+unforkify(fuse4);
+
+delete-uncalled(*);
 gcm(*);
\ No newline at end of file
-- 
GitLab


From 1848d531851e3a815d2545a8e7471ef74711095c Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 09:48:40 -0600
Subject: [PATCH 06/15] Chefs kiss

---
 juno_samples/cava/src/cava.jn | 34 +++++++++++++++++-----------------
 juno_samples/cava/src/cpu.sch | 19 +++++++++++++++++++
 2 files changed, 36 insertions(+), 17 deletions(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index 29fc4df5..366792c3 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -35,20 +35,6 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row,
   return res;
 }
 
-fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] {
-  let res : u8[CHAN, row, col];
-
-  for chan = 0 to CHAN {
-    for r = 0 to row {
-      for c = 0 to col {
-        res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8;
-      }
-    }
-  }
-
-  return res;
-}
-
 fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
   @res2 let res : f32[CHAN, row, col];
 
@@ -184,7 +170,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
 
 fn tone_map<row : usize, col:usize>
   (input : f32[CHAN, row, col], tone_map : f32[256, CHAN]) -> f32[CHAN, row, col] {
-  let result : f32[CHAN, row, col];
+  @res1 let result : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
@@ -198,6 +184,20 @@ fn tone_map<row : usize, col:usize>
   return result;
 }
 
+fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] {
+  @res2 let res : u8[CHAN, row, col];
+
+  for chan = 0 to CHAN {
+    for r = 0 to row {
+      for c = 0 to col {
+        res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8;
+      }
+    }
+  }
+
+  return res;
+}
+
 #[entry]
 fn cava<r, c, num_ctrl_pts : usize>(
   input : u8[CHAN, r, c],
@@ -212,7 +212,7 @@ fn cava<r, c, num_ctrl_pts : usize>(
   @fuse2 let denosd = denoise::<r, c>(demosc);
   @fuse3 let transf = transform::<r, c>(denosd, TsTw);
   @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
-  let tonemd = tone_map::<r, c>(gamutd, tonemap);
-  let dscald = descale::<r, c>(tonemd);
+  @fuse5 let tonemd = tone_map::<r, c>(gamutd, tonemap);
+  @fuse5 let dscald = descale::<r, c>(tonemd);
   return dscald;
 }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 8099c0ba..d51b8b94 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -22,6 +22,9 @@ inline(fuse3);
 let fuse4 = outline(cava@fuse4);
 inline(fuse4);
 
+let fuse5 = outline(cava@fuse5);
+inline(fuse5);
+
 ip-sroa(*);
 sroa(*);
 simpl!(*);
@@ -58,6 +61,7 @@ fixpoint {
   fork-guard-elim(fuse3);
   fork-coalesce(fuse3);
 }
+simpl!(fuse3);
 fork-split(fuse3);
 unforkify(fuse3);
 
@@ -68,8 +72,23 @@ fixpoint {
   fork-guard-elim(fuse4);
   fork-coalesce(fuse4);
 }
+simpl!(fuse4);
 fork-split(fuse4);
 unforkify(fuse4);
 
+no-memset(fuse5@res1);
+no-memset(fuse5@res2);
+fixpoint {
+  forkify(fuse5);
+  fork-guard-elim(fuse5);
+  fork-coalesce(fuse5);
+}
+simpl!(fuse5);
+array-slf(fuse5);
+simpl!(fuse5);
+fork-split(fuse5);
+unforkify(fuse5);
+
+simpl!(*);
 delete-uncalled(*);
 gcm(*);
\ No newline at end of file
-- 
GitLab


From cb35f6a1b0b4748ceac295075f5ab138cbde5424 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 14:17:06 -0600
Subject: [PATCH 07/15] Works!

---
 juno_samples/cava/src/cpu.sch | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 2ef7152f..246d10d4 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -75,7 +75,8 @@ inline(fuse2);
 ip-sroa(*);
 sroa(*);
 array-slf(fuse2);
-xdot[true](fuse2);
+write-predication(fuse2);
+simpl!(fuse2);
 fork-split(fuse2);
 unforkify(fuse2);
 
-- 
GitLab


From c5d8d401362ca444ecb73afb47c52f021d360b61 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 14:21:43 -0600
Subject: [PATCH 08/15] whoops

---
 juno_samples/cava/src/cava.jn |   2 +-
 juno_samples/cava/src/cpu.sch | 140 +++++++++++++++++-----------------
 2 files changed, 71 insertions(+), 71 deletions(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index f95f1ebc..839e80e2 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -29,7 +29,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row,
   for chan = 0 to CHAN {
     for r = 0 to row {
       for c = 0 to col {
-        res[chan, r, c] = input[chan, r, c] as f32 * (1.0 / 255.0);
+        res[chan, r, c] = input[chan, r, c] as f32 / 255.0;
       }
     }
   }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 246d10d4..5c5428ae 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -27,92 +27,92 @@ inline(fuse5);
 
 ip-sroa(*);
 sroa(*);
-simpl!(*);
+//simpl!(*);
 
 no-memset(fuse1@res1);
 no-memset(fuse1@res2);
-fixpoint {
-  forkify(fuse1);
-  fork-guard-elim(fuse1);
-  fork-coalesce(fuse1);
-}
-simpl!(fuse1);
-array-slf(fuse1);
-simpl!(fuse1);
-unforkify(fuse1);
+//fixpoint {
+//  forkify(fuse1);
+//  fork-guard-elim(fuse1);
+//  fork-coalesce(fuse1);
+//}
+//simpl!(fuse1);
+//array-slf(fuse1);
+//simpl!(fuse1);
+//unforkify(fuse1);
 
 inline(fuse2);
 no-memset(fuse2@res);
 no-memset(fuse2@filter);
 no-memset(fuse2@tmp);
-fixpoint {
-  forkify(fuse2);
-  fork-guard-elim(fuse2);
-  fork-coalesce(fuse2);
-}
-simpl!(fuse2);
-predication(fuse2);
-simpl!(fuse2);
-
-let median = outline(fuse2@median);
-fork-unroll(median@medianOuter);
-simpl!(median);
-fixpoint {
-  forkify(median);
-  fork-guard-elim(median);
-}
-simpl!(median);
-fixpoint {
-  fork-unroll(median);
-}
-ccp(median);
-array-to-product(median);
-sroa(median);
-predication(median);
-simpl!(median);
-
-inline(fuse2);
-ip-sroa(*);
-sroa(*);
-array-slf(fuse2);
-write-predication(fuse2);
-simpl!(fuse2);
-fork-split(fuse2);
-unforkify(fuse2);
+//fixpoint {
+//  forkify(fuse2);
+//  fork-guard-elim(fuse2);
+//  fork-coalesce(fuse2);
+//}
+//simpl!(fuse2);
+//predication(fuse2);
+//simpl!(fuse2);
+//
+//let median = outline(fuse2@median);
+//fork-unroll(median@medianOuter);
+//simpl!(median);
+//fixpoint {
+//  forkify(median);
+//  fork-guard-elim(median);
+//}
+//simpl!(median);
+//fixpoint {
+//  fork-unroll(median);
+//}
+//ccp(median);
+//array-to-product(median);
+//sroa(median);
+//predication(median);
+//simpl!(median);
+//
+//inline(fuse2);
+//ip-sroa(*);
+//sroa(*);
+//array-slf(fuse2);
+//write-predication(fuse2);
+//simpl!(fuse2);
+//fork-split(fuse2);
+//unforkify(fuse2);
 
 no-memset(fuse3@res);
-fixpoint {
-  forkify(fuse3);
-  fork-guard-elim(fuse3);
-  fork-coalesce(fuse3);
-}
-simpl!(fuse3);
-fork-split(fuse3);
-unforkify(fuse3);
+//fixpoint {
+//  forkify(fuse3);
+//  fork-guard-elim(fuse3);
+//  fork-coalesce(fuse3);
+//}
+//simpl!(fuse3);
+//fork-split(fuse3);
+//unforkify(fuse3);
 
 no-memset(fuse4@res);
 no-memset(fuse4@l2);
-fixpoint {
-  forkify(fuse4);
-  fork-guard-elim(fuse4);
-  fork-coalesce(fuse4);
-}
-simpl!(fuse4);
-fork-split(fuse4);
-unforkify(fuse4);
+//fixpoint {
+//  forkify(fuse4);
+//  fork-guard-elim(fuse4);
+//  fork-coalesce(fuse4);
+//}
+//simpl!(fuse4);
+//fork-split(fuse4);
+//unforkify(fuse4);
 
 no-memset(fuse5@res1);
 no-memset(fuse5@res2);
-fixpoint {
-  forkify(fuse5);
-  fork-guard-elim(fuse5);
-  fork-coalesce(fuse5);
-}
-simpl!(fuse5);
-array-slf(fuse5);
-simpl!(fuse5);
-fork-split(fuse5);
-unforkify(fuse5);
+//fixpoint {
+//  forkify(fuse5);
+//  fork-guard-elim(fuse5);
+//  fork-coalesce(fuse5);
+//}
+//simpl!(fuse5);
+//array-slf(fuse5);
+//simpl!(fuse5);
+//fork-split(fuse5);
+//unforkify(fuse5);
 
 simpl!(*);
 delete-uncalled(*);
-- 
GitLab


From cd0f1635fc0c2beb088ab11b470f092e2b6949b0 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 14:22:45 -0600
Subject: [PATCH 09/15] actually works

---
 juno_samples/cava/src/cpu.sch | 140 +++++++++++++++++-----------------
 1 file changed, 70 insertions(+), 70 deletions(-)

diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 5c5428ae..246d10d4 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -27,92 +27,92 @@ inline(fuse5);
 
 ip-sroa(*);
 sroa(*);
-//simpl!(*);
+simpl!(*);
 
 no-memset(fuse1@res1);
 no-memset(fuse1@res2);
-//fixpoint {
-//  forkify(fuse1);
-//  fork-guard-elim(fuse1);
-//  fork-coalesce(fuse1);
-//}
-//simpl!(fuse1);
-//array-slf(fuse1);
-//simpl!(fuse1);
-//unforkify(fuse1);
+fixpoint {
+  forkify(fuse1);
+  fork-guard-elim(fuse1);
+  fork-coalesce(fuse1);
+}
+simpl!(fuse1);
+array-slf(fuse1);
+simpl!(fuse1);
+unforkify(fuse1);
 
 inline(fuse2);
 no-memset(fuse2@res);
 no-memset(fuse2@filter);
 no-memset(fuse2@tmp);
-//fixpoint {
-//  forkify(fuse2);
-//  fork-guard-elim(fuse2);
-//  fork-coalesce(fuse2);
-//}
-//simpl!(fuse2);
-//predication(fuse2);
-//simpl!(fuse2);
-//
-//let median = outline(fuse2@median);
-//fork-unroll(median@medianOuter);
-//simpl!(median);
-//fixpoint {
-//  forkify(median);
-//  fork-guard-elim(median);
-//}
-//simpl!(median);
-//fixpoint {
-//  fork-unroll(median);
-//}
-//ccp(median);
-//array-to-product(median);
-//sroa(median);
-//predication(median);
-//simpl!(median);
-//
-//inline(fuse2);
-//ip-sroa(*);
-//sroa(*);
-//array-slf(fuse2);
-//write-predication(fuse2);
-//simpl!(fuse2);
-//fork-split(fuse2);
-//unforkify(fuse2);
+fixpoint {
+  forkify(fuse2);
+  fork-guard-elim(fuse2);
+  fork-coalesce(fuse2);
+}
+simpl!(fuse2);
+predication(fuse2);
+simpl!(fuse2);
+
+let median = outline(fuse2@median);
+fork-unroll(median@medianOuter);
+simpl!(median);
+fixpoint {
+  forkify(median);
+  fork-guard-elim(median);
+}
+simpl!(median);
+fixpoint {
+  fork-unroll(median);
+}
+ccp(median);
+array-to-product(median);
+sroa(median);
+predication(median);
+simpl!(median);
+
+inline(fuse2);
+ip-sroa(*);
+sroa(*);
+array-slf(fuse2);
+write-predication(fuse2);
+simpl!(fuse2);
+fork-split(fuse2);
+unforkify(fuse2);
 
 no-memset(fuse3@res);
-//fixpoint {
-//  forkify(fuse3);
-//  fork-guard-elim(fuse3);
-//  fork-coalesce(fuse3);
-//}
-//simpl!(fuse3);
-//fork-split(fuse3);
-//unforkify(fuse3);
+fixpoint {
+  forkify(fuse3);
+  fork-guard-elim(fuse3);
+  fork-coalesce(fuse3);
+}
+simpl!(fuse3);
+fork-split(fuse3);
+unforkify(fuse3);
 
 no-memset(fuse4@res);
 no-memset(fuse4@l2);
-//fixpoint {
-//  forkify(fuse4);
-//  fork-guard-elim(fuse4);
-//  fork-coalesce(fuse4);
-//}
-//simpl!(fuse4);
-//fork-split(fuse4);
-//unforkify(fuse4);
+fixpoint {
+  forkify(fuse4);
+  fork-guard-elim(fuse4);
+  fork-coalesce(fuse4);
+}
+simpl!(fuse4);
+fork-split(fuse4);
+unforkify(fuse4);
 
 no-memset(fuse5@res1);
 no-memset(fuse5@res2);
-//fixpoint {
-//  forkify(fuse5);
-//  fork-guard-elim(fuse5);
-//  fork-coalesce(fuse5);
-//}
-//simpl!(fuse5);
-//array-slf(fuse5);
-//simpl!(fuse5);
-//fork-split(fuse5);
-//unforkify(fuse5);
+fixpoint {
+  forkify(fuse5);
+  fork-guard-elim(fuse5);
+  fork-coalesce(fuse5);
+}
+simpl!(fuse5);
+array-slf(fuse5);
+simpl!(fuse5);
+fork-split(fuse5);
+unforkify(fuse5);
 
 simpl!(*);
 delete-uncalled(*);
-- 
GitLab


From 81f524ca76868a04d1642895fe974f6f58ca8484 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 14:27:55 -0600
Subject: [PATCH 10/15] This eliminates most selects

---
 juno_samples/cava/src/cpu.sch | 1 +
 1 file changed, 1 insertion(+)

diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 246d10d4..e9e3e683 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -68,6 +68,7 @@ fixpoint {
 ccp(median);
 array-to-product(median);
 sroa(median);
+phi-elim(median);
 predication(median);
 simpl!(median);
 
-- 
GitLab


From 0958995eb70b45d0b7d71994b5f2de47fb8828e3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 14:40:24 -0600
Subject: [PATCH 11/15] unroll channel loop in gamut with intention of fusion
 the ctrl pts loop across channels

---
 juno_samples/cava/src/cava.jn | 2 +-
 juno_samples/cava/src/cpu.sch | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index 839e80e2..e6961faf 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -153,7 +153,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
         l2_dist[cp] = sqrt!::<f32>(v);
       }
 
-      for chan = 0 to CHAN {
+      @channel_loop for chan = 0 to CHAN {
         let chan_val : f32 = 0.0;
         for cp = 0 to num_ctrl_pts {
           chan_val += l2_dist[cp] * weights[cp, chan];
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index e9e3e683..3021f6a0 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -99,6 +99,9 @@ fixpoint {
   fork-coalesce(fuse4);
 }
 simpl!(fuse4);
+fork-unroll(fuse4@channel_loop);
+simpl!(fuse4);
+//fork-fusion(fuse4@channel_loop);
 fork-split(fuse4);
 unforkify(fuse4);
 
-- 
GitLab


From 0b1a9eb3c1a023f695fad05fd070c62610cd2a22 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 15:02:58 -0600
Subject: [PATCH 12/15] Add intrinsics to math exprs, array slf gets rid of
 first ctrl pts look in gamut

---
 hercules_ir/src/einsum.rs     | 31 +++++++++++++++++++++++++++++++
 hercules_opt/src/utils.rs     |  7 +++++++
 juno_samples/cava/src/cpu.sch |  3 +++
 3 files changed, 41 insertions(+)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 5dd0fe5d..b222e1bc 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -34,6 +34,7 @@ pub enum MathExpr {
     Unary(UnaryOperator, MathID),
     Binary(BinaryOperator, MathID, MathID),
     Ternary(TernaryOperator, MathID, MathID, MathID),
+    IntrinsicFunc(Intrinsic, Box<[MathID]>),
 }
 
 pub type MathEnv = Vec<MathExpr>;
@@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> {
                 let third = self.compute_math_expr(third);
                 MathExpr::Ternary(op, first, second, third)
             }
+            Node::IntrinsicCall {
+                intrinsic,
+                ref args,
+            } => {
+                let args = args
+                    .into_iter()
+                    .map(|id| self.compute_math_expr(*id))
+                    .collect();
+                MathExpr::IntrinsicFunc(intrinsic, args)
+            }
             Node::Read {
                 collect,
                 ref indices,
@@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> {
                 let third = self.substitute_new_dims(third);
                 self.intern_math_expr(MathExpr::Ternary(op, first, second, third))
             }
+            MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+                let args = args
+                    .clone()
+                    .iter()
+                    .map(|id| self.substitute_new_dims(*id))
+                    .collect();
+                self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args))
+            }
             _ => id,
         }
     }
@@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
                 stack.push(second);
                 stack.push(third);
             }
+            MathExpr::IntrinsicFunc(_, ref args) => {
+                stack.extend(args);
+            }
         }
     }
     set
@@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
             debug_print_math_expr(third, env);
             print!(")");
         }
+        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+            print!("{}(", intrinsic.lower_case_name());
+            debug_print_math_expr(id, env);
+            for arg in args {
+                print!(", ");
+                debug_print_math_expr(*arg, env);
+            }
+            print!(")");
+        }
     }
 }
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index c5e0d934..3f12ad7c 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr(
                 third,
             })
         }
+        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+            let args = args
+                .into_iter()
+                .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
+                .collect();
+            edit.add_node(Node::IntrinsicCall { intrinsic, args })
+        }
         _ => panic!(),
     }
 }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 3021f6a0..1ae1dc13 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -102,6 +102,9 @@ simpl!(fuse4);
 fork-unroll(fuse4@channel_loop);
 simpl!(fuse4);
 //fork-fusion(fuse4@channel_loop);
+simpl!(fuse4);
+array-slf(fuse4);
+simpl!(fuse4);
 fork-split(fuse4);
 unforkify(fuse4);
 
-- 
GitLab


From 9397f55db557497e0fbf6934af646625668a8bc3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 17:17:02 -0600
Subject: [PATCH 13/15] Fork fusion

---
 hercules_opt/src/fork_transforms.rs | 120 +++++++++++++++++++++++++---
 juno_samples/cava/src/cava.jn       |   2 +-
 juno_samples/cava/src/cpu.sch       |   4 +-
 juno_scheduler/src/compile.rs       |   1 +
 juno_scheduler/src/ir.rs            |   1 +
 juno_scheduler/src/pm.rs            |  21 +++++
 6 files changed, 135 insertions(+), 14 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 72b5716f..c32a517e 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -912,16 +912,16 @@ pub fn chunk_fork_unguarded(
                 let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
                 new_factors.insert(dim_idx + 1, tile_size);
                 new_factors[dim_idx] = edit.add_dynamic_constant(outer);
-        
+
                 let new_fork = Node::Fork {
                     control: old_control,
                     factors: new_factors.into(),
                 };
                 let new_fork = edit.add_node(new_fork);
-        
+
                 edit = edit.replace_all_uses(fork, new_fork)?;
                 edit.sub_edit(fork, new_fork);
-        
+
                 for (tid, node) in fork_users {
                     let Node::ThreadID {
                         control: _,
@@ -945,7 +945,7 @@ pub fn chunk_fork_unguarded(
                             dimension: tid_dim + 1,
                         };
                         let tile_tid = edit.add_node(tile_tid);
-        
+
                         let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                         let mul = edit.add_node(Node::Binary {
                             left: tid,
@@ -965,23 +965,23 @@ pub fn chunk_fork_unguarded(
                 edit = edit.delete_node(fork)?;
                 Ok(edit)
             });
-        },
+        }
         TileOrder::TileOuter => {
             editor.edit(|mut edit| {
                 let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
                 new_factors.insert(dim_idx, tile_size);
-                let inner_dc_id =  edit.add_dynamic_constant(inner);
+                let inner_dc_id = edit.add_dynamic_constant(inner);
                 new_factors[dim_idx + 1] = inner_dc_id;
-        
+
                 let new_fork = Node::Fork {
                     control: old_control,
                     factors: new_factors.into(),
                 };
                 let new_fork = edit.add_node(new_fork);
-        
+
                 edit = edit.replace_all_uses(fork, new_fork)?;
                 edit.sub_edit(fork, new_fork);
-        
+
                 for (tid, node) in fork_users {
                     let Node::ThreadID {
                         control: _,
@@ -1000,13 +1000,12 @@ pub fn chunk_fork_unguarded(
                         edit.sub_edit(tid, new_tid);
                         edit = edit.delete_node(tid)?;
                     } else if tid_dim == dim_idx {
-
                         let tile_tid = Node::ThreadID {
                             control: new_fork,
                             dimension: tid_dim,
                         };
                         let tile_tid = edit.add_node(tile_tid);
-                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } );
+                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
                         let mul = edit.add_node(Node::Binary {
                             left: tid,
                             right: inner_dc,
@@ -1027,7 +1026,6 @@ pub fn chunk_fork_unguarded(
             });
         }
     }
-    
 }
 
 pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
@@ -1350,3 +1348,101 @@ pub fn fork_unroll(
         Ok(edit)
     })
 }
+
+/*
+ * Looks for fork-joins that are next to each other, not inter-dependent, and
+ * have the same bounds. These fork-joins can be fused, pooling together all
+ * their reductions.
+ */
+pub fn fork_fusion_all_forks(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) {
+    for (fork, join) in fork_join_map {
+        if editor.is_mutable(*fork)
+            && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins)
+        {
+            break;
+        }
+    }
+}
+
+/*
+ * Tries to fuse a given fork join with the immediately following fork-join, if
+ * it exists.
+ */
+fn fork_fusion(
+    editor: &mut FunctionEditor,
+    top_fork: NodeID,
+    top_join: NodeID,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> bool {
+    let nodes = &editor.func().nodes;
+    // Rust operator precedence is not such that these can be put in one big
+    // let-else statement. Sad!
+    let Some(bottom_fork) = editor
+        .get_users(top_join)
+        .filter(|id| nodes[id.idx()].is_control())
+        .next()
+    else {
+        return false;
+    };
+    let Some(bottom_join) = fork_join_map.get(&bottom_fork) else {
+        return false;
+    };
+    let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap();
+    let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap();
+    assert_eq!(bottom_fork_pred, top_join);
+    let top_join_pred = nodes[top_join.idx()].try_join().unwrap();
+    let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap();
+
+    // The fork factors must be identical.
+    if top_factors != bottom_factors {
+        return false;
+    }
+
+    // Check that no iterated users of the top's reduces are in the bottom fork-
+    // join (iteration stops at a phi or reduce outside the bottom fork-join).
+    for reduce in editor
+        .get_users(top_join)
+        .filter(|id| nodes[id.idx()].is_reduce())
+    {
+        let mut visited = HashSet::new();
+        visited.insert(reduce);
+        let mut workset = vec![reduce];
+        while let Some(pop) = workset.pop() {
+            for u in editor.get_users(pop) {
+                if nodes_in_fork_joins[&bottom_fork].contains(&u) {
+                    return false;
+                } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce())
+                    && !nodes_in_fork_joins[&top_fork].contains(&u)
+                {
+                } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) {
+                    visited.insert(u);
+                    workset.push(u);
+                }
+            }
+        }
+    }
+
+    // Perform the fusion.
+    editor.edit(|mut edit| {
+        if bottom_join_pred != bottom_fork {
+            // If there is control flow in the bottom fork-join, stitch it into
+            // the top fork-join.
+            edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| {
+                nodes_in_fork_joins[&bottom_fork].contains(id)
+            })?;
+            edit =
+                edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?;
+        }
+        // Replace the bottom fork and join with the top fork and join.
+        edit = edit.replace_all_uses(bottom_fork, top_fork)?;
+        edit = edit.replace_all_uses(*bottom_join, top_join)?;
+        edit = edit.delete_node(bottom_fork)?;
+        edit = edit.delete_node(*bottom_join)?;
+        Ok(edit)
+    })
+}
diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index e6961faf..8158bf0a 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
         let v  = v1 * v1 + v2 * v2 + v3 * v3;
         l2_dist[cp] = sqrt!::<f32>(v);
       }
-
+      
       @channel_loop for chan = 0 to CHAN {
         let chan_val : f32 = 0.0;
         for cp = 0 to num_ctrl_pts {
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 1ae1dc13..0b8869d5 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -101,7 +101,9 @@ fixpoint {
 simpl!(fuse4);
 fork-unroll(fuse4@channel_loop);
 simpl!(fuse4);
-//fork-fusion(fuse4@channel_loop);
+fixpoint {
+  fork-fusion(fuse4@channel_loop);
+}
 simpl!(fuse4);
 array-slf(fuse4);
 simpl!(fuse4);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 912cc91f..7c92e00d 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -129,6 +129,7 @@ impl FromStr for Appliable {
             "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
+            "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 8ad92324..205cd70b 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -13,6 +13,7 @@ pub enum Pass {
     ForkCoalesce,
     ForkDimMerge,
     ForkFissionBufferize,
+    ForkFusion,
     ForkGuardElim,
     ForkInterchange,
     ForkSplit,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index b26d1720..a4783a93 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2526,6 +2526,27 @@ fn run_pass(
                 fields: new_fork_joins,
             };
         }
+        Pass::ForkFusion => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            for ((func, fork_join_map), nodes_in_fork_joins) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
-- 
GitLab


From bb50794136a039f29236fb70385ac1abbdd4d1c3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 17:31:03 -0600
Subject: [PATCH 14/15] Crop the small example image to round dimensions

---
 juno_samples/cava/src/image_proc.rs | 37 +++++++++++++++++++----------
 juno_samples/cava/src/main.rs       | 13 +++++++++-
 2 files changed, 37 insertions(+), 13 deletions(-)

diff --git a/juno_samples/cava/src/image_proc.rs b/juno_samples/cava/src/image_proc.rs
index 94a7c14f..7115bac8 100644
--- a/juno_samples/cava/src/image_proc.rs
+++ b/juno_samples/cava/src/image_proc.rs
@@ -17,41 +17,54 @@ fn read_bin_u32(f: &mut File) -> Result<u32, Error> {
     Ok(u32::from_le_bytes(buffer))
 }
 
-pub fn read_raw<P: AsRef<Path>>(image: P) -> Result<RawImage, Error> {
+pub fn read_raw<P: AsRef<Path>>(
+    image: P,
+    crop_rows: Option<usize>,
+    crop_cols: Option<usize>,
+) -> Result<RawImage, Error> {
     let mut file = File::open(image)?;
-    let rows = read_bin_u32(&mut file)? as usize;
-    let cols = read_bin_u32(&mut file)? as usize;
+    let in_rows = read_bin_u32(&mut file)? as usize;
+    let in_cols = read_bin_u32(&mut file)? as usize;
+    let out_rows = crop_rows.unwrap_or(in_rows);
+    let out_cols = crop_cols.unwrap_or(in_cols);
     let chans = read_bin_u32(&mut file)? as usize;
 
     assert!(
         chans == CHAN,
         "Channel size read from the binary file doesn't match the default value"
     );
+    assert!(in_rows >= out_rows);
+    assert!(in_cols >= out_cols);
 
-    let size: usize = rows * cols * CHAN;
-    let mut buffer: Vec<u8> = vec![0; size];
+    let in_size: usize = in_rows * in_cols * CHAN;
+    let mut buffer: Vec<u8> = vec![0; in_size];
     // The file has pixels is a 3D array with dimensions height, width, channel
     file.read_exact(&mut buffer)?;
 
     let chunked_channels = buffer.chunks(CHAN).collect::<Vec<_>>();
-    let chunked = chunked_channels.chunks(cols).collect::<Vec<_>>();
+    let chunked = chunked_channels.chunks(in_cols).collect::<Vec<_>>();
     let input: &[&[&[u8]]] = chunked.as_slice();
 
     // We need the image in a 3D array with dimensions channel, height, width
-    let mut pixels: Vec<u8> = vec![0; size];
-    let mut pixels_columns = pixels.chunks_mut(cols).collect::<Vec<_>>();
-    let mut pixels_chunked = pixels_columns.chunks_mut(rows).collect::<Vec<_>>();
+    let out_size: usize = out_rows * out_cols * CHAN;
+    let mut pixels: Vec<u8> = vec![0; out_size];
+    let mut pixels_columns = pixels.chunks_mut(out_cols).collect::<Vec<_>>();
+    let mut pixels_chunked = pixels_columns.chunks_mut(out_rows).collect::<Vec<_>>();
     let result: &mut [&mut [&mut [u8]]] = pixels_chunked.as_mut_slice();
 
-    for row in 0..rows {
-        for col in 0..cols {
+    for row in 0..out_rows {
+        for col in 0..out_cols {
             for chan in 0..CHAN {
                 result[chan][row][col] = input[row][col][chan];
             }
         }
     }
 
-    Ok(RawImage { rows, cols, pixels })
+    Ok(RawImage {
+        rows: out_rows,
+        cols: out_cols,
+        pixels,
+    })
 }
 
 pub fn extern_image(rows: usize, cols: usize, image: &[u8]) -> RgbImage {
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index 700b74e6..142ed703 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -147,6 +147,10 @@ struct CavaInputs {
     #[clap(long = "output-verify", value_name = "PATH")]
     output_verify: Option<String>,
     cam_model: String,
+    #[clap(short, long)]
+    crop_rows: Option<usize>,
+    #[clap(short, long)]
+    crop_cols: Option<usize>,
 }
 
 fn cava_harness(args: CavaInputs) {
@@ -156,8 +160,11 @@ fn cava_harness(args: CavaInputs) {
         verify,
         output_verify,
         cam_model,
+        crop_rows,
+        crop_cols,
     } = args;
-    let RawImage { rows, cols, pixels } = read_raw(input).expect("Error loading image");
+    let RawImage { rows, cols, pixels } =
+        read_raw(input, crop_rows, crop_cols).expect("Error loading image");
     let CamModel {
         tstw,
         num_ctrl_pts,
@@ -236,6 +243,8 @@ fn cava_test_small() {
         verify: true,
         output_verify: None,
         cam_model: "cam_models/NikonD7000".to_string(),
+        crop_rows: Some(144),
+        crop_cols: Some(192),
     });
 }
 
@@ -248,5 +257,7 @@ fn cava_test_full() {
         verify: true,
         output_verify: None,
         cam_model: "cam_models/NikonD7000".to_string(),
+        crop_rows: None,
+        crop_cols: None,
     });
 }
-- 
GitLab


From a3ec1d56780b9afc79bbf8fadc7268620cef6e9a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 17:35:49 -0600
Subject: [PATCH 15/15] Reorganize cpu schedule

---
 juno_samples/cava/src/cpu.sch | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 0b8869d5..623dcfbc 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -38,8 +38,6 @@ fixpoint {
 }
 simpl!(fuse1);
 array-slf(fuse1);
-simpl!(fuse1);
-unforkify(fuse1);
 
 inline(fuse2);
 no-memset(fuse2@res);
@@ -78,8 +76,6 @@ sroa(*);
 array-slf(fuse2);
 write-predication(fuse2);
 simpl!(fuse2);
-fork-split(fuse2);
-unforkify(fuse2);
 
 no-memset(fuse3@res);
 fixpoint {
@@ -88,8 +84,6 @@ fixpoint {
   fork-coalesce(fuse3);
 }
 simpl!(fuse3);
-fork-split(fuse3);
-unforkify(fuse3);
 
 no-memset(fuse4@res);
 no-memset(fuse4@l2);
@@ -107,8 +101,6 @@ fixpoint {
 simpl!(fuse4);
 array-slf(fuse4);
 simpl!(fuse4);
-fork-split(fuse4);
-unforkify(fuse4);
 
 no-memset(fuse5@res1);
 no-memset(fuse5@res2);
@@ -120,9 +112,23 @@ fixpoint {
 simpl!(fuse5);
 array-slf(fuse5);
 simpl!(fuse5);
+
+delete-uncalled(*);
+simpl!(*);
+xdot[true](*);
+
+simpl!(fuse1);
+unforkify(fuse1);
+fork-split(fuse2);
+unforkify(fuse2);
+fork-split(fuse3);
+unforkify(fuse3);
+fork-split(fuse4);
+unforkify(fuse4);
 fork-split(fuse5);
 unforkify(fuse5);
 
 simpl!(*);
+
 delete-uncalled(*);
 gcm(*);
-- 
GitLab