diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index ff0f0283767996914e8f9b2274ed9a6d538b1812..cb0e7de48cb0efe01a586f9ae633f66cb61703a7 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -578,7 +578,7 @@ pub fn fork_coalesce(
     // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
     // something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
     for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
-        if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
+        if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
             return true;
         }
     }
@@ -587,13 +587,15 @@ pub fn fork_coalesce(
 
 /** Opposite of fork split, takes two fork-joins
     with no control between them, and merges them into a single fork-join.
+    Returns None if the forks could not be merged and the NodeIDs of the
+    resulting fork and join if it succeeds in merging them.
 */
 pub fn fork_coalesce_helper(
     editor: &mut FunctionEditor,
     outer_fork: NodeID,
     inner_fork: NodeID,
     fork_join_map: &HashMap<NodeID, NodeID>,
-) -> bool {
+) -> Option<(NodeID, NodeID)> {
     // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
 
     let outer_join = fork_join_map[&outer_fork];
@@ -621,19 +623,19 @@ pub fn fork_coalesce_helper(
             reduct: _,
         } = inner_reduce_node
         else {
-            return false;
+            return None;
         };
 
         // FIXME: check this condition better (i.e reduce might not be attached to join)
         if *inner_control != inner_join {
-            return false;
+            return None;
         };
         if *inner_init != outer_reduce {
-            return false;
+            return None;
         };
 
         if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
-            return false;
+            return None;
         } else {
             pairs.insert(outer_reduce, inner_reduce);
         }
@@ -645,11 +647,11 @@ pub fn fork_coalesce_helper(
         .filter(|node| editor.func().nodes[node.idx()].is_control())
         .next()
     else {
-        return false;
+        return None;
     };
 
     if user != inner_fork {
-        return false;
+        return None;
     }
 
     let Some(user) = editor
@@ -657,11 +659,11 @@ pub fn fork_coalesce_helper(
         .filter(|node| editor.func().nodes[node.idx()].is_control())
         .next()
     else {
-        return false;
+        return None;
     };
 
     if user != outer_join {
-        return false;
+        return None;
     }
 
     // Checklist:
@@ -709,10 +711,10 @@ pub fn fork_coalesce_helper(
         let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
             .try_reduce()
             .unwrap();
-        editor.edit(|mut edit| {
+        let success = editor.edit(|mut edit| {
             // Set inner init to outer init.
             edit =
-                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
+                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
             edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
             edit = edit.delete_node(outer_reduce)?;
 
@@ -720,12 +722,15 @@ pub fn fork_coalesce_helper(
         });
     }
 
+    let mut new_fork = NodeID::new(0);
+    let new_join = inner_join; // We reuse the inner join as the join of the new fork
+
     editor.edit(|mut edit| {
-        let new_fork = Node::Fork {
+        let new_fork_node = Node::Fork {
             control: outer_pred,
             factors: new_factors.into(),
         };
-        let new_fork = edit.add_node(new_fork);
+        new_fork = edit.add_node(new_fork_node);
 
         edit = edit.replace_all_uses(inner_fork, new_fork)?;
         edit = edit.replace_all_uses(outer_fork, new_fork)?;
@@ -737,7 +742,7 @@ pub fn fork_coalesce_helper(
         Ok(edit)
     });
 
-    true
+    Some((new_fork, new_join))
 }
 
 pub fn split_any_fork(
@@ -760,7 +765,7 @@ pub fn split_any_fork(
  * Useful for code generation. A single iteration of `fork_split` only splits
  * at most one fork-join, it must be called repeatedly to split all fork-joins.
  */
-pub(crate) fn split_fork(
+pub fn split_fork(
     editor: &mut FunctionEditor,
     fork: NodeID,
     join: NodeID,
@@ -1215,13 +1220,13 @@ pub fn fork_interchange_all_forks(
     }
 }
 
-fn fork_interchange(
+pub fn fork_interchange(
     editor: &mut FunctionEditor,
     fork: NodeID,
     join: NodeID,
     first_dim: usize,
     second_dim: usize,
-) {
+) -> Option<NodeID> {
     // Check that every reduce on the join is parallel or associative.
     let nodes = &editor.func().nodes;
     let schedules = &editor.func().schedules;
@@ -1234,7 +1239,7 @@ fn fork_interchange(
         })
     {
         // If not, we can't necessarily do interchange.
-        return;
+        return None;
     }
 
     let Node::Fork {
@@ -1276,6 +1281,7 @@ fn fork_interchange(
     let mut factors = factors.clone();
     factors.swap(first_dim, second_dim);
     let new_fork = Node::Fork { control, factors };
+    let mut new_fork_id = None;
     editor.edit(|mut edit| {
         for (old_id, new_tid) in fix_tids {
             let new_id = edit.add_node(new_tid);
@@ -1283,9 +1289,12 @@ fn fork_interchange(
             edit = edit.delete_node(old_id)?;
         }
         let new_fork = edit.add_node(new_fork);
+        new_fork_id = Some(new_fork);
         edit = edit.replace_all_uses(fork, new_fork)?;
         edit.delete_node(fork)
     });
+
+    new_fork_id
 }
 
 /*