diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 1df5338ddb3a4fa1a3ff65db27a6c9571dd21bb0..7367a63e51fad5dc3fb5dfb5de0e2426f46b41e3 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -310,23 +310,33 @@ pub fn fork_fission<'a>(
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     loop_tree: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_label: LabelID,
 ) -> Vec<NodeID> {
-    let mut forks: Vec<_> = loop_tree
+    let forks: Vec<_> = loop_tree
         .bottom_up_loops()
         .into_iter()
         .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
         .collect();
 
     let mut created_forks = Vec::new();
+    
     // This does the reduction fission 
-    for fork in forks.clone() {
+    for fork in forks {
         let join = fork_join_map[&fork.0];
+
+        // FIXME: Don't make multiple forks for reduces that are in cycles with each other. 
         let reduce_partition = default_reduce_partition(editor, fork.0, join);
 
+        if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
+            continue;
+        }
+
         if editor.is_mutable(fork.0) {
             created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
-            if !created_forks.is_empty() {
-                break;
+            if created_forks.is_empty() {
+                continue;
+            } else {
+                return created_forks;
             }
         }
             
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 71a6eb928fd54afd7f6e79e41f1f611bda8ed5ca..a3bb146b3f8d456083daeb4ba62b8bb5369b7a71 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -98,15 +98,11 @@ dce(auto.test8);
 no-memset(test9@const);
 
 fork-split(auto.test10);
-xdot[true](auto.test10);
-fork-fission-reduces(auto.test10);
-dce(auto.test10);
-xdot[false](auto.test10);
+let fission_out = fork-fission-reduces[test10@loop3](auto.test10);
+fork-fusion(fission_out.test10_8.fj0, fission_out.test10_8.fj1);
 dce(auto.test10);
 simplify-cfg(auto.test10);
 dce(auto.test10);
-xdot[true](auto.test10);
-
 fork-split(auto.test10);
 
 unforkify(auto.test10);
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 81dc8d9854776931f4598a9010837008796baaf8..0d828f489cea8251b0b008efb1a3ddfae857547f 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -13,7 +13,7 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10);
 gpu(auto.test1);
 gpu(auto.test2);
 gpu(auto.test3);
@@ -22,6 +22,7 @@ gpu(auto.test5);
 gpu(auto.test7);
 gpu(auto.test8);
 gpu(auto.test9);
+gpu(auto.test10);
 
 ip-sroa(*);
 sroa(*);
@@ -75,6 +76,12 @@ dce(auto.test8);
 
 no-memset(test9@const);
 
+fork-split(auto.test10);
+fork-fission-reduces(auto.test10);
+dce(auto.test10);
+simplify-cfg(auto.test10);
+dce(auto.test10);
+
 ip-sroa(*);
 sroa(*);
 dce(*);
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index d381e861c323f5d1326e26cccb2cf916a0bf735f..dd7a76a7da6b58bb5bc723116e84318b4b6346dc 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -60,6 +60,7 @@ impl Pass {
             Pass::ForkChunk => num == 4,
             Pass::ForkExtend => num == 1,
             Pass::ForkFissionBufferize => num == 2 || num == 1,
+            Pass::ForkFission => num == 1,
             Pass::ForkInterchange => num == 2,
             Pass::ForkReshape => true,
             Pass::InterproceduralSROA => num == 0 || num == 1,
@@ -78,6 +79,7 @@ impl Pass {
             Pass::ForkChunk => "4",
             Pass::ForkExtend => "1",
             Pass::ForkFissionBufferize => "1 or 2",
+            Pass::ForkFission => "1",
             Pass::ForkInterchange => "2",
             Pass::ForkReshape => "any",
             Pass::InterproceduralSROA => "0 or 1",
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 9281925a96dda6fbad44c746bf824206f13c34d7..666ebdcd726ff3347fedac45471d20e33ed9d995 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2742,7 +2742,18 @@ fn run_pass(
             pm.clear_analyses();
         },
         Pass::ForkFission => {
-            assert!(args.len() == 0);
+            assert!(args.len() == 1);
+
+            let Some(Value::Label {
+                labels: fork_labels,
+            }) = args.get(0)
+            else {
+                return Err(SchedulerError::PassError {
+                    pass: "forkFission".to_string(),
+                    error: "expected label argument".to_string(),
+                });
+            };
+
 
             pm.make_fork_join_maps();
             pm.make_loops();
@@ -2755,6 +2766,17 @@ fn run_pass(
 
             let mut created_fork_joins = vec![vec![]; pm.functions.len()];
 
+            // assert only one function is in the selection.
+            let num_functions = build_selection(pm, selection.clone(), false)
+                .iter()
+                .filter(|func| func.is_some())
+                .count();
+
+            assert!(num_functions <= 1);
+            assert_eq!(fork_labels.len(), 1);
+
+            let fork_label = fork_labels[0].label;
+
             for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins,
             ) in build_selection(pm, selection, false)
                 .into_iter()
@@ -2773,6 +2795,7 @@ fn run_pass(
                     reduce_cycles,
                     loop_tree,
                     fork_join_map,
+                    fork_label,
                 );
                 
                 let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
@@ -2810,7 +2833,7 @@ fn run_pass(
                 // level of the split fork-joins being referred to.
                 let mut func_record = HashMap::new();
                 for (idx, label) in labels {
-                    let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" };
+                    let fmt = format!("fj{idx}");
                     func_record.insert(
                         fmt.to_string(),
                         Value::Label {
@@ -2874,14 +2897,6 @@ fn run_pass(
 
             let fork_label = fork_labels[0].label;
 
-            // assert only one function is in the selection.
-            let num_functions = build_selection(pm, selection.clone(), false)
-            .iter()
-            .filter(|func| func.is_some())
-            .count();
-        
-            assert!(num_functions <= 1);
-
             for (
                 ((((func, fork_join_map), loop_tree), typing), reduce_cycles),
                 nodes_in_fork_joins,