From e936110be361a33ecd4e6f3756ca94a44dd9a341 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Fri, 14 Feb 2025 19:12:18 -0600
Subject: [PATCH] Some Cava optimization

---
 hercules_cg/src/rt.rs               |   2 +-
 hercules_ir/src/dot.rs              |  13 ++-
 hercules_ir/src/einsum.rs           |  31 +++++++
 hercules_opt/src/fork_transforms.rs | 120 ++++++++++++++++++++++---
 hercules_opt/src/utils.rs           |   7 ++
 juno_samples/cava/build.rs          |   2 +
 juno_samples/cava/src/cava.jn       |  97 ++++++++++----------
 juno_samples/cava/src/cpu.sch       | 134 ++++++++++++++++++++++++++++
 juno_samples/cava/src/image_proc.rs |  37 +++++---
 juno_samples/cava/src/main.rs       |  83 ++++++++++-------
 juno_scheduler/src/compile.rs       |   3 +
 juno_scheduler/src/ir.rs            |  12 ++-
 juno_scheduler/src/lang.l           |   4 +-
 juno_scheduler/src/pm.rs            |  24 +++++
 14 files changed, 457 insertions(+), 112 deletions(-)
 create mode 100644 juno_samples/cava/src/cpu.sch

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 090253d4..f758fed3 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -533,7 +533,7 @@ impl<'a> RTContext<'a> {
                 {
                     write!(block, "backing_{}.byte_add(", device.name())?;
                     self.codegen_dynamic_constant(offset, block)?;
-                    write!(block, "), ")?
+                    write!(block, " as usize), ")?
                 }
                 for dc in dynamic_constants {
                     self.codegen_dynamic_constant(*dc, block)?;
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 9c6c5f17..0e084085 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -24,8 +24,8 @@ pub fn xdot_module(
     bbs: Option<&Vec<BasicBlocks>>,
 ) {
     let mut tmp_path = temp_dir();
-    let mut rng = rand::thread_rng();
-    let num: u64 = rng.gen();
+    let mut rng = rand::rng();
+    let num: u64 = rng.random();
     tmp_path.push(format!("hercules_dot_{}.dot", num));
     let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output file.");
     let mut contents = String::new();
@@ -43,11 +43,11 @@ pub fn xdot_module(
     .expect("PANIC: Unable to generate output file contents.");
     file.write_all(contents.as_bytes())
         .expect("PANIC: Unable to write output file contents.");
+    println!("Graphviz written to: {}", tmp_path.display());
     Command::new("xdot")
         .args([&tmp_path])
         .output()
         .expect("PANIC: Couldn't execute xdot. Is xdot installed?");
-    println!("Graphviz written to: {}", tmp_path.display());
 }
 
 /*
@@ -108,6 +108,13 @@ pub fn write_dot<W: Write>(
 
         for node_id in (0..function.nodes.len()).map(NodeID::new) {
             let node = &function.nodes[node_id.idx()];
+            let skip = node.is_constant()
+                || node.is_dynamic_constant()
+                || node.is_undef()
+                || node.is_parameter();
+            if skip {
+                continue;
+            }
             let dst_control = node.is_control();
             for u in def_use::get_uses(&node).as_ref() {
                 let src_control = function.nodes[u.idx()].is_control();
diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 5dd0fe5d..b222e1bc 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -34,6 +34,7 @@ pub enum MathExpr {
     Unary(UnaryOperator, MathID),
     Binary(BinaryOperator, MathID, MathID),
     Ternary(TernaryOperator, MathID, MathID, MathID),
+    IntrinsicFunc(Intrinsic, Box<[MathID]>),
 }
 
 pub type MathEnv = Vec<MathExpr>;
@@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> {
                 let third = self.compute_math_expr(third);
                 MathExpr::Ternary(op, first, second, third)
             }
+            Node::IntrinsicCall {
+                intrinsic,
+                ref args,
+            } => {
+                let args = args
+                    .into_iter()
+                    .map(|id| self.compute_math_expr(*id))
+                    .collect();
+                MathExpr::IntrinsicFunc(intrinsic, args)
+            }
             Node::Read {
                 collect,
                 ref indices,
@@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> {
                 let third = self.substitute_new_dims(third);
                 self.intern_math_expr(MathExpr::Ternary(op, first, second, third))
             }
+            MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+                let args = args
+                    .clone()
+                    .iter()
+                    .map(|id| self.substitute_new_dims(*id))
+                    .collect();
+                self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args))
+            }
             _ => id,
         }
     }
@@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
                 stack.push(second);
                 stack.push(third);
             }
+            MathExpr::IntrinsicFunc(_, ref args) => {
+                stack.extend(args);
+            }
         }
     }
     set
@@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
             debug_print_math_expr(third, env);
             print!(")");
         }
+        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+            print!("{}(", intrinsic.lower_case_name());
+            debug_print_math_expr(id, env);
+            for arg in args {
+                print!(", ");
+                debug_print_math_expr(*arg, env);
+            }
+            print!(")");
+        }
     }
 }
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 72b5716f..c32a517e 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -912,16 +912,16 @@ pub fn chunk_fork_unguarded(
                 let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
                 new_factors.insert(dim_idx + 1, tile_size);
                 new_factors[dim_idx] = edit.add_dynamic_constant(outer);
-        
+
                 let new_fork = Node::Fork {
                     control: old_control,
                     factors: new_factors.into(),
                 };
                 let new_fork = edit.add_node(new_fork);
-        
+
                 edit = edit.replace_all_uses(fork, new_fork)?;
                 edit.sub_edit(fork, new_fork);
-        
+
                 for (tid, node) in fork_users {
                     let Node::ThreadID {
                         control: _,
@@ -945,7 +945,7 @@ pub fn chunk_fork_unguarded(
                             dimension: tid_dim + 1,
                         };
                         let tile_tid = edit.add_node(tile_tid);
-        
+
                         let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                         let mul = edit.add_node(Node::Binary {
                             left: tid,
@@ -965,23 +965,23 @@ pub fn chunk_fork_unguarded(
                 edit = edit.delete_node(fork)?;
                 Ok(edit)
             });
-        },
+        }
         TileOrder::TileOuter => {
             editor.edit(|mut edit| {
                 let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
                 new_factors.insert(dim_idx, tile_size);
-                let inner_dc_id =  edit.add_dynamic_constant(inner);
+                let inner_dc_id = edit.add_dynamic_constant(inner);
                 new_factors[dim_idx + 1] = inner_dc_id;
-        
+
                 let new_fork = Node::Fork {
                     control: old_control,
                     factors: new_factors.into(),
                 };
                 let new_fork = edit.add_node(new_fork);
-        
+
                 edit = edit.replace_all_uses(fork, new_fork)?;
                 edit.sub_edit(fork, new_fork);
-        
+
                 for (tid, node) in fork_users {
                     let Node::ThreadID {
                         control: _,
@@ -1000,13 +1000,12 @@ pub fn chunk_fork_unguarded(
                         edit.sub_edit(tid, new_tid);
                         edit = edit.delete_node(tid)?;
                     } else if tid_dim == dim_idx {
-
                         let tile_tid = Node::ThreadID {
                             control: new_fork,
                             dimension: tid_dim,
                         };
                         let tile_tid = edit.add_node(tile_tid);
-                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } );
+                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
                         let mul = edit.add_node(Node::Binary {
                             left: tid,
                             right: inner_dc,
@@ -1027,7 +1026,6 @@ pub fn chunk_fork_unguarded(
             });
         }
     }
-    
 }
 
 pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
@@ -1350,3 +1348,101 @@ pub fn fork_unroll(
         Ok(edit)
     })
 }
+
+/*
+ * Looks for fork-joins that are next to each other, not inter-dependent, and
+ * have the same bounds. These fork-joins can be fused, pooling together all
+ * their reductions.
+ */
+pub fn fork_fusion_all_forks(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) {
+    for (fork, join) in fork_join_map {
+        if editor.is_mutable(*fork)
+            && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins)
+        {
+            break;
+        }
+    }
+}
+
+/*
+ * Tries to fuse a given fork join with the immediately following fork-join, if
+ * it exists.
+ */
+fn fork_fusion(
+    editor: &mut FunctionEditor,
+    top_fork: NodeID,
+    top_join: NodeID,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> bool {
+    let nodes = &editor.func().nodes;
+    // Rust operator precedence is not such that these can be put in one big
+    // let-else statement. Sad!
+    let Some(bottom_fork) = editor
+        .get_users(top_join)
+        .filter(|id| nodes[id.idx()].is_control())
+        .next()
+    else {
+        return false;
+    };
+    let Some(bottom_join) = fork_join_map.get(&bottom_fork) else {
+        return false;
+    };
+    let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap();
+    let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap();
+    assert_eq!(bottom_fork_pred, top_join);
+    let top_join_pred = nodes[top_join.idx()].try_join().unwrap();
+    let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap();
+
+    // The fork factors must be identical.
+    if top_factors != bottom_factors {
+        return false;
+    }
+
+    // Check that no iterated users of the top's reduces are in the bottom fork-
+    // join (iteration stops at a phi or reduce outside the bottom fork-join).
+    for reduce in editor
+        .get_users(top_join)
+        .filter(|id| nodes[id.idx()].is_reduce())
+    {
+        let mut visited = HashSet::new();
+        visited.insert(reduce);
+        let mut workset = vec![reduce];
+        while let Some(pop) = workset.pop() {
+            for u in editor.get_users(pop) {
+                if nodes_in_fork_joins[&bottom_fork].contains(&u) {
+                    return false;
+                } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce())
+                    && !nodes_in_fork_joins[&top_fork].contains(&u)
+                {
+                } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) {
+                    visited.insert(u);
+                    workset.push(u);
+                }
+            }
+        }
+    }
+
+    // Perform the fusion.
+    editor.edit(|mut edit| {
+        if bottom_join_pred != bottom_fork {
+            // If there is control flow in the bottom fork-join, stitch it into
+            // the top fork-join.
+            edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| {
+                nodes_in_fork_joins[&bottom_fork].contains(id)
+            })?;
+            edit =
+                edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?;
+        }
+        // Replace the bottom fork and join with the top fork and join.
+        edit = edit.replace_all_uses(bottom_fork, top_fork)?;
+        edit = edit.replace_all_uses(*bottom_join, top_join)?;
+        edit = edit.delete_node(bottom_fork)?;
+        edit = edit.delete_node(*bottom_join)?;
+        Ok(edit)
+    })
+}
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index c5e0d934..3f12ad7c 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr(
                 third,
             })
         }
