From 52c0eb8a37c28f91e9ad6c431bbfdf9242631256 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Sun, 2 Feb 2025 16:54:06 -0600
Subject: [PATCH] More fork tests, bug fixes in fork passes and GPU backend

---
 Cargo.lock                                    |  10 +
 Cargo.toml                                    |   1 +
 hercules_cg/src/gpu.rs                        |  85 +++--
 hercules_ir/src/fork_join_analysis.rs         |  34 +-
 hercules_ir/src/ir.rs                         |   2 +-
 hercules_ir/src/loops.rs                      |   4 +
 hercules_opt/src/float_collections.rs         |  27 +-
 hercules_opt/src/fork_concat_split.rs         | 141 --------
 hercules_opt/src/fork_transforms.rs           | 155 ++++++++
 hercules_opt/src/gcm.rs                       |  77 +++-
 hercules_opt/src/lib.rs                       |   2 -
 hercules_opt/src/unforkify.rs                 | 335 +++++++++---------
 juno_samples/fork_join_tests/Cargo.toml       |  21 ++
 juno_samples/fork_join_tests/build.rs         |  24 ++
 juno_samples/fork_join_tests/src/cpu.sch      |  42 +++
 .../fork_join_tests/src/fork_join_tests.jn    |  61 ++++
 juno_samples/fork_join_tests/src/gpu.sch      |  33 ++
 juno_samples/fork_join_tests/src/main.rs      |  46 +++
 juno_scheduler/src/pm.rs                      |  26 +-
 19 files changed, 752 insertions(+), 374 deletions(-)
 delete mode 100644 hercules_opt/src/fork_concat_split.rs
 create mode 100644 juno_samples/fork_join_tests/Cargo.toml
 create mode 100644 juno_samples/fork_join_tests/build.rs
 create mode 100644 juno_samples/fork_join_tests/src/cpu.sch
 create mode 100644 juno_samples/fork_join_tests/src/fork_join_tests.jn
 create mode 100644 juno_samples/fork_join_tests/src/gpu.sch
 create mode 100644 juno_samples/fork_join_tests/src/main.rs

diff --git a/Cargo.lock b/Cargo.lock
index b8bf2278..af7902c6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1130,6 +1130,16 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_fork_join_tests"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_frontend"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index f7b9322a..890d7924 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,4 +31,5 @@ members = [
 	"juno_samples/concat",
 	"juno_samples/schedule_test",
 	"juno_samples/edge_detection",
+	"juno_samples/fork_join_tests",
 ]
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 81e31396..1086d1aa 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -323,7 +323,10 @@ impl GPUContext<'_> {
 
         // Emit all GPU kernel code from previous steps
         let mut kernel_body = String::new();
-        self.codegen_gotos(false, &mut gotos, &mut kernel_body)?;
+        let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
+        write!(w, "\n")?;
+        self.codegen_goto(false, &mut gotos, NodeID::new(0), &mut kernel_body)?;
+        self.codegen_gotos(false, &mut gotos, &rev_po, NodeID::new(0), &mut kernel_body)?;
         write!(w, "{}", kernel_body)?;
         write!(w, "}}\n")?;
 
@@ -536,27 +539,62 @@ namespace cg = cooperative_groups;
         &self,
         goto_debug: bool,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
+        rev_po: &Vec<NodeID>,
+        root: NodeID,
         w: &mut String,
     ) -> Result<(), Error> {
-        write!(w, "\n")?;
-        for (id, goto) in gotos.iter() {
-            let goto_block = self.get_block_name(*id, false);
-            write!(w, "{}:\n", goto_block)?;
-            if goto_debug {
-                write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?;
-            }
-            write!(w, "{}", goto.init)?;
-            if !goto.post_init.is_empty() {
-                let goto_block = self.get_block_name(*id, true);
-                write!(w, "{}:\n", goto_block)?;
-                write!(w, "{}", goto.post_init)?;
+        // Print the blocks in a kind of silly way to avoid errors aroun
+        // initialization of fork variables and gotos.
+        let mut not_forks = vec![];
+        let mut forks = vec![];
+        let not_fork_controls = &self.fork_control_map[&root];
+        for bb in rev_po
+            .into_iter()
+            .filter(|id| not_fork_controls.contains(id) && **id != root)
+        {
+            not_forks.push(*bb);
+        }
+        if let Some(fork_controls) = &self.fork_tree.get(&root) {
+            for bb in rev_po
+                .into_iter()
+                .filter(|id| fork_controls.contains(id) && **id != root)
+            {
+                forks.push(*bb);
             }
-            write!(w, "{}", goto.body)?;
-            write!(w, "{}\n", goto.term)?;
+        }
+        for id in not_forks {
+            self.codegen_goto(goto_debug, gotos, id, w)?;
+        }
+        for root in forks {
+            self.codegen_goto(goto_debug, gotos, root, w)?;
+            self.codegen_gotos(goto_debug, gotos, rev_po, root, w)?;
         }
         Ok(())
     }
 
+    fn codegen_goto(
+        &self,
+        goto_debug: bool,
+        gotos: &mut BTreeMap<NodeID, CudaGoto>,
+        bb: NodeID,
+        w: &mut String,
+    ) -> Result<(), Error> {
+        let goto = &gotos[&bb];
+        let goto_block = self.get_block_name(bb, false);
+        write!(w, "{}:\n", goto_block)?;
+        if goto_debug {
+            write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?;
+        }
+        write!(w, "{}", goto.init)?;
+        if !goto.post_init.is_empty() {
+            let goto_block = self.get_block_name(bb, true);
+            write!(w, "{}:\n", goto_block)?;
+            write!(w, "{}", goto.post_init)?;
+        }
+        write!(w, "{}", goto.body)?;
+        write!(w, "{}\n", goto.term)
+    }
+
     fn codegen_launch_code(
         &self,
         num_blocks: String,
@@ -620,23 +658,23 @@ extern \"C\" {} {}(",
             write!(pass_args, "ret")?;
             write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
         }
-        write!(w, "\tcudaError_t err;\n");
+        write!(w, "\tcudaError_t err;\n")?;
         write!(
             w,
             "\t{}_gpu<<<{}, {}, {}>>>({});\n",
             self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args
         )?;
-        write!(w, "\terr = cudaGetLastError();\n");
+        write!(w, "\terr = cudaGetLastError();\n")?;
         write!(
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
-        );
+        )?;
         write!(w, "\tcudaDeviceSynchronize();\n")?;
-        write!(w, "\terr = cudaGetLastError();\n");
+        write!(w, "\terr = cudaGetLastError();\n")?;
         write!(
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error2: %s\\n\", cudaGetErrorString(err)); }}\n"
-        );
+        )?;
         if has_ret_var {
             // Copy return from device to host, whether it's primitive value or collection pointer
             write!(w, "\t{} host_ret;\n", ret_type)?;
@@ -1148,7 +1186,8 @@ extern \"C\" {} {}(",
             // for all threads. Otherwise, it can be inside or outside block fork.
             // If inside, it's stored in shared memory so we "allocate" it once
             // and parallelize memset to 0. If outside, we initialize as offset
-            // to backing, but if multi-block grid, don't memset to avoid grid-level sync.
+            // to backing, but if multi-block grid, don't memset to avoid grid-
+            // level sync.
             Node::Constant { id: cons_id } => {
                 let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive();
                 let cg_tile = match state {
@@ -1190,9 +1229,7 @@ extern \"C\" {} {}(",
                     )?;
                 }
                 if !is_primitive
-                    && (state != KernelState::OutBlock
-                        || is_block_parallel.is_none()
-                        || !is_block_parallel.unwrap())
+                    && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
                 {
                     let data_size =
                         self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects));
diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs
index 263fa952..7a098a35 100644
--- a/hercules_ir/src/fork_join_analysis.rs
+++ b/hercules_ir/src/fork_join_analysis.rs
@@ -140,23 +140,23 @@ fn reduce_cycle_dfs_helper(
     }
 
     current_visited.insert(iter);
-    let found_reduce = get_uses(&function.nodes[iter.idx()])
-        .as_ref()
-        .into_iter()
-        .any(|u| {
-            !current_visited.contains(u)
-                && !function.nodes[u.idx()].is_control()
-                && isnt_outside_fork_join(*u)
-                && reduce_cycle_dfs_helper(
-                    function,
-                    *u,
-                    fork,
-                    reduce,
-                    current_visited,
-                    in_cycle,
-                    fork_join_nest,
-                )
-        });
+    let mut found_reduce = false;
+
+    // This doesn't short circuit on purpose.
+    for u in get_uses(&function.nodes[iter.idx()]).as_ref() {
+        found_reduce |= !current_visited.contains(u)
+            && !function.nodes[u.idx()].is_control()
+            && isnt_outside_fork_join(*u)
+            && reduce_cycle_dfs_helper(
+                function,
+                *u,
+                fork,
+                reduce,
+                current_visited,
+                in_cycle,
+                fork_join_nest,
+            )
+    }
     if found_reduce {
         in_cycle.insert(iter);
     }
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 846347b0..5c575ea1 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1857,7 +1857,7 @@ pub fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
  * list of indices B.
  */
 pub fn indices_contain_other_indices(indices_a: &[Index], indices_b: &[Index]) -> bool {
-    if indices_a.len() < indices_b.len() {
+    if indices_a.len() > indices_b.len() {
         return false;
     }
 
diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs
index a425c442..daf85968 100644
--- a/hercules_ir/src/loops.rs
+++ b/hercules_ir/src/loops.rs
@@ -40,6 +40,10 @@ impl LoopTree {
         self.loops[&header].0.iter_ones().map(NodeID::new)
     }
 
+    pub fn nodes_in_loop_bitvec(&self, header: NodeID) -> &BitVec<u8, Lsb0> {
+        &self.loops[&header].0
+    }
+
     pub fn is_in_loop(&self, header: NodeID, is_in: NodeID) -> bool {
         header == self.root || self.loops[&header].0[is_in.idx()]
     }
diff --git a/hercules_opt/src/float_collections.rs b/hercules_opt/src/float_collections.rs
index faa38375..6ef050c2 100644
--- a/hercules_opt/src/float_collections.rs
+++ b/hercules_opt/src/float_collections.rs
@@ -1,3 +1,5 @@
+use std::collections::BTreeMap;
+
 use hercules_ir::*;
 
 use crate::*;
@@ -7,27 +9,36 @@ use crate::*;
  * allowed.
  */
 pub fn float_collections(
-    editors: &mut [FunctionEditor],
+    editors: &mut BTreeMap<FunctionID, FunctionEditor>,
     typing: &ModuleTyping,
     callgraph: &CallGraph,
     devices: &Vec<Device>,
 ) {
-    let topo = callgraph.topo();
+    let topo: Vec<_> = callgraph
+        .topo()
+        .into_iter()
+        .filter(|id| editors.contains_key(&id))
+        .collect();
     for to_float_id in topo {
         // Collection constants float until reaching an AsyncRust function.
         if devices[to_float_id.idx()] == Device::AsyncRust {
             continue;
         }
 
+        // Check that all callers are in the selection as well.
+        for caller in callgraph.get_callers(to_float_id) {
+            assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller);
+        }
+
         // Find the target constant nodes in the function.
-        let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()]
+        let cons: Vec<(NodeID, Node)> = editors[&to_float_id]
             .func()
             .nodes
             .iter()
             .enumerate()
             .filter(|(_, node)| {
                 node.try_constant()
-                    .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar())
+                    .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar())
                     .unwrap_or(false)
             })
             .map(|(idx, node)| (NodeID::new(idx), node.clone()))
@@ -37,12 +48,12 @@ pub fn float_collections(
         }
 
         // Each constant node becomes a new parameter.
-        let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone();
+        let mut new_param_types = editors[&to_float_id].func().param_types.clone();
         let old_num_params = new_param_types.len();
         for (id, _) in cons.iter() {
             new_param_types.push(typing[to_float_id.idx()][id.idx()]);
         }
