diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 12b91194749b6125418b3639fa1d29d6b0397fc1..9a16c99c7d3decceff04cc7a4ce8b13a149f96ad 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -676,50 +676,47 @@ pub fn fork_coalesce_helper(
     // CHECKME / FIXME: Might need to be added the other way.
     new_factors.append(&mut inner_dims.to_vec());
 
-    for tid in inner_tids {
-        let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
-        let new_tid = Node::ThreadID {
-            control: fork,
-            dimension: dim + num_outer_dims,
-        };
+    let mut new_fork = NodeID::new(0);
+    let new_join = inner_join; // We'll reuse the inner join as the join of the new fork
+
+    let success = editor.edit(|mut edit| {
+        for tid in inner_tids {
+            let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
+            let new_tid = Node::ThreadID {
+                control: fork,
+                dimension: dim + num_outer_dims,
+            };
 
-        editor.edit(|mut edit| {
             let new_tid = edit.add_node(new_tid);
-            let mut edit = edit.replace_all_uses(tid, new_tid)?;
+            edit = edit.replace_all_uses(tid, new_tid)?;
             edit.sub_edit(tid, new_tid);
-            Ok(edit)
-        });
-    }
-
-    // Fuse Reductions
-    for (outer_reduce, inner_reduce) in pairs {
-        let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
-            .try_reduce()
-            .unwrap();
-        let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
-            .try_reduce()
-            .unwrap();
-        let success = editor.edit(|mut edit| {
+        }
+        // Fuse Reductions
+        for (outer_reduce, inner_reduce) in pairs {
+            let (_, outer_init, _) = edit.get_node(outer_reduce)
+                .try_reduce()
+                .unwrap();
+            let (_, inner_init, _) = edit.get_node(inner_reduce)
+                .try_reduce()
+                .unwrap();
             // Set inner init to outer init.
             edit =
                 edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
             edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
             edit = edit.delete_node(outer_reduce)?;
+        }
 
-            Ok(edit)
-        });
-    }
-
-    let mut new_fork = NodeID::new(0);
-    let new_join = inner_join; // We reuse the inner join as the join of the new fork
-
-    editor.edit(|mut edit| {
         let new_fork_node = Node::Fork {
             control: outer_pred,
             factors: new_factors.into(),
         };
         new_fork = edit.add_node(new_fork_node);
 
+        if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork)
+            && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) {
+            edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
+        }
+
         edit = edit.replace_all_uses(inner_fork, new_fork)?;
         edit = edit.replace_all_uses(outer_fork, new_fork)?;
         edit = edit.replace_all_uses(outer_join, inner_join)?;
@@ -730,7 +727,11 @@ pub fn fork_coalesce_helper(
         Ok(edit)
     });
 
-    Some((new_fork, new_join))
+    if success {
+        Some((new_fork, new_join))
+    } else {
+        None
+    }
 }
 
 pub fn split_any_fork(
@@ -1277,9 +1278,14 @@ pub fn fork_interchange(
             edit = edit.delete_node(old_id)?;
         }
         let new_fork = edit.add_node(new_fork);
-        new_fork_id = Some(new_fork);
+        if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
+            edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
+        }
         edit = edit.replace_all_uses(fork, new_fork)?;
-        edit.delete_node(fork)
+        edit = edit.delete_node(fork)?;
+
+        new_fork_id = Some(new_fork);
+        Ok(edit)
     });
 
     new_fork_id