+        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+            let args = args
+                .into_iter()
+                .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
+                .collect();
+            edit.add_node(Node::IntrinsicCall { intrinsic, args })
+        }
         _ => panic!(),
     }
 }
diff --git a/juno_samples/cava/build.rs b/juno_samples/cava/build.rs
index 1b6dddf4..ff7e2b6b 100644
--- a/juno_samples/cava/build.rs
+++ b/juno_samples/cava/build.rs
@@ -13,6 +13,8 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("cava.jn")
         .unwrap()
+        .schedule_in_src("cpu.sch")
+        .unwrap()
         .build()
         .unwrap();
 }
diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index 95a73f5b..8158bf0a 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -1,47 +1,35 @@
 fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a {
-  const n : usize = rows * cols;
-
-  let tmp : a[rows * cols];
-  for i = 0 to rows * cols {
-    tmp[i] = m[i / cols, i % cols];
-  }
-
-  for i = 0 to n - 1 {
-    for j = 0 to n - i - 1 {
-      if tmp[j] > tmp[j+1] {
-        let t : a = tmp[j];
-        tmp[j] = tmp[j+1];
-        tmp[j+1] = t;
+  @median {
+    const n : usize = rows * cols;
+    
+    @tmp let tmp : a[rows * cols];
+    for i = 0 to rows * cols {
+      tmp[i] = m[i / cols, i % cols];
+    }
+    
+    @medianOuter for i = 0 to n - 1 {
+      for j = 0 to n - i - 1 {
+        if tmp[j] > tmp[j+1] {
+          let t : a = tmp[j];
+          tmp[j] = tmp[j+1];
+          tmp[j+1] = t;
+        }
       }
     }
+    
+    return tmp[n / 2];
   }
-
-  return tmp[n / 2];
 }
 
 const CHAN : u64 = 3;
 
 fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] {
-  let res : f32[CHAN, row, col];
+  @res1 let res : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
       for c = 0 to col {
-        res[chan, r, c] = input[chan, r, c] as f32 * 1.0 / 255;
-      }
-    }
-  }
-
-  return res;
-}
-
-fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] {
-  let res : u8[CHAN, row, col];
-
-  for chan = 0 to CHAN {
-    for r = 0 to row {
-      for c = 0 to col {
-        res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8;
+        res[chan, r, c] = input[chan, r, c] as f32 / 255.0;
       }
     }
   }
@@ -50,7 +38,7 @@ fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, ro
 }
 
 fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