-        let success = editors[to_float_id.idx()].edit(|mut edit| {
+        let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| {
             for (idx, (id, _)) in cons.iter().enumerate() {
                 let param = edit.add_node(Node::Parameter {
                     index: idx + old_num_params,
@@ -59,7 +70,7 @@ pub fn float_collections(
 
         // Add constants in callers and pass them into calls.
         for caller in callgraph.get_callers(to_float_id) {
-            let calls: Vec<(NodeID, Node)> = editors[caller.idx()]
+            let calls: Vec<(NodeID, Node)> = editors[&caller]
                 .func()
                 .nodes
                 .iter()
@@ -71,7 +82,7 @@ pub fn float_collections(
                 })
                 .map(|(idx, node)| (NodeID::new(idx), node.clone()))
                 .collect();
-            let success = editors[caller.idx()].edit(|mut edit| {
+            let success = editors.get_mut(&caller).unwrap().edit(|mut edit| {
                 let cons_ids: Vec<_> = cons
                     .iter()
                     .map(|(_, node)| edit.add_node(node.clone()))
diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs
deleted file mode 100644
index bb3a2cff..00000000
--- a/hercules_opt/src/fork_concat_split.rs
+++ /dev/null
@@ -1,141 +0,0 @@
-use std::collections::{HashMap, HashSet};
-use std::iter::zip;
-
-use hercules_ir::ir::*;
-
-use crate::*;
-
-/*
- * Split multi-dimensional fork-joins into separate one-dimensional fork-joins.
- * Useful for code generation. A single iteration of `fork_split` only splits
- * at most one fork-join, it must be called repeatedly to split all fork-joins.
- */
-pub fn fork_split(
-    editor: &mut FunctionEditor,
-    fork_join_map: &HashMap<NodeID, NodeID>,
-    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
-) {
-    // A single multi-dimensional fork becomes multiple forks, a join becomes
-    // multiple joins, a thread ID becomes a thread ID on the correct
-    // fork, and a reduce becomes multiple reduces to shuffle the reduction
-    // value through the fork-join nest.
-    for (fork, join) in fork_join_map {
-        let nodes = &editor.func().nodes;
-        let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
-        if factors.len() < 2 {
-            continue;
-        }
-        let factors: Box<[DynamicConstantID]> = factors.into();
-        let join_control = nodes[join.idx()].try_join().unwrap();
-        let tids: Vec<_> = editor
-            .get_users(*fork)
-            .filter(|id| nodes[id.idx()].is_thread_id())
-            .collect();
-        let reduces: Vec<_> = editor
-            .get_users(*join)
-            .filter(|id| nodes[id.idx()].is_reduce())
-            .collect();
-
-        let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces
-            .iter()
-            .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce)))
-            .flatten()
-            .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user))
-            .collect();
-
-        editor.edit(|mut edit| {
-            // Create the forks and a thread ID per fork.
-            let mut acc_fork = fork_control;
-            let mut new_tids = vec![];
-            for factor in factors {
-                acc_fork = edit.add_node(Node::Fork {
-                    control: acc_fork,
-                    factors: Box::new([factor]),
-                });
-                edit.sub_edit(*fork, acc_fork);
-                new_tids.push(edit.add_node(Node::ThreadID {
-                    control: acc_fork,
-                    dimension: 0,
-                }));
-            }
-
-            // Create the joins.
-            let mut acc_join = if join_control == *fork {
-                acc_fork
-            } else {
-                join_control
-            };
-            let mut joins = vec![];
-            for _ in new_tids.iter() {
-                acc_join = edit.add_node(Node::Join { control: acc_join });
-                edit.sub_edit(*join, acc_join);
-                joins.push(acc_join);
-            }
-
-            // Create the reduces.
-            let mut new_reduces = vec![];
-            for reduce in reduces.iter() {
-                let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap();
-                let num_nodes = edit.num_node_ids();
-                let mut inner_reduce = NodeID::new(0);
-                let mut outer_reduce = NodeID::new(0);
-                for (join_idx, join) in joins.iter().enumerate() {
-                    let init = if join_idx == joins.len() - 1 {
-                        init
-                    } else {
-                        NodeID::new(num_nodes + join_idx + 1)
-                    };
-                    let reduct = if join_idx == 0 {
-                        reduct
-                    } else {
-                        NodeID::new(num_nodes + join_idx - 1)
-                    };
-                    let new_reduce = edit.add_node(Node::Reduce {
-                        control: *join,
-                        init,
-                        reduct,
-                    });
-                    assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx));
-                    edit.sub_edit(*reduce, new_reduce);
-                    if join_idx == 0 {
-                        inner_reduce = new_reduce;
-                    }
-                    if join_idx == joins.len() - 1 {
-                        outer_reduce = new_reduce;
-                    }
-                }
-                new_reduces.push((inner_reduce, outer_reduce));
-            }
-
-            // Replace everything.
-            edit = edit.replace_all_uses(*fork, acc_fork)?;
-            edit = edit.replace_all_uses(*join, acc_join)?;
-            for tid in tids.iter() {
-                let dim = edit.get_node(*tid).try_thread_id().unwrap().1;
-                edit.sub_edit(*tid, new_tids[dim]);
-                edit = edit.replace_all_uses(*tid, new_tids[dim])?;
-            }
-            for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) {
-                edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| {
-                    data_in_reduce_cycle.contains(&(*id, *reduce))
-                })?;
-                edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| {
-                    !data_in_reduce_cycle.contains(&(*id, *reduce))
-                })?;
-            }
-
-            // Delete all the old stuff.
-            edit = edit.delete_node(*fork)?;
-            edit = edit.delete_node(*join)?;
-            for tid in tids {
-                edit = edit.delete_node(tid)?;
-            }
-            for reduce in reduces {
-                edit = edit.delete_node(reduce)?;
-            }
-
-            Ok(edit)
-        });
-        break;
-    }
-}
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index a4605bec..e23f586f 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,4 +1,5 @@
 use std::collections::{HashMap, HashSet};
+use std::iter::zip;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -538,3 +539,157 @@ pub fn fork_coalesce_helper(
 
     true
 }
+
+pub fn split_all_forks(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+) {
+    for (fork, join) in fork_join_map {
+        if let Some((forks, _)) = split_fork(editor, *fork, *join, reduce_cycles)
+            && forks.len() > 1
+        {
+            break;
+        }
+    }
+}
+
+/*
+ * Split multi-dimensional fork-joins into separate one-dimensional fork-joins.
+ * Useful for code generation. A single iteration of `fork_split` only splits
+ * at most one fork-join, it must be called repeatedly to split all fork-joins.
+ */
+pub(crate) fn split_fork(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    join: NodeID,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
+    // A single multi-dimensional fork becomes multiple forks, a join becomes
+    // multiple joins, a thread ID becomes a thread ID on the correct
+    // fork, and a reduce becomes multiple reduces to shuffle the reduction
+    // value through the fork-join nest.
+    let nodes = &editor.func().nodes;
+    let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
+    if factors.len() < 2 {
+        return Some((vec![fork], vec![join]));
+    }
+    let factors: Box<[DynamicConstantID]> = factors.into();
+    let join_control = nodes[join.idx()].try_join().unwrap();
+    let tids: Vec<_> = editor
+        .get_users(fork)
+        .filter(|id| nodes[id.idx()].is_thread_id())
+        .collect();
+    let reduces: Vec<_> = editor
+        .get_users(join)
+        .filter(|id| nodes[id.idx()].is_reduce())
+        .collect();
+
+    let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces
+        .iter()
+        .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce)))
+        .flatten()
+        .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user))
+        .collect();
+
+    let mut new_forks = vec![];
+    let mut new_joins = vec![];
+    let success = editor.edit(|mut edit| {
+        // Create the forks and a thread ID per fork.
+        let mut acc_fork = fork_control;
+        let mut new_tids = vec![];
+        for factor in factors {
+            acc_fork = edit.add_node(Node::Fork {
+                control: acc_fork,
+                factors: Box::new([factor]),
+            });
+            new_forks.push(acc_fork);
+            edit.sub_edit(fork, acc_fork);
+            new_tids.push(edit.add_node(Node::ThreadID {
+                control: acc_fork,
+                dimension: 0,
+            }));
+        }
+
+        // Create the joins.
+        let mut acc_join = if join_control == fork {
+            acc_fork
+        } else {
+            join_control
+        };
+        for _ in new_tids.iter() {
+            acc_join = edit.add_node(Node::Join { control: acc_join });
+            edit.sub_edit(join, acc_join);
+            new_joins.push(acc_join);
+        }
+
+        // Create the reduces.
+        let mut new_reduces = vec![];
+        for reduce in reduces.iter() {
+            let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap();
+            let num_nodes = edit.num_node_ids();
+            let mut inner_reduce = NodeID::new(0);
+            let mut outer_reduce = NodeID::new(0);
+            for (join_idx, join) in new_joins.iter().enumerate() {
+                let init = if join_idx == new_joins.len() - 1 {
+                    init
+                } else {
+                    NodeID::new(num_nodes + join_idx + 1)
+                };
+                let reduct = if join_idx == 0 {
+                    reduct
+                } else {
+                    NodeID::new(num_nodes + join_idx - 1)
+                };
+                let new_reduce = edit.add_node(Node::Reduce {
+                    control: *join,
+                    init,
+                    reduct,
+                });
+                assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx));
+                edit.sub_edit(*reduce, new_reduce);
+                if join_idx == 0 {
+                    inner_reduce = new_reduce;
+                }
+                if join_idx == new_joins.len() - 1 {
+                    outer_reduce = new_reduce;
+                }
+            }
+            new_reduces.push((inner_reduce, outer_reduce));
+        }
+
+        // Replace everything.
+        edit = edit.replace_all_uses(fork, acc_fork)?;
+        edit = edit.replace_all_uses(join, acc_join)?;
+        for tid in tids.iter() {
+            let dim = edit.get_node(*tid).try_thread_id().unwrap().1;
+            edit.sub_edit(*tid, new_tids[dim]);
+            edit = edit.replace_all_uses(*tid, new_tids[dim])?;
+        }
+        for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) {
+            edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| {
+                data_in_reduce_cycle.contains(&(*id, *reduce))
+            })?;
+            edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| {
+                !data_in_reduce_cycle.contains(&(*id, *reduce))
+            })?;
+        }
+
+        // Delete all the old stuff.
+        edit = edit.delete_node(fork)?;
+        edit = edit.delete_node(join)?;
+        for tid in tids {
+            edit = edit.delete_node(tid)?;
+        }
+        for reduce in reduces {
+            edit = edit.delete_node(reduce)?;
+        }
+
+        Ok(edit)
+    });
+    if success {
+        Some((new_forks, new_joins))
+    } else {
+        None
+    }
+}
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 271bfaf1..d9505fde 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1,5 +1,5 @@
 use std::cell::Ref;
