diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 72b5716f6ab56b260cc1d0d0cb8c78fa2da60d42..c32a517e46ab52a69f13feacb94382a8edaaf731 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -912,16 +912,16 @@ pub fn chunk_fork_unguarded(
                 let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
                 new_factors.insert(dim_idx + 1, tile_size);
                 new_factors[dim_idx] = edit.add_dynamic_constant(outer);
-        
+
                 let new_fork = Node::Fork {
                     control: old_control,
                     factors: new_factors.into(),
                 };
                 let new_fork = edit.add_node(new_fork);
-        
+
                 edit = edit.replace_all_uses(fork, new_fork)?;
                 edit.sub_edit(fork, new_fork);
-        
+
                 for (tid, node) in fork_users {
                     let Node::ThreadID {
                         control: _,
@@ -945,7 +945,7 @@ pub fn chunk_fork_unguarded(
                             dimension: tid_dim + 1,
                         };
                         let tile_tid = edit.add_node(tile_tid);
-        
+
                         let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                         let mul = edit.add_node(Node::Binary {
                             left: tid,
@@ -965,23 +965,23 @@ pub fn chunk_fork_unguarded(
                 edit = edit.delete_node(fork)?;
                 Ok(edit)
             });
-        },
+        }
         TileOrder::TileOuter => {
             editor.edit(|mut edit| {
                 let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
                 new_factors.insert(dim_idx, tile_size);
-                let inner_dc_id =  edit.add_dynamic_constant(inner);
+                let inner_dc_id = edit.add_dynamic_constant(inner);
                 new_factors[dim_idx + 1] = inner_dc_id;
-        
+
                 let new_fork = Node::Fork {
                     control: old_control,
                     factors: new_factors.into(),
                 };
                 let new_fork = edit.add_node(new_fork);
-        
+
                 edit = edit.replace_all_uses(fork, new_fork)?;
                 edit.sub_edit(fork, new_fork);
-        
+
                 for (tid, node) in fork_users {
                     let Node::ThreadID {
                         control: _,
@@ -1000,13 +1000,12 @@ pub fn chunk_fork_unguarded(
                         edit.sub_edit(tid, new_tid);
                         edit = edit.delete_node(tid)?;
                     } else if tid_dim == dim_idx {
-
                         let tile_tid = Node::ThreadID {
                             control: new_fork,
                             dimension: tid_dim,
                         };
                         let tile_tid = edit.add_node(tile_tid);
-                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } );
+                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
                         let mul = edit.add_node(Node::Binary {
                             left: tid,
                             right: inner_dc,
@@ -1027,7 +1026,6 @@ pub fn chunk_fork_unguarded(
             });
         }
     }
-    
 }
 
 pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
@@ -1350,3 +1348,101 @@ pub fn fork_unroll(
         Ok(edit)
     })
 }
+
+/*
+ * Looks for fork-joins that are next to each other, not inter-dependent, and
+ * have the same bounds. These fork-joins can be fused, pooling together all
+ * their reductions.
+ */
+pub fn fork_fusion_all_forks(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) {
+    for (fork, join) in fork_join_map {
+        if editor.is_mutable(*fork)
+            && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins)
+        {
+            break;
+        }
+    }
+}
+
+/*
+ * Tries to fuse a given fork join with the immediately following fork-join, if
+ * it exists.
+ */
+fn fork_fusion(
+    editor: &mut FunctionEditor,
+    top_fork: NodeID,
+    top_join: NodeID,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> bool {
+    let nodes = &editor.func().nodes;
+    // Rust operator precedence is not such that these can be put in one big
+    // let-else statement. Sad!
+    let Some(bottom_fork) = editor
+        .get_users(top_join)
+        .filter(|id| nodes[id.idx()].is_control())
+        .next()
+    else {
+        return false;
+    };
+    let Some(bottom_join) = fork_join_map.get(&bottom_fork) else {
+        return false;
+    };
+    let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap();
+    let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap();
+    assert_eq!(bottom_fork_pred, top_join);
+    let top_join_pred = nodes[top_join.idx()].try_join().unwrap();
+    let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap();
+
+    // The fork factors must be identical.
+    if top_factors != bottom_factors {
+        return false;
+    }
+
+    // Check that no iterated users of the top's reduces are in the bottom fork-
+    // join (iteration stops at a phi or reduce outside the bottom fork-join).
+    for reduce in editor
+        .get_users(top_join)
+        .filter(|id| nodes[id.idx()].is_reduce())
+    {
+        let mut visited = HashSet::new();
+        visited.insert(reduce);
+        let mut workset = vec![reduce];
+        while let Some(pop) = workset.pop() {
+            for u in editor.get_users(pop) {
+                if nodes_in_fork_joins[&bottom_fork].contains(&u) {
+                    return false;
+                } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce())
+                    && !nodes_in_fork_joins[&top_fork].contains(&u)
+                {
+                } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) {
+                    visited.insert(u);
+                    workset.push(u);
+                }
+            }
+        }
+    }
+
+    // Perform the fusion.
+    editor.edit(|mut edit| {
+        if bottom_join_pred != bottom_fork {
+            // If there is control flow in the bottom fork-join, stitch it into
+            // the top fork-join.
+            edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| {
+                nodes_in_fork_joins[&bottom_fork].contains(id)
+            })?;
+            edit =
+                edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?;
+        }
+        // Replace the bottom fork and join with the top fork and join.
+        edit = edit.replace_all_uses(bottom_fork, top_fork)?;
+        edit = edit.replace_all_uses(*bottom_join, top_join)?;
+        edit = edit.delete_node(bottom_fork)?;
+        edit = edit.delete_node(*bottom_join)?;
+        Ok(edit)
+    })
+}
diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index e6961faf01bc90114186e3ea72b6b9671e2758d4..8158bf0a9bec25d57e1136bbf040ae0f561dcc3f 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
         let v  = v1 * v1 + v2 * v2 + v3 * v3;
         l2_dist[cp] = sqrt!::<f32>(v);
       }
-
+      
       @channel_loop for chan = 0 to CHAN {
         let chan_val : f32 = 0.0;
         for cp = 0 to num_ctrl_pts {
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 1ae1dc132942ff868a25080d297a25f06c8395a2..0b8869d55de67ab8e796397bc25d5c15cc926d59 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -101,7 +101,9 @@ fixpoint {
 simpl!(fuse4);
 fork-unroll(fuse4@channel_loop);
 simpl!(fuse4);
-//fork-fusion(fuse4@channel_loop);
+fixpoint {
+  fork-fusion(fuse4@channel_loop);
+}
 simpl!(fuse4);
 array-slf(fuse4);
 simpl!(fuse4);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 912cc91f7f5fd968374a50d304dfb4e5cce1654f..7c92e00d93dc100c6aec313c68ec17e42f7c44e9 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -129,6 +129,7 @@ impl FromStr for Appliable {
             "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
+            "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 8ad923242a3b73f8b9379f1f9b51070245d2030c..205cd70b710dc810a89e88a3683fde4a9f695e5b 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -13,6 +13,7 @@ pub enum Pass {
     ForkCoalesce,
     ForkDimMerge,
     ForkFissionBufferize,
+    ForkFusion,
     ForkGuardElim,
     ForkInterchange,
     ForkSplit,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index b26d1720fc192378a214e3951d15e957c317b77f..a4783a93701564f041237f2fecef8a22c62abe5c 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2526,6 +2526,27 @@ fn run_pass(
                 fields: new_fork_joins,
             };
         }
+        Pass::ForkFusion => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            for ((func, fork_join_map), nodes_in_fork_joins) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();