-  let res : f32[CHAN, row, col];
+  @res2 let res : f32[CHAN, row, col];
 
   for r = 1 to row-1 {
     for c = 1 to col-1 {
@@ -102,13 +90,13 @@ fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN,
 }
 
 fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
-  let res : f32[CHAN, row, col];
+  @res let res : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
       for c = 0 to col {
         if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 {
-          let filter : f32[3][3]; // same as [3, 3]
+          @filter let filter : f32[3][3]; // same as [3, 3]
           for i = 0 to 3 by 1 {
             for j = 0 to 3 by 1 {
               filter[i, j] = input[chan, r + i - 1, c + j - 1];
@@ -129,7 +117,7 @@ fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, r
 fn transform<row : usize, col : usize>
   (input : f32[CHAN, row, col], tstw_trans : f32[CHAN, CHAN])
     -> f32[CHAN, row, col] {
-  let result : f32[CHAN, row, col];
+  @res let result : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
@@ -152,11 +140,11 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
   weights  : f32[num_ctrl_pts, CHAN],
   coefs : f32[4, CHAN]
 ) -> f32[CHAN, row, col] {
-  let result : f32[CHAN, row, col];
-  let l2_dist : f32[num_ctrl_pts];
+  @res let result : f32[CHAN, row, col];
 
   for r = 0 to row {
     for c = 0 to col {
+      @l2 let l2_dist : f32[num_ctrl_pts];
       for cp = 0 to num_ctrl_pts {
         let v1 = input[0, r, c] - ctrl_pts[cp, 0];
         let v2 = input[1, r, c] - ctrl_pts[cp, 1];
@@ -164,8 +152,8 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
         let v  = v1 * v1 + v2 * v2 + v3 * v3;
         l2_dist[cp] = sqrt!::<f32>(v);
       }
-
-      for chan = 0 to CHAN {
+      
+      @channel_loop for chan = 0 to CHAN {
         let chan_val : f32 = 0.0;
         for cp = 0 to num_ctrl_pts {
           chan_val += l2_dist[cp] * weights[cp, chan];
@@ -184,7 +172,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
 
 fn tone_map<row : usize, col:usize>
   (input : f32[CHAN, row, col], tone_map : f32[256, CHAN]) -> f32[CHAN, row, col] {
-  let result : f32[CHAN, row, col];
+  @res1 let result : f32[CHAN, row, col];
 
   for chan = 0 to CHAN {
     for r = 0 to row {
@@ -198,6 +186,20 @@ fn tone_map<row : usize, col:usize>
   return result;
 }
 
+fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] {
+  @res2 let res : u8[CHAN, row, col];
+
+  for chan = 0 to CHAN {
+    for r = 0 to row {
+      for c = 0 to col {
+        res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8;
+      }
+    }
+  }
+
+  return res;
+}
+
 #[entry]
 fn cava<r, c, num_ctrl_pts : usize>(
   input : u8[CHAN, r, c],
@@ -207,11 +209,12 @@ fn cava<r, c, num_ctrl_pts : usize>(
   coefs : f32[4, CHAN],
   tonemap : f32[256, CHAN],
 ) -> u8[CHAN, r, c] {
-  let scaled = scale::<r, c>(input);
-  let demosc = demosaic::<r, c>(scaled);
-  let denosd = denoise::<r, c>(demosc);
-  let transf = transform::<r, c>(denosd, TsTw);
-  let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
-  let tonemd = tone_map::<r, c>(gamutd, tonemap);
-  return descale::<r, c>(tonemd);
+  @fuse1 let scaled = scale::<r, c>(input);
+  @fuse1 let demosc = demosaic::<r, c>(scaled);
+  @fuse2 let denosd = denoise::<r, c>(demosc);
+  @fuse3 let transf = transform::<r, c>(denosd, TsTw);
+  @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs);
+  @fuse5 let tonemd = tone_map::<r, c>(gamutd, tonemap);
+  @fuse5 let dscald = descale::<r, c>(tonemd);
+  return dscald;
 }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
new file mode 100644
index 00000000..623dcfbc
--- /dev/null
+++ b/juno_samples/cava/src/cpu.sch
@@ -0,0 +1,134 @@
+macro simpl!(X) {
+  ccp(X);
+  simplify-cfg(X);
+  lift-dc-math(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  infer-schedules(X);
+}
+
+simpl!(*);
+
+let fuse1 = outline(cava@fuse1);
+inline(fuse1);
+
+let fuse2 = outline(cava@fuse2);
+inline(fuse2);
+
+let fuse3 = outline(cava@fuse3);
+inline(fuse3);
+
+let fuse4 = outline(cava@fuse4);
+inline(fuse4);
+
+let fuse5 = outline(cava@fuse5);
+inline(fuse5);
+
+ip-sroa(*);
+sroa(*);
+simpl!(*);
+
+no-memset(fuse1@res1);
+no-memset(fuse1@res2);
+fixpoint {
+  forkify(fuse1);
+  fork-guard-elim(fuse1);
+  fork-coalesce(fuse1);
+}
+simpl!(fuse1);
+array-slf(fuse1);
+
+inline(fuse2);
+no-memset(fuse2@res);
+no-memset(fuse2@filter);
+no-memset(fuse2@tmp);
+fixpoint {
+  forkify(fuse2);
+  fork-guard-elim(fuse2);
+  fork-coalesce(fuse2);
+}
+simpl!(fuse2);
+predication(fuse2);
+simpl!(fuse2);
+
+let median = outline(fuse2@median);
+fork-unroll(median@medianOuter);
+simpl!(median);
+fixpoint {
+  forkify(median);
+  fork-guard-elim(median);
+}
+simpl!(median);
+fixpoint {
+  fork-unroll(median);
+}
+ccp(median);
+array-to-product(median);
+sroa(median);
+phi-elim(median);
+predication(median);
+simpl!(median);
+
+inline(fuse2);
+ip-sroa(*);
+sroa(*);
+array-slf(fuse2);
+write-predication(fuse2);
+simpl!(fuse2);
+
+no-memset(fuse3@res);
+fixpoint {
+  forkify(fuse3);
+  fork-guard-elim(fuse3);
+  fork-coalesce(fuse3);
+}
+simpl!(fuse3);
+
+no-memset(fuse4@res);
+no-memset(fuse4@l2);
+fixpoint {
+  forkify(fuse4);
+  fork-guard-elim(fuse4);
+  fork-coalesce(fuse4);
+}
+simpl!(fuse4);
+fork-unroll(fuse4@channel_loop);
+simpl!(fuse4);
+fixpoint {
+  fork-fusion(fuse4@channel_loop);
+}
+simpl!(fuse4);
+array-slf(fuse4);
+simpl!(fuse4);
+
+no-memset(fuse5@res1);
+no-memset(fuse5@res2);
+fixpoint {
+  forkify(fuse5);
+  fork-guard-elim(fuse5);
+  fork-coalesce(fuse5);
+}
+simpl!(fuse5);
+array-slf(fuse5);
+simpl!(fuse5);
+
+delete-uncalled(*);
+simpl!(*);
+xdot[true](*);
+
+simpl!(fuse1);
+unforkify(fuse1);
+fork-split(fuse2);
+unforkify(fuse2);
+fork-split(fuse3);
+unforkify(fuse3);
+fork-split(fuse4);
+unforkify(fuse4);
+fork-split(fuse5);
+unforkify(fuse5);
+
+simpl!(*);
+
+delete-uncalled(*);
+gcm(*);
diff --git a/juno_samples/cava/src/image_proc.rs b/juno_samples/cava/src/image_proc.rs
index 94a7c14f..7115bac8 100644
--- a/juno_samples/cava/src/image_proc.rs
+++ b/juno_samples/cava/src/image_proc.rs
@@ -17,41 +17,54 @@ fn read_bin_u32(f: &mut File) -> Result<u32, Error> {
     Ok(u32::from_le_bytes(buffer))
 }
 
-pub fn read_raw<P: AsRef<Path>>(image: P) -> Result<RawImage, Error> {
+pub fn read_raw<P: AsRef<Path>>(
+    image: P,
+    crop_rows: Option<usize>,
+    crop_cols: Option<usize>,
+) -> Result<RawImage, Error> {
     let mut file = File::open(image)?;
-    let rows = read_bin_u32(&mut file)? as usize;
-    let cols = read_bin_u32(&mut file)? as usize;
+    let in_rows = read_bin_u32(&mut file)? as usize;
+    let in_cols = read_bin_u32(&mut file)? as usize;
+    let out_rows = crop_rows.unwrap_or(in_rows);
+    let out_cols = crop_cols.unwrap_or(in_cols);
     let chans = read_bin_u32(&mut file)? as usize;
 
     assert!(
         chans == CHAN,
         "Channel size read from the binary file doesn't match the default value"
     );
+    assert!(in_rows >= out_rows);
+    assert!(in_cols >= out_cols);
 
-    let size: usize = rows * cols * CHAN;
-    let mut buffer: Vec<u8> = vec![0; size];
+    let in_size: usize = in_rows * in_cols * CHAN;
+    let mut buffer: Vec<u8> = vec![0; in_size];
     // The file has pixels is a 3D array with dimensions height, width, channel
     file.read_exact(&mut buffer)?;
 
     let chunked_channels = buffer.chunks(CHAN).collect::<Vec<_>>();
-    let chunked = chunked_channels.chunks(cols).collect::<Vec<_>>();
+    let chunked = chunked_channels.chunks(in_cols).collect::<Vec<_>>();
     let input: &[&[&[u8]]] = chunked.as_slice();
 
     // We need the image in a 3D array with dimensions channel, height, width
-    let mut pixels: Vec<u8> = vec![0; size];
-    let mut pixels_columns = pixels.chunks_mut(cols).collect::<Vec<_>>();
-    let mut pixels_chunked = pixels_columns.chunks_mut(rows).collect::<Vec<_>>();
+    let out_size: usize = out_rows * out_cols * CHAN;
+    let mut pixels: Vec<u8> = vec![0; out_size];
+    let mut pixels_columns = pixels.chunks_mut(out_cols).collect::<Vec<_>>();
+    let mut pixels_chunked = pixels_columns.chunks_mut(out_rows).collect::<Vec<_>>();
     let result: &mut [&mut [&mut [u8]]] = pixels_chunked.as_mut_slice();
 
-    for row in 0..rows {
-        for col in 0..cols {
+    for row in 0..out_rows {
+        for col in 0..out_cols {
             for chan in 0..CHAN {
                 result[chan][row][col] = input[row][col][chan];
             }
         }
     }
 
-    Ok(RawImage { rows, cols, pixels })
+    Ok(RawImage {
+        rows: out_rows,
+        cols: out_cols,
+        pixels,
+    })
 }
 
 pub fn extern_image(rows: usize, cols: usize, image: &[u8]) -> RgbImage {
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index b4a0f6fd..142ed703 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -19,24 +19,31 @@ juno_build::juno!("cava");
 // Individual lifetimes are not needed in this example but should probably be generated for
 // flexibility
 async fn safe_run<'a, 'b: 'a, 'c: 'a, 'd: 'a, 'e: 'a, 'f: 'a, 'g: 'a>(
-    runner: &'a mut HerculesRunner_cava, r: u64, c: u64, num_ctrl_pts: u64,
-    input: &'b HerculesImmBox<'b, u8>, tstw: &'c HerculesImmBox<'c, f32>,
-    ctrl_pts: &'d HerculesImmBox<'d, f32>, weights: &'e HerculesImmBox<'e, f32>,
-    coefs: &'f HerculesImmBox<'f, f32>, tonemap: &'g HerculesImmBox<'g, f32>,
+    runner: &'a mut HerculesRunner_cava,
+    r: u64,
+    c: u64,
+    num_ctrl_pts: u64,
+    input: &'b HerculesImmBox<'b, u8>,
+    tstw: &'c HerculesImmBox<'c, f32>,
+    ctrl_pts: &'d HerculesImmBox<'d, f32>,
+    weights: &'e HerculesImmBox<'e, f32>,
+    coefs: &'f HerculesImmBox<'f, f32>,
+    tonemap: &'g HerculesImmBox<'g, f32>,
 ) -> HerculesMutBox<'a, u8> {
     HerculesMutBox::from(
-        runner.run(
-            r,
-            c,
-            num_ctrl_pts,
-            input.to(),
-            tstw.to(),
-            ctrl_pts.to(),
-            weights.to(),
-            coefs.to(),
-            tonemap.to()
-        )
-        .await
+        runner
+            .run(
+                r,
+                c,
+                num_ctrl_pts,
+                input.to(),
+                tstw.to(),
+                ctrl_pts.to(),
+                weights.to(),
+                coefs.to(),
+                tonemap.to(),
+            )
+            .await,
     )
 }
 
@@ -68,16 +75,17 @@ fn run_cava(
     let mut r = runner!(cava);
 
     async_std::task::block_on(async {
-        safe_run(&mut r,
-                 rows as u64,
-                 cols as u64,
-                 num_ctrl_pts as u64,
-                 &image,
-                 &tstw,
-                 &ctrl_pts,
-                 &weights,
-                 &coefs,
-                 &tonemap,
+        safe_run(
+            &mut r,
+            rows as u64,
+            cols as u64,
+            num_ctrl_pts as u64,
+            &image,
+            &tstw,
+            &ctrl_pts,
+            &weights,
+            &coefs,
+            &tonemap,
         )
         .await
     })
@@ -139,6 +147,10 @@ struct CavaInputs {
     #[clap(long = "output-verify", value_name = "PATH")]
     output_verify: Option<String>,
     cam_model: String,
+    #[clap(short, long)]
+    crop_rows: Option<usize>,
+    #[clap(short, long)]
+    crop_cols: Option<usize>,
 }
 
 fn cava_harness(args: CavaInputs) {
@@ -148,8 +160,11 @@ fn cava_harness(args: CavaInputs) {
         verify,
         output_verify,
         cam_model,
+        crop_rows,
+        crop_cols,
     } = args;
-    let RawImage { rows, cols, pixels } = read_raw(input).expect("Error loading image");
+    let RawImage { rows, cols, pixels } =
+        read_raw(input, crop_rows, crop_cols).expect("Error loading image");
     let CamModel {
         tstw,
         num_ctrl_pts,
@@ -159,6 +174,10 @@ fn cava_harness(args: CavaInputs) {
         tonemap,
     } = load_cam_model(cam_model, CHAN).expect("Error loading camera model");
 
+    println!(
+        "Running cava with {} rows, {} columns, and {} control points.",
+        rows, cols, num_ctrl_pts
+    );
     let result = run_cava(
         rows,
         cols,
@@ -224,19 +243,21 @@ fn cava_test_small() {
         verify: true,
         output_verify: None,
         cam_model: "cam_models/NikonD7000".to_string(),
+        crop_rows: Some(144),
+        crop_cols: Some(192),
     });
 }
 
-// Disabling the larger test because of how long it takes
-/*
 #[test]
-fn cava_test() {
+#[ignore]
+fn cava_test_full() {
     cava_harness(CavaInputs {
         input: "examples/raw_tulips.bin".to_string(),
         output: None,
         verify: true,
         output_verify: None,
         cam_model: "cam_models/NikonD7000".to_string(),
+        crop_rows: None,
+        crop_cols: None,
     });
 }
-*/
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 0652f8f2..7c92e00d 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -129,6 +129,7 @@ impl FromStr for Appliable {
             "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
+            "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
@@ -146,6 +147,8 @@ impl FromStr for Appliable {
             "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)),
             "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)),
 
+            "print" => Ok(Appliable::Pass(ir::Pass::Print)),
+
             "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)),
             "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)),
             "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index e92f1d37..205cd70b 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -13,6 +13,7 @@ pub enum Pass {
     ForkCoalesce,
     ForkDimMerge,
     ForkFissionBufferize,
+    ForkFusion,
     ForkGuardElim,
     ForkInterchange,
     ForkSplit,
@@ -27,6 +28,7 @@ pub enum Pass {
     Outline,
     PhiElim,
     Predication,
+    Print,
     ReduceSLF,
     Rename,
     ReuseProducts,
@@ -44,11 +46,12 @@ impl Pass {
     pub fn is_valid_num_args(&self, num: usize) -> bool {
         match self {
             Pass::ArrayToProduct => num == 0 || num == 1,
-            Pass::Rename => num == 1,
-            Pass::Xdot => num == 0 || num == 1,
             Pass::ForkChunk => num == 4,
             Pass::ForkFissionBufferize => num == 2,
             Pass::ForkInterchange => num == 2,
+            Pass::Print => num == 1,
+            Pass::Rename => num == 1,
+            Pass::Xdot => num == 0 || num == 1,
             _ => num == 0,
         }
     }
@@ -56,11 +59,12 @@ impl Pass {
     pub fn valid_arg_nums(&self) -> &'static str {
         match self {
             Pass::ArrayToProduct => "0 or 1",
-            Pass::Rename => "1",
-            Pass::Xdot => "0 or 1",
             Pass::ForkChunk => "4",
             Pass::ForkFissionBufferize => "2",
             Pass::ForkInterchange => "2",
+            Pass::Print => "1",
+            Pass::Rename => "1",
+            Pass::Xdot => "0 or 1",
             _ => "0",
         }
     }
diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l
index afe596b2..ca75276e 100644
--- a/juno_scheduler/src/lang.l
+++ b/juno_scheduler/src/lang.l
@@ -43,8 +43,8 @@ panic[\t \n\r]+after "panic_after"
 print[\t \n\r]+iter  "print_iter"
 stop[\t \n\r]+after  "stop_after"
 
-[a-zA-Z][a-zA-Z0-9_\-]*! "MACRO"
-[a-zA-Z][a-zA-Z0-9_\-]*  "ID"
+[a-zA-Z_][a-zA-Z0-9_\-]*! "MACRO"
+[a-zA-Z_][a-zA-Z0-9_\-]*  "ID"
 [0-9]+                   "INT"
 \"[a-zA-Z0-9_\-\s\.]*\"  "STRING"
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 8c2ecb19..a4783a93 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2526,6 +2526,27 @@ fn run_pass(
                 fields: new_fork_joins,
             };
         }
+        Pass::ForkFusion => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            for ((func, fork_join_map), nodes_in_fork_joins) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
@@ -2657,6 +2678,9 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
+        Pass::Print => {
+            println!("{:?}", args.get(0));
+        }
     }
     println!("Ran Pass: {:?}", pass);
 
-- 
GitLab