-use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
+use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
 use std::iter::{empty, once, zip, FromIterator};
 
 use bitvec::prelude::*;
@@ -40,6 +40,13 @@ use crate::*;
  * necessarily valid anymore, so this function is called in a loop in the pass
  * manager until no more spills are found.
  *
+ * GCM also generally tries to massage the code to be properly formed for the
+ * device backends in other ways. For example, reduction cycles through the
+ * `init` inputs of an inner reduce and a use of a non-parallel outer reduce in
+ * nested fork-joins is not schedulable, since the outer reduce doesn't dominate
+ * the fork corresponding to the inner reduce. In such cases, the outer fork-
+ * join must be split and unforkified.
+ *
  * GCM is additionally complicated by the need to generate code that references
  * objects across multiple devices. In particular, GCM makes sure that every
  * object lives on exactly one device, so that references to that object always
@@ -76,11 +83,16 @@ pub fn gcm(
     dom: &DomTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
     loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
     object_device_demands: &FunctionObjectDeviceDemands,
     backing_allocations: &BackingAllocations,
 ) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> {
+    if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) {
+        return None;
+    }
+
     let bbs = basic_blocks(
         editor.func(),
         editor.func_id(),
@@ -88,8 +100,10 @@ pub fn gcm(
         reverse_postorder,
         dom,
         loops,
+        reduce_cycles,
         fork_join_map,
         objects,
+        devices,
     );
 
     let liveness = liveness_dataflow(
@@ -156,6 +170,37 @@ pub fn gcm(
     Some((bbs, node_colors, backing_allocation))
 }
 
+/*
+ * Do misc. fixups on the IR, such as unforkifying sequential outer forks with
+ * problematic reduces.
+ */
+fn preliminary_fixups(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+) -> bool {
+    let nodes = &editor.func().nodes;
+    for (reduce, cycle) in reduce_cycles {
+        if cycle.into_iter().any(|id| nodes[id.idx()].is_reduce()) {
+            let join = nodes[reduce.idx()].try_reduce().unwrap().0;
+            let fork = fork_join_map
+                .into_iter()
+                .filter(|(_, j)| join == **j)
+                .map(|(f, _)| *f)
+                .next()
+                .unwrap();
+            let (forks, _) = split_fork(editor, fork, join, reduce_cycles).unwrap();
+            if forks.len() > 1 {
+                return true;
+            }
+            unforkify(editor, fork, join, loops);
+            return true;
+        }
+    }
+    false
+}
+
 /*
  * Top level global code motion function. Assigns each data node to one of its
  * immediate control use / user nodes, forming (unordered) basic blocks. Returns
@@ -172,8 +217,10 @@ fn basic_blocks(
     reverse_postorder: &Vec<NodeID>,
     dom: &DomTree,
     loops: &LoopTree,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     fork_join_map: &HashMap<NodeID, NodeID>,
     objects: &CollectionObjects,
+    devices: &Vec<Device>,
 ) -> BasicBlocks {
     let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
 
@@ -244,6 +291,9 @@ fn basic_blocks(
     //    but not forwarding read - forwarding reads are collapsed, and the
     //    bottom read is treated as reading from the transitive parent of the
     //    forwarding read(s).
+    // 3: If the node producing the collection is a reduce node, then any read
+    //    users that aren't in the reduce's cycle shouldn't anti-depend user any
+    //    mutators in the reduce cycle.
     let mut antideps = BTreeSet::new();
     for id in reverse_postorder.iter() {
         // Find a terminating read node and the collections it reads.
@@ -269,6 +319,10 @@ fn basic_blocks(
             // TODO: make this less outrageously inefficient.
             let func_objects = &objects[&func_id];
             for root in roots.iter() {
+                let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles
+                    .get(root)
+                    .map(|cycle| !cycle.contains(&id))
+                    .unwrap_or(false);
                 let root_early = schedule_early[root.idx()].unwrap();
                 let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new();
                 let mut workset = BTreeSet::new();
@@ -296,6 +350,11 @@ fn basic_blocks(
                         && mutating_objects(function, func_id, *mutator, objects)
                             .any(|mutated| read_objs.contains(&mutated))
                         && id != mutator
+                        && (!root_is_reduce_and_read_isnt_in_cycle
+                            || !reduce_cycles
+                                .get(root)
+                                .map(|cycle| cycle.contains(mutator))
+                                .unwrap_or(false))
                     {
                         antideps.insert((*id, *mutator));
                     }
@@ -421,9 +480,18 @@ fn basic_blocks(
                 // If the next node further up the dominator tree is in a shallower
                 // loop nest or if we can get out of a reduce loop when we don't
                 // need to be in one, place this data node in a higher-up location.
-                // Only do this is the node isn't a constant or undef.
+                // Only do this is the node isn't a constant or undef - if a
+                // node is a constant or undef, we want its placement to be as
+                // control dependent as possible, even inside loops. In GPU
+                // functions specifically, lift constants that may be returned
+                // outside fork-joins.
                 let is_constant_or_undef =
                     function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef();
+                let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
+                    && objects[&func_id]
+                        .objects(id)
+                        .into_iter()
+                        .any(|obj| objects[&func_id].returned_objects().contains(obj));
                 let old_nest = loops
                     .header_of(location)
                     .map(|header| loops.nesting(header).unwrap());
@@ -444,7 +512,10 @@ fn basic_blocks(
                 // loop use the reduce node forming the loop, so the dominator chain
                 // will consist of one block, and this loop won't ever iterate.
                 let currently_at_join = function.nodes[location.idx()].is_join();
-                if !is_constant_or_undef && (shallower_nest || currently_at_join) {
+
+                if (!is_constant_or_undef || is_gpu_returned)
+                    && (shallower_nest || currently_at_join)
+                {
                     location = control_node;
                 }
             }
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index e3cca161..48475f2f 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -6,7 +6,6 @@ pub mod dce;
 pub mod delete_uncalled;
 pub mod editor;
 pub mod float_collections;
-pub mod fork_concat_split;
 pub mod fork_guard_elim;
 pub mod fork_transforms;
 pub mod forkify;
@@ -31,7 +30,6 @@ pub use crate::dce::*;
 pub use crate::delete_uncalled::*;
 pub use crate::editor::*;
 pub use crate::float_collections::*;
-pub use crate::fork_concat_split::*;
 pub use crate::fork_guard_elim::*;
 pub use crate::fork_transforms::*;
 pub use crate::forkify::*;
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 85ffd233..7451b0ad 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -100,11 +100,24 @@ pub fn calculate_fork_nodes(
  */
 
 // FIXME: Only works on fully split fork nests.
-pub fn unforkify(
+pub fn unforkify_all(
     editor: &mut FunctionEditor,
     fork_join_map: &HashMap<NodeID, NodeID>,
     loop_tree: &LoopTree,
 ) {
+    for l in loop_tree.bottom_up_loops().into_iter().rev() {
+        if !editor.node(l.0).is_fork() {
+            continue;
+        }
+
+        let fork = l.0;
+        let join = fork_join_map[&fork];
+
+        unforkify(editor, fork, join, loop_tree);
+    }
+}
+
+pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_tree: &LoopTree) {
     let mut zero_cons_id = ConstantID::new(0);
     let mut one_cons_id = ConstantID::new(0);
     assert!(editor.edit(|mut edit| {
@@ -118,180 +131,170 @@ pub fn unforkify(
     // control insides of the fork-join should become the successor of the true
     // projection node, and what was the use of the join should become a use of
     // the new region.
-    for l in loop_tree.bottom_up_loops().into_iter().rev() {
-        if !editor.node(l.0).is_fork() {
-            continue;
-        }
-
-        let fork = &l.0;
-        let join = &fork_join_map[&fork];
-
-        let fork_nodes = calculate_fork_nodes(editor, l.1, *fork);
+    let fork_nodes = calculate_fork_nodes(editor, loop_tree.nodes_in_loop_bitvec(fork), fork);
 
-        let nodes = &editor.func().nodes;
-        let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
-        if factors.len() > 1 {
-            // For now, don't convert multi-dimensional fork-joins. Rely on pass
-            // that splits fork-joins.
-            continue;
-        }
-        let join_control = nodes[join.idx()].try_join().unwrap();
-        let tids: Vec<_> = editor
-            .get_users(*fork)
-            .filter(|id| nodes[id.idx()].is_thread_id())
-            .collect();
-        let reduces: Vec<_> = editor
-            .get_users(*join)
-            .filter(|id| nodes[id.idx()].is_reduce())
-            .collect();
+    let nodes = &editor.func().nodes;
+    let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
+    if factors.len() > 1 {
+        // For now, don't convert multi-dimensional fork-joins. Rely on pass
+        // that splits fork-joins.
+        return;
+    }
+    let join_control = nodes[join.idx()].try_join().unwrap();
+    let tids: Vec<_> = editor
+        .get_users(fork)
+        .filter(|id| nodes[id.idx()].is_thread_id())
+        .collect();
+    let reduces: Vec<_> = editor
+        .get_users(join)
+        .filter(|id| nodes[id.idx()].is_reduce())
+        .collect();
 
-        let num_nodes = editor.node_ids().len();
-        let region_id = NodeID::new(num_nodes);
-        let if_id = NodeID::new(num_nodes + 1);
-        let proj_back_id = NodeID::new(num_nodes + 2);
-        let proj_exit_id = NodeID::new(num_nodes + 3);
-        let zero_id = NodeID::new(num_nodes + 4);
-        let one_id = NodeID::new(num_nodes + 5);
-        let indvar_id = NodeID::new(num_nodes + 6);
-        let add_id = NodeID::new(num_nodes + 7);
-        let dc_id = NodeID::new(num_nodes + 8);
-        let neq_id = NodeID::new(num_nodes + 9);
+    let num_nodes = editor.node_ids().len();
+    let region_id = NodeID::new(num_nodes);
+    let if_id = NodeID::new(num_nodes + 1);
+    let proj_back_id = NodeID::new(num_nodes + 2);
+    let proj_exit_id = NodeID::new(num_nodes + 3);
+    let zero_id = NodeID::new(num_nodes + 4);
+    let one_id = NodeID::new(num_nodes + 5);
+    let indvar_id = NodeID::new(num_nodes + 6);
+    let add_id = NodeID::new(num_nodes + 7);
+    let dc_id = NodeID::new(num_nodes + 8);
+    let neq_id = NodeID::new(num_nodes + 9);
 
-        let guard_if_id = NodeID::new(num_nodes + 10);
-        let guard_join_id = NodeID::new(num_nodes + 11);
-        let guard_taken_proj_id = NodeID::new(num_nodes + 12);
-        let guard_skipped_proj_id = NodeID::new(num_nodes + 13);
-        let guard_cond_id = NodeID::new(num_nodes + 14);
+    let guard_if_id = NodeID::new(num_nodes + 10);
+    let guard_join_id = NodeID::new(num_nodes + 11);
+    let guard_taken_proj_id = NodeID::new(num_nodes + 12);
+    let guard_skipped_proj_id = NodeID::new(num_nodes + 13);
+    let guard_cond_id = NodeID::new(num_nodes + 14);
 
-        let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new);
-        let s = num_nodes + 15 + reduces.len();
-        let join_phi_ids = (s..s + reduces.len()).map(NodeID::new);
+    let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new);
+    let s = num_nodes + 15 + reduces.len();
+    let join_phi_ids = (s..s + reduces.len()).map(NodeID::new);
 
-        let guard_cond = Node::Binary {
-            left: zero_id,
-            right: dc_id,
-            op: BinaryOperator::LT,
-        };
-        let guard_if = Node::If {
-            control: fork_control,
-            cond: guard_cond_id,
-        };
-        let guard_taken_proj = Node::Projection {
-            control: guard_if_id,
-            selection: 1,
-        };
-        let guard_skipped_proj = Node::Projection {
-            control: guard_if_id,
-            selection: 0,
-        };
-        let guard_join = Node::Region {
-            preds: Box::new([guard_skipped_proj_id, proj_exit_id]),
-        };
+    let guard_cond = Node::Binary {
+        left: zero_id,
+        right: dc_id,
+        op: BinaryOperator::LT,
+    };
+    let guard_if = Node::If {
+        control: fork_control,
+        cond: guard_cond_id,
+    };
+    let guard_taken_proj = Node::Projection {
+        control: guard_if_id,
+        selection: 1,
+    };
+    let guard_skipped_proj = Node::Projection {
+        control: guard_if_id,
+        selection: 0,
+    };
+    let guard_join = Node::Region {
+        preds: Box::new([guard_skipped_proj_id, proj_exit_id]),
+    };
 
-        let region = Node::Region {
-            preds: Box::new([guard_taken_proj_id, proj_back_id]),
-        };
-        let if_node = Node::If {
-            control: join_control,
-            cond: neq_id,
-        };
-        let proj_back = Node::Projection {
-            control: if_id,
-            selection: 1,
-        };
-        let proj_exit = Node::Projection {
-            control: if_id,
-            selection: 0,
-        };
-        let zero = Node::Constant { id: zero_cons_id };
-        let one = Node::Constant { id: one_cons_id };
-        let indvar = Node::Phi {
-            control: region_id,
-            data: Box::new([zero_id, add_id]),
-        };
-        let add = Node::Binary {
-            op: BinaryOperator::Add,
-            left: indvar_id,
-            right: one_id,
-        };
-        let dc = Node::DynamicConstant { id: factors[0] };
-        let neq = Node::Binary {
-            op: BinaryOperator::NE,
-            left: add_id,
-            right: dc_id,
-        };
-        let (phis, join_phis): (Vec<_>, Vec<_>) = reduces
-            .iter()
-            .map(|reduce_id| {
-                let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap();
-                (
-                    Node::Phi {
-                        control: region_id,
-                        data: Box::new([init, reduct]),
-                    },
-                    Node::Phi {
-                        control: guard_join_id,
-                        data: Box::new([init, reduct]),
-                    },
-                )
-            })
-            .unzip();
+    let region = Node::Region {
+        preds: Box::new([guard_taken_proj_id, proj_back_id]),
+    };
+    let if_node = Node::If {
+        control: join_control,
+        cond: neq_id,
+    };
+    let proj_back = Node::Projection {
+        control: if_id,
+        selection: 1,
+    };
+    let proj_exit = Node::Projection {
+        control: if_id,
+        selection: 0,
+    };
+    let zero = Node::Constant { id: zero_cons_id };
+    let one = Node::Constant { id: one_cons_id };
+    let indvar = Node::Phi {
+        control: region_id,
+        data: Box::new([zero_id, add_id]),
+    };
+    let add = Node::Binary {
+        op: BinaryOperator::Add,
+        left: indvar_id,
+        right: one_id,
+    };
+    let dc = Node::DynamicConstant { id: factors[0] };
+    let neq = Node::Binary {
+        op: BinaryOperator::NE,
+        left: add_id,
+        right: dc_id,
+    };
+    let (phis, join_phis): (Vec<_>, Vec<_>) = reduces
+        .iter()
+        .map(|reduce_id| {
+            let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap();
+            (
+                Node::Phi {
+                    control: region_id,
+                    data: Box::new([init, reduct]),
+                },
+                Node::Phi {
+                    control: guard_join_id,
+                    data: Box::new([init, reduct]),
+                },
+            )
+        })
+        .unzip();
 
-        editor.edit(|mut edit| {
-            assert_eq!(edit.add_node(region), region_id);
-            assert_eq!(edit.add_node(if_node), if_id);
-            assert_eq!(edit.add_node(proj_back), proj_back_id);
-            assert_eq!(edit.add_node(proj_exit), proj_exit_id);
-            assert_eq!(edit.add_node(zero), zero_id);
-            assert_eq!(edit.add_node(one), one_id);
-            assert_eq!(edit.add_node(indvar), indvar_id);
-            assert_eq!(edit.add_node(add), add_id);
-            assert_eq!(edit.add_node(dc), dc_id);
-            assert_eq!(edit.add_node(neq), neq_id);
-            assert_eq!(edit.add_node(guard_if), guard_if_id);
-            assert_eq!(edit.add_node(guard_join), guard_join_id);
-            assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id);
-            assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id);
-            assert_eq!(edit.add_node(guard_cond), guard_cond_id);
+    editor.edit(|mut edit| {
+        assert_eq!(edit.add_node(region), region_id);
+        assert_eq!(edit.add_node(if_node), if_id);
+        assert_eq!(edit.add_node(proj_back), proj_back_id);
+        assert_eq!(edit.add_node(proj_exit), proj_exit_id);
+        assert_eq!(edit.add_node(zero), zero_id);
+        assert_eq!(edit.add_node(one), one_id);
+        assert_eq!(edit.add_node(indvar), indvar_id);
+        assert_eq!(edit.add_node(add), add_id);
+        assert_eq!(edit.add_node(dc), dc_id);
+        assert_eq!(edit.add_node(neq), neq_id);
+        assert_eq!(edit.add_node(guard_if), guard_if_id);
+        assert_eq!(edit.add_node(guard_join), guard_join_id);
+        assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id);
+        assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id);
+        assert_eq!(edit.add_node(guard_cond), guard_cond_id);
 
-            for (phi_id, phi) in zip(phi_ids.clone(), &phis) {
-                assert_eq!(edit.add_node(phi.clone()), phi_id);
-            }
-            for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) {
-                assert_eq!(edit.add_node(phi.clone()), phi_id);
-            }
+        for (phi_id, phi) in zip(phi_ids.clone(), &phis) {
+            assert_eq!(edit.add_node(phi.clone()), phi_id);
+        }
+        for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) {
+            assert_eq!(edit.add_node(phi.clone()), phi_id);
+        }
 
-            edit = edit.replace_all_uses(*fork, region_id)?;
-            edit = edit.replace_all_uses_where(*join, guard_join_id, |usee| *usee != if_id)?;
-            edit.sub_edit(*fork, region_id);
-            edit.sub_edit(*join, if_id);
-            for tid in tids.iter() {
-                edit.sub_edit(*tid, indvar_id);
-                edit = edit.replace_all_uses(*tid, indvar_id)?;
-            }
-            for (((reduce, phi_id), phi), join_phi_id) in
-                zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids)
-            {
-                edit.sub_edit(*reduce, phi_id);
-                let Node::Phi { control: _, data } = phi else {
-                    panic!()
-                };
-                edit = edit.replace_all_uses_where(*reduce, join_phi_id, |usee| {
-                    !fork_nodes.contains(usee)
-                })?; //, |usee| *usee != *reduct)?;
-                edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| {
-                    fork_nodes.contains(usee) || *usee == data[1]
-                })?;
-                edit = edit.delete_node(*reduce)?;
-            }
+        edit = edit.replace_all_uses(fork, region_id)?;
+        edit = edit.replace_all_uses_where(join, guard_join_id, |usee| *usee != if_id)?;
+        edit.sub_edit(fork, region_id);
+        edit.sub_edit(join, if_id);
+        for tid in tids.iter() {
+            edit.sub_edit(*tid, indvar_id);
+            edit = edit.replace_all_uses(*tid, indvar_id)?;
+        }
+        for (((reduce, phi_id), phi), join_phi_id) in
+            zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids)
+        {
+            edit.sub_edit(*reduce, phi_id);
+            let Node::Phi { control: _, data } = phi else {
+                panic!()
+            };
+            edit = edit
+                .replace_all_uses_where(*reduce, join_phi_id, |usee| !fork_nodes.contains(usee))?; //, |usee| *usee != *reduct)?;
+            edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| {
+                fork_nodes.contains(usee) || *usee == data[1]
+            })?;
+            edit = edit.delete_node(*reduce)?;
+        }
 
-            edit = edit.delete_node(*fork)?;
-            edit = edit.delete_node(*join)?;
-            for tid in tids {
-                edit = edit.delete_node(tid)?;
-            }
+        edit = edit.delete_node(fork)?;
+        edit = edit.delete_node(join)?;
+        for tid in tids {
+            edit = edit.delete_node(tid)?;
+        }
 
-            Ok(edit)
-        });
-    }
+        Ok(edit)
+    });
 }
diff --git a/juno_samples/fork_join_tests/Cargo.toml b/juno_samples/fork_join_tests/Cargo.toml
new file mode 100644
index 00000000..a109e782
--- /dev/null
+++ b/juno_samples/fork_join_tests/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "juno_fork_join_tests"
+version = "0.1.0"
+authors = ["Russel Arbore <rarbore2@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_fork_join_tests"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
diff --git a/juno_samples/fork_join_tests/build.rs b/juno_samples/fork_join_tests/build.rs
new file mode 100644
index 00000000..796e9f32
--- /dev/null
+++ b/juno_samples/fork_join_tests/build.rs
@@ -0,0 +1,24 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    #[cfg(not(feature = "cuda"))]
+    {
+        JunoCompiler::new()
+            .file_in_src("fork_join_tests.jn")
+            .unwrap()
+            .schedule_in_src("cpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+    #[cfg(feature = "cuda")]
+    {
+        JunoCompiler::new()
+            .file_in_src("fork_join_tests.jn")
+            .unwrap()
+            .schedule_in_src("gpu.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+}
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
new file mode 100644
index 00000000..0263c275
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -0,0 +1,42 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+let out = auto-outline(*);
+cpu(out.test1);
+cpu(out.test2);
+cpu(out.test3);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+}
+gvn(*);
+phi-elim(*);
+dce(*);
+
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  infer-schedules(*);
+}
+fork-split(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+unforkify(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+gcm(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
new file mode 100644
index 00000000..55e0a37e
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -0,0 +1,61 @@
+#[entry]
+fn test1(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 4 {
+    for j = 0 to 4 {
+      arr[i, j] = input;
+    }
+  }
+  return arr;
+}
+
+#[entry]
+fn test2(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 8 {
+    for j = 0 to 4 {
+      for k = 0 to 4 {
+        arr[j, k] += input;
+      }
+    }
+  }
+  return arr;
+}
+
+#[entry]
+fn test3(input : i32) -> i32[3, 3] {
+  let arr1 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr1[i, j] = (i + j) as i32 + input;
+    }
+  }
+  let arr2 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr2[i, j] = arr1[2 - i, 2 - j];
+    }
+  }
+  let arr3 : i32[3, 3];
+  for i = 0 to 3 {
+    for j = 0 to 3 {
+      arr3[i, j] = arr2[i, j] + 7;
+    }
+  }
+  return arr3;
+}
+
+#[entry]
+fn test4(input : i32) -> i32[4, 4] {
+  let arr : i32[4, 4];
+  for i = 0 to 4 {
+    for j = 0 to 4 {
+      let acc = arr[i, j];
+      for k = 0 to 7 {
+        acc += input;
+      }
+      arr[i, j] = acc;
+    }
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
new file mode 100644
index 00000000..701c347c
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -0,0 +1,33 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+let out = auto-outline(*);
+gpu(out.test1);
+gpu(out.test2);
+gpu(out.test3);
+gpu(out.test4);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  forkify(*);
+  fork-guard-elim(*);
+  fork-coalesce(*);
+}
+
+gvn(*);
+phi-elim(*);
+dce(*);
+
+fixpoint panic after 20 {
+  infer-schedules(*);
+}
+
+float-collections(test2, out.test2, test4, out.test4);
+gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
new file mode 100644
index 00000000..cbd42c50
--- /dev/null
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -0,0 +1,46 @@
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
+
+juno_build::juno!("fork_join_tests");
+
+fn main() {
+    #[cfg(not(feature = "cuda"))]
+    let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| {
+        assert_eq!(output.as_slice::<i32>(), &correct);
+    };
+
+    #[cfg(feature = "cuda")]
+    let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| {
+        let mut dst = vec![0i32; 16];
+        let output = output.to_cpu_ref(&mut dst);
+        assert_eq!(output.as_slice::<i32>(), &correct);
+    };
+
+    async_std::task::block_on(async {
+        let mut r = runner!(test1);
+        let output = r.run(5).await;
+        let correct = vec![5i32; 16];
+        assert(correct, output);
+
+        let mut r = runner!(test2);
+        let output = r.run(3).await;
+        let correct = vec![24i32; 16];
+        assert(correct, output);
+
+        let mut r = runner!(test3);
+        let output = r.run(0).await;
+        let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
+        assert(correct, output);
+
+        let mut r = runner!(test4);
+        let output = r.run(9).await;
+        let correct = vec![63i32; 16];
+        assert(correct, output);
+    });
+}
+
+#[test]
+fn implicit_clone_test() {
+    main();
+}
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index f6fe2fc1..9c818707 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -575,6 +575,8 @@ impl PassManager {
         self.postdoms = None;
         self.fork_join_maps = None;
         self.fork_join_nests = None;
+        self.fork_control_maps = None;
+        self.fork_trees = None;
         self.loops = None;
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
@@ -1303,7 +1305,7 @@ fn run_pass(
     pm: &mut PassManager,
     pass: Pass,
     args: Vec<Value>,
-    selection: Option<Vec<CodeLocation>>,
+    mut selection: Option<Vec<CodeLocation>>,
 ) -> Result<(Value, bool), SchedulerError> {
     let mut result = Value::Record {
         fields: HashMap::new(),
@@ -1439,13 +1441,6 @@ fn run_pass(
         }
         Pass::FloatCollections => {
             assert!(args.is_empty());
-            if let Some(_) = selection {
-                return Err(SchedulerError::PassError {
-                    pass: "floatCollections".to_string(),
-                    error: "must be applied to the entire module".to_string(),
-                });
-            }
-
             pm.make_typing();
             pm.make_callgraph();
             pm.make_devices();
@@ -1453,11 +1448,15 @@ fn run_pass(
             let callgraph = pm.callgraph.take().unwrap();
             let devices = pm.devices.take().unwrap();
 
-            let mut editors = build_editors(pm);
+            // Modify the selection to include callers of selected functions.
+            let mut editors = build_selection(pm, selection)
+                .into_iter()
+                .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor)))
+                .collect();
             float_collections(&mut editors, &typing, &callgraph, &devices);
 
             for func in editors {
-                changed |= func.modified();
+                changed |= func.1.modified();
             }
 
             pm.delete_gravestones();
@@ -1506,7 +1505,7 @@ fn run_pass(
                     let Some(mut func) = func else {
                         continue;
                     };
-                    fork_split(&mut func, fork_join_map, reduce_cycles);
+                    split_all_forks(&mut func, fork_join_map, reduce_cycles);
                     changed |= func.modified();
                     inner_changed |= func.modified();
                 }
@@ -1568,6 +1567,7 @@ fn run_pass(
                 pm.make_doms();
                 pm.make_fork_join_maps();
                 pm.make_loops();
+                pm.make_reduce_cycles();
                 pm.make_collection_objects();
                 pm.make_devices();
                 pm.make_object_device_demands();
@@ -1578,6 +1578,7 @@ fn run_pass(
                 let doms = pm.doms.take().unwrap();
                 let fork_join_maps = pm.fork_join_maps.take().unwrap();
                 let loops = pm.loops.take().unwrap();
+                let reduce_cycles = pm.reduce_cycles.take().unwrap();
                 let control_subgraphs = pm.control_subgraphs.take().unwrap();
                 let collection_objects = pm.collection_objects.take().unwrap();
                 let devices = pm.devices.take().unwrap();
@@ -1599,6 +1600,7 @@ fn run_pass(
                         &doms[id.idx()],
                         &fork_join_maps[id.idx()],
                         &loops[id.idx()],
+                        &reduce_cycles[id.idx()],
                         &collection_objects,
                         &devices,
                         &object_device_demands[id.idx()],
@@ -1874,7 +1876,7 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                unforkify(&mut func, fork_join_map, loop_tree);
+                unforkify_all(&mut func, fork_join_map, loop_tree);
                 changed |= func.modified();
             }
             pm.delete_gravestones();
-- 
GitLab