From f6674a965fc503820effb50481f266ffafc9e972 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Tue, 4 Feb 2025 20:07:08 -0600
Subject: [PATCH] Misc. preparations for multi-core support.

---
 hercules_cg/src/rt.rs                         |  72 ++++++++++
 hercules_ir/src/dot.rs                        |  35 ++++-
 hercules_ir/src/einsum.rs                     |   6 +-
 hercules_ir/src/fork_join_analysis.rs         |  18 +--
 hercules_opt/src/ccp.rs                       |  10 +-
 hercules_opt/src/fork_transforms.rs           |  14 +-
 hercules_opt/src/forkify.rs                   |  10 +-
 hercules_rt/Cargo.toml                        |   1 +
 hercules_rt/src/lib.rs                        |  87 +++++++++++-
 hercules_rt/src/rtdefs.cu                     |  12 +-
 juno_samples/fork_join_tests/src/cpu.sch      |  30 +++-
 .../fork_join_tests/src/fork_join_tests.jn    |   9 ++
 juno_samples/fork_join_tests/src/gpu.sch      |  28 +++-
 juno_samples/fork_join_tests/src/main.rs      |  25 ++--
 juno_scheduler/src/lib.rs                     |   2 +
 juno_scheduler/src/pm.rs                      | 133 ++++++++++++++----
 16 files changed, 403 insertions(+), 89 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index cbef5a00..35334a14 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -407,6 +407,78 @@ impl<'a> RTContext<'a> {
                     write!(block, ");\n")?;
                 }
             }
+            Node::Unary { op, input } => {
+                let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
+                match op {
+                    UnaryOperator::Not => write!(
+                        block,
+                        "                {} = !{};\n",
+                        self.get_value(id),
+                        self.get_value(input)
+                    )?,
+                    UnaryOperator::Neg => write!(
+                        block,
+                        "                {} = -{};\n",
+                        self.get_value(id),
+                        self.get_value(input)
+                    )?,
+                    UnaryOperator::Cast(ty) => write!(
+                        block,
+                        "                {} = {} as {};\n",
+                        self.get_value(id),
+                        self.get_value(input),
+                        self.get_type(ty)
+                    )?,
+                };
+            }
+            Node::Binary { op, left, right } => {
+                let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
+                let op = match op {
+                    BinaryOperator::Add => "+",
+                    BinaryOperator::Sub => "-",
+                    BinaryOperator::Mul => "*",
+                    BinaryOperator::Div => "/",
+                    BinaryOperator::Rem => "%",
+                    BinaryOperator::LT => "<",
+                    BinaryOperator::LTE => "<=",
+                    BinaryOperator::GT => ">",
+                    BinaryOperator::GTE => ">=",
+                    BinaryOperator::EQ => "==",
+                    BinaryOperator::NE => "!=",
+                    BinaryOperator::Or => "|",
+                    BinaryOperator::And => "&",
+                    BinaryOperator::Xor => "^",
+                    BinaryOperator::LSh => "<<",
+                    BinaryOperator::RSh => ">>",
+                };
+
+                write!(
+                    block,
+                    "                {} = {} {} {};\n",
+                    self.get_value(id),
+                    self.get_value(left),
+                    op,
+                    self.get_value(right)
+                )?;
+            }
+            Node::Ternary {
+                op,
+                first,
+                second,
+                third,
+            } => {
+                let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
+                match op {
+                    TernaryOperator::Select => write!(
+                        block,
+                        "                {} = if {} {{ {} }} else {{ {} }};\n",
+                        self.get_value(id),
+                        self.get_value(first),
+                        self.get_value(second),
+                        self.get_value(third),
+                    )?,
+                };
+            }
             Node::Read {
                 collect,
                 ref indices,
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 5ccda9dc..7ad8c6df 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -343,30 +343,57 @@ fn write_node<W: Write>(
     }
 
     let mut iter = schedules.into_iter();
-    if let Some(first) = iter.next() {
-        let schedules = iter.fold(format!("{:?}", first), |b, i| format!("{}, {:?}", b, i));
+    let schedules = if let Some(first) = iter.next() {
+        iter.fold(format!("{:?}", first), |b, i| format!("{}, {:?}", b, i))
+    } else {
+        String::new()
+    };
+    if tylabel.is_empty() && schedules.is_empty() {
         write!(
             w,
-            "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT><BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n",
+            "{}_{}_{} [xlabel={}, label=<{}>, color={}];\n",
+            node.lower_case_name(),
+            function_id.idx(),
+            node_id.idx(),
+            xlabel,
+            label,
+            color
+        )?;
+    } else if schedules.is_empty() {
+        write!(
+            w,
+            "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n",
             node.lower_case_name(),
             function_id.idx(),
             node_id.idx(),
             xlabel,
             label,
             tylabel,
+            color
+        )?;
+    } else if tylabel.is_empty() {
+        write!(
+            w,
+            "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n",
+            node.lower_case_name(),
+            function_id.idx(),
+            node_id.idx(),
+            xlabel,
+            label,
             schedules,
             color
         )?;
     } else {
         write!(
             w,
-            "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n",
+            "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT><BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n",
             node.lower_case_name(),
             function_id.idx(),
             node_id.idx(),
             xlabel,
             label,
             tylabel,
+            schedules,
             color
         )?;
     }
diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 8d3bec3a..25e15d63 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -60,7 +60,7 @@ pub fn einsum(
     typing: &Vec<TypeID>,
     fork_join_map: &HashMap<NodeID, NodeID>,
     fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
-    data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
 ) -> (MathEnv, HashMap<NodeID, MathID>) {
     let mut env = vec![];
     let mut rev_env = HashMap::new();
@@ -101,7 +101,7 @@ pub fn einsum(
             function,
             typing,
             constants,
-            data_nodes_in_fork_joins,
+            nodes_in_fork_joins,
             fork,
             factors,
             thread_ids: &thread_ids,
@@ -185,7 +185,7 @@ struct EinsumContext<'a> {
     function: &'a Function,
     typing: &'a Vec<TypeID>,
     constants: &'a Vec<Constant>,
-    data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
+    nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
     fork: NodeID,
     factors: &'a [DynamicConstantID],
     thread_ids: &'a Vec<(NodeID, usize)>,
diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs
index 7a098a35..ad3125ba 100644
--- a/hercules_ir/src/fork_join_analysis.rs
+++ b/hercules_ir/src/fork_join_analysis.rs
@@ -165,10 +165,9 @@ fn reduce_cycle_dfs_helper(
 }
 
 /*
- * Top level function to calculate which data nodes are "inside" a fork-join,
- * not including its reduces.
+ * Top level function to calculate which nodes are "inside" a fork-join.
  */
-pub fn data_nodes_in_fork_joins(
+pub fn nodes_in_fork_joins(
     function: &Function,
     def_use: &ImmutableDefUseMap,
     fork_join_map: &HashMap<NodeID, NodeID>,
@@ -178,22 +177,19 @@ pub fn data_nodes_in_fork_joins(
     for (fork, join) in fork_join_map {
         let mut worklist = vec![*fork];
         let mut set = HashSet::new();
+        set.insert(*fork);
 
         while let Some(item) = worklist.pop() {
             for u in def_use.get_users(item) {
-                if function.nodes[u.idx()].is_control()
+                let terminate = *u == *join
                     || function.nodes[u.idx()]
                         .try_reduce()
                         .map(|(control, _, _)| control == *join)
-                        .unwrap_or(false)
-                {
-                    // Ignore control users and reduces of the fork-join.
-                    continue;
-                }
-                if !set.contains(u) {
-                    set.insert(*u);
+                        .unwrap_or(false);
+                if !set.contains(u) && !terminate {
                     worklist.push(*u);
                 }
+                set.insert(*u);
             }
         }
 
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index 92d52a71..9768198c 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -482,9 +482,13 @@ fn ccp_flow_function(
             reachability: ReachabilityLattice::bottom(),
             constant: ConstantLattice::Constant(editor.get_constant(*id).clone()),
         },
-        // TODO: This should really be constant interpreted, since dynamic
-        // constants as values are used frequently.
-        Node::DynamicConstant { id: _ } => CCPLattice::bottom(),
+        Node::DynamicConstant { id } => match *editor.get_dynamic_constant(*id) {
+            DynamicConstant::Constant(value) => CCPLattice {
+                reachability: ReachabilityLattice::bottom(),
+                constant: ConstantLattice::Constant(Constant::UnsignedInteger64(value as u64)),
+            },
+            _ => CCPLattice::bottom(),
+        },
         // Interpret unary op on constant.
         Node::Unary { input, op } => {
             let CCPLattice {
diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index c4a6ba7f..456f670e 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -541,18 +541,19 @@ pub fn fork_coalesce_helper(
     true
 }
 
-pub fn split_all_forks(
+pub fn split_any_fork(
     editor: &mut FunctionEditor,
     fork_join_map: &HashMap<NodeID, NodeID>,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
-) {
+) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
     for (fork, join) in fork_join_map {
-        if let Some((forks, _)) = split_fork(editor, *fork, *join, reduce_cycles)
+        if let Some((forks, joins)) = split_fork(editor, *fork, *join, reduce_cycles)
             && forks.len() > 1
         {
-            break;
+            return Some((forks, joins));
         }
     }
+    None
 }
 
 /*
@@ -689,6 +690,7 @@ pub(crate) fn split_fork(
         Ok(edit)
     });
     if success {
+        new_joins.reverse();
         Some((new_forks, new_joins))
     } else {
         None
@@ -749,6 +751,7 @@ pub fn chunk_fork_unguarded(
         let new_fork = edit.add_node(new_fork);
 
         edit = edit.replace_all_uses(fork, new_fork)?;
+        edit.sub_edit(fork, new_fork);
 
         for (tid, node) in fork_users {
             let Node::ThreadID {
@@ -765,6 +768,7 @@ pub fn chunk_fork_unguarded(
                 };
                 let new_tid = edit.add_node(new_tid);
                 edit = edit.replace_all_uses(tid, new_tid)?;
+                edit.sub_edit(tid, new_tid);
                 edit = edit.delete_node(tid)?;
             } else if tid_dim == dim_idx {
                 let tile_tid = Node::ThreadID {
@@ -784,6 +788,8 @@ pub fn chunk_fork_unguarded(
                     right: tile_tid,
                     op: BinaryOperator::Add,
                 });
+                edit.sub_edit(tid, add);
+                edit.sub_edit(tid, tile_tid);
                 edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
             }
         }
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index f6db06ca..2adfddd8 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -298,15 +298,13 @@ pub fn forkify_loop(
     let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap();
     let dimension = factors.len() - 1;
 
-    // Start failable edit:
-
     let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis
         .iter()
         .map(|reduction_phi| {
             let LoopPHI::Reductionable {
                 phi,
                 data_cycle: _,
-                continue_latch,
+                continue_latch: _,
                 is_associative: _,
             } = reduction_phi
             else {
@@ -328,6 +326,7 @@ pub fn forkify_loop(
         })
         .collect();
 
+    // Start failable edit:
     editor.edit(|mut edit| {
         let thread_id = Node::ThreadID {
             control: fork_id,
@@ -339,6 +338,7 @@ pub fn forkify_loop(
         edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
             loop_nodes.contains(node)
         })?;
+        edit.sub_edit(canonical_iv.phi(), thread_id_id);
 
         edit = edit.delete_node(canonical_iv.phi())?;
 
@@ -386,12 +386,16 @@ pub fn forkify_loop(
             edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| {
                 !loop_nodes.contains(usee) && *usee != reduce_id
             })?;
+            edit.sub_edit(phi, reduce_id);
             edit = edit.delete_node(phi)?
         }
 
         edit = edit.replace_all_uses(l.header, fork_id)?;
         edit = edit.replace_all_uses(loop_continue_projection, fork_id)?;
         edit = edit.replace_all_uses(loop_exit_projection, join_id)?;
+        edit.sub_edit(l.header, fork_id);
+        edit.sub_edit(loop_continue_projection, fork_id);
+        edit.sub_edit(loop_exit_projection, join_id);
 
         edit = edit.delete_node(loop_continue_projection)?;
         edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this.
diff --git a/hercules_rt/Cargo.toml b/hercules_rt/Cargo.toml
index c4678b18..46886b12 100644
--- a/hercules_rt/Cargo.toml
+++ b/hercules_rt/Cargo.toml
@@ -6,6 +6,7 @@ edition = "2021"
 
 [features]
 cuda = []
+debug = []
 
 [dependencies]
 
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index ed5dca1d..2ad72043 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -10,29 +10,101 @@ use std::slice::{from_raw_parts, from_raw_parts_mut};
  */
 
 pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 {
-    alloc(Layout::from_size_align(size, 16).unwrap())
+    let ptr = alloc(Layout::from_size_align(size, 16).unwrap());
+    if cfg!(feature = "debug") {
+        eprintln!("__cpu_alloc: {:?}, {}", ptr, size);
+        assert!(!ptr.is_null() || size == 0);
+    }
+    ptr
 }
 
 pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__cpu_dealloc: {:?}, {}", ptr, size);
+        assert!(!ptr.is_null() || size == 0);
+    }
     dealloc(ptr, Layout::from_size_align(size, 16).unwrap())
 }
 
 pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__cpu_zero_mem: {:?}, {}", ptr, size);
+        assert!(!ptr.is_null() || size == 0);
+    }
     write_bytes(ptr, 0, size);
 }
 
 pub unsafe fn __copy_cpu_to_cpu(dst: *mut u8, src: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__copy_cpu_to_cpu: {:?}, {:?}, {}", dst, src, size);
+        assert!((!dst.is_null() && !src.is_null()) || size == 0);
+    }
     copy_nonoverlapping(src, dst, size);
 }
 
+#[cfg(feature = "cuda")]
+pub unsafe fn __cuda_alloc(size: usize) -> *mut u8 {
+    let ptr = ___cuda_alloc(size);
+    if cfg!(feature = "debug") {
+        eprintln!("__cuda_alloc: {:?}, {}", ptr, size);
+        assert!(!ptr.is_null() || size == 0);
+    }
+    ptr
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __cuda_dealloc(ptr: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__cuda_dealloc: {:?}, {}", ptr, size);
+        assert!(!ptr.is_null() || size == 0);
+    }
+    ___cuda_dealloc(ptr, size);
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __cuda_zero_mem(ptr: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__cuda_zero_mem: {:?}, {}", ptr, size);
+        assert!(!ptr.is_null() || size == 0);
+    }
+    ___cuda_zero_mem(ptr, size);
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__copy_cpu_to_cuda: {:?}, {:?}, {}", dst, src, size);
+        assert!((!dst.is_null() && !src.is_null()) || size == 0);
+    }
+    ___copy_cpu_to_cuda(dst, src, size);
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__copy_cuda_to_cpu: {:?}, {:?}, {}", dst, src, size);
+        assert!((!dst.is_null() && !src.is_null()) || size == 0);
+    }
+    ___copy_cuda_to_cpu(dst, src, size);
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
+    if cfg!(feature = "debug") {
+        eprintln!("__copy_cuda_to_cuda: {:?}, {:?}, {}", dst, src, size);
+        assert!((!dst.is_null() && !src.is_null()) || size == 0);
+    }
+    ___copy_cuda_to_cuda(dst, src, size);
+}
+
 #[cfg(feature = "cuda")]
 extern "C" {
-    pub fn __cuda_alloc(size: usize) -> *mut u8;
-    pub fn __cuda_dealloc(ptr: *mut u8, size: usize);
-    pub fn __cuda_zero_mem(ptr: *mut u8, size: usize);
-    pub fn __copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn __copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    pub fn ___cuda_alloc(size: usize) -> *mut u8;
+    pub fn ___cuda_dealloc(ptr: *mut u8, size: usize);
+    pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
+    pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
+    pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
 }
 
 #[derive(Clone, Debug)]
@@ -155,6 +227,7 @@ impl<'a> HerculesCUDARef<'a> {
     pub fn to_cpu_ref<'b, T>(self, dst: &'b mut [T]) -> HerculesCPURefMut<'b> {
         unsafe {
             let size = self.size;
+            assert_eq!(size, dst.len() * size_of::<T>());
             let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
             __copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size);
             HerculesCPURefMut {
diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu
index 534b297d..50e11fa6 100644
--- a/hercules_rt/src/rtdefs.cu
+++ b/hercules_rt/src/rtdefs.cu
@@ -1,5 +1,5 @@
 extern "C" {
-	void *__cuda_alloc(size_t size) {
+	void *___cuda_alloc(size_t size) {
 		void *ptr = NULL;
 		cudaError_t res = cudaMalloc(&ptr, size);
 		if (res != cudaSuccess) {
@@ -8,24 +8,24 @@ extern "C" {
 		return ptr;
 	}
 
-	void __cuda_dealloc(void *ptr, size_t size) {
+	void ___cuda_dealloc(void *ptr, size_t size) {
 		(void) size;
 		cudaFree(ptr);
 	}
 
-	void __cuda_zero_mem(void *ptr, size_t size) {
+	void ___cuda_zero_mem(void *ptr, size_t size) {
 		cudaMemset(ptr, 0, size);
 	}
 
-	void __copy_cpu_to_cuda(void *dst, void *src, size_t size) {
+	void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
 		cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
 	}
 
-	void __copy_cuda_to_cpu(void *dst, void *src, size_t size) {
+	void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
 		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
 	}
 
-	void __copy_cuda_to_cuda(void *dst, void *src, size_t size) {
+	void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
 		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
 	}
 }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 0263c275..38010004 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -1,15 +1,21 @@
+no-memset(test6@const);
+
+ccp(*);
 gvn(*);
 phi-elim(*);
 dce(*);
 
-let out = auto-outline(*);
+let out = auto-outline(test1, test2, test3, test4, test5);
 cpu(out.test1);
 cpu(out.test2);
 cpu(out.test3);
+cpu(out.test4);
+cpu(out.test5);
 
 ip-sroa(*);
 sroa(*);
 dce(*);
+ccp(*);
 gvn(*);
 phi-elim(*);
 dce(*);
@@ -23,20 +29,30 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-gvn(*);
-phi-elim(*);
-dce(*);
-
 fixpoint panic after 20 {
   infer-schedules(*);
 }
-fork-split(*);
+fork-split(out.test1, out.test2, out.test3, out.test4, out.test5);
 gvn(*);
 phi-elim(*);
 dce(*);
-unforkify(*);
+unforkify(out.test1, out.test2, out.test3, out.test4, out.test5);
+ccp(*);
 gvn(*);
 phi-elim(*);
 dce(*);
 
+fork-tile[32, 0, true](test6@loop);
+let out = fork-split(test6@loop);
+//let out = outline(out.test6.fj1);
+let out = auto-outline(test6);
+cpu(out.test6);
+ip-sroa(*);
+sroa(*);
+unforkify(out.test6);
+dce(*);
+ccp(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 6e5db4cb..806cb0f1 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -72,3 +72,12 @@ fn test5(input : i32) -> i32[4] {
   }
   return arr1;
 }
+
+#[entry]
+fn test6(input: i32) -> i32[1024] {
+  @const let arr : i32[1024];
+  @loop for i = 0 to 1024 {
+    arr[i] = i as i32 + input;
+  }
+  return arr;
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index f096ea50..f108e2c1 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -4,17 +4,18 @@ no-memset(test1@const);
 no-memset(test3@const1);
 no-memset(test3@const2);
 no-memset(test3@const3);
+no-memset(test6@const);
 
 gvn(*);
 phi-elim(*);
 dce(*);
 
-let out = auto-outline(*);
-gpu(out.test1);
-gpu(out.test2);
-gpu(out.test3);
-gpu(out.test4);
-gpu(out.test5);
+let auto = auto-outline(test1, test2, test3, test4, test5);
+gpu(auto.test1);
+gpu(auto.test2);
+gpu(auto.test3);
+gpu(auto.test4);
+gpu(auto.test5);
 
 ip-sroa(*);
 sroa(*);
@@ -37,5 +38,18 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
-float-collections(test2, out.test2, test4, out.test4, test5, out.test5);
+fork-tile[32, 0, true](test6@loop);
+let out = fork-split(test6@loop);
+let out = auto-outline(test6);
+gpu(out.test6);
+ip-sroa(*);
+sroa(*);
+dce(*);
+ccp(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+gcm(*);
+
+float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5);
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 5e848ade..19838fd7 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -6,42 +6,47 @@ juno_build::juno!("fork_join_tests");
 
 fn main() {
     #[cfg(not(feature = "cuda"))]
-    let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| {
-        assert_eq!(output.as_slice::<i32>(), &correct);
+    let assert = |correct: &Vec<i32>, output: hercules_rt::HerculesCPURefMut<'_>| {
+        assert_eq!(output.as_slice::<i32>(), correct);
     };
 
     #[cfg(feature = "cuda")]
-    let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| {
-        let mut dst = vec![0i32; 16];
+    let assert = |correct: &Vec<i32>, output: hercules_rt::HerculesCUDARefMut<'_>| {
+        let mut dst = vec![0i32; correct.len()];
         let output = output.to_cpu_ref(&mut dst);
-        assert_eq!(output.as_slice::<i32>(), &correct);
+        assert_eq!(output.as_slice::<i32>(), correct);
     };
 
     async_std::task::block_on(async {
         let mut r = runner!(test1);
         let output = r.run(5).await;
         let correct = vec![5i32; 16];
-        assert(correct, output);
+        assert(&correct, output);
 
         let mut r = runner!(test2);
         let output = r.run(3).await;
         let correct = vec![24i32; 16];
-        assert(correct, output);
+        assert(&correct, output);
 
         let mut r = runner!(test3);
         let output = r.run(0).await;
         let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7];
-        assert(correct, output);
+        assert(&correct, output);
 
         let mut r = runner!(test4);
         let output = r.run(9).await;
         let correct = vec![63i32; 16];
-        assert(correct, output);
+        assert(&correct, output);
 
         let mut r = runner!(test5);
         let output = r.run(4).await;
         let correct = vec![7i32; 4];
-        assert(correct, output);
+        assert(&correct, output);
+
+        let mut r = runner!(test6);
+        let output = r.run(73).await;
+        let correct = (73i32..73i32+1024i32).collect();
+        assert(&correct, output);
     });
 }
 
diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index ad9195fb..d4ab432a 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -1,3 +1,5 @@
+#![feature(exact_size_is_empty)]
+
 use std::collections::{HashMap, HashSet};
 use std::fs::File;
 use std::io::Read;
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index b2845913..9478eb9b 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -182,7 +182,7 @@ pub struct PassManager {
     pub fork_trees: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub loops: Option<Vec<LoopTree>>,
     pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
-    pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
+    pub nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>,
     pub no_reset_constants: Option<Vec<BTreeSet<NodeID>>>,
     pub collection_objects: Option<CollectionObjects>,
@@ -221,7 +221,7 @@ impl PassManager {
             fork_trees: None,
             loops: None,
             reduce_cycles: None,
-            data_nodes_in_fork_joins: None,
+            nodes_in_fork_joins: None,
             reduce_einsums: None,
             no_reset_constants: None,
             collection_objects: None,
@@ -408,11 +408,11 @@ impl PassManager {
         }
     }
 
-    pub fn make_data_nodes_in_fork_joins(&mut self) {
-        if self.data_nodes_in_fork_joins.is_none() {
+    pub fn make_nodes_in_fork_joins(&mut self) {
+        if self.nodes_in_fork_joins.is_none() {
             self.make_def_uses();
             self.make_fork_join_maps();
-            self.data_nodes_in_fork_joins = Some(
+            self.nodes_in_fork_joins = Some(
                 zip(
                     self.functions.iter(),
                     zip(
@@ -421,7 +421,7 @@ impl PassManager {
                     ),
                 )
                 .map(|(function, (def_use, fork_join_map))| {
-                    data_nodes_in_fork_joins(function, def_use, fork_join_map)
+                    nodes_in_fork_joins(function, def_use, fork_join_map)
                 })
                 .collect(),
             );
@@ -434,12 +434,12 @@ impl PassManager {
             self.make_typing();
             self.make_fork_join_maps();
             self.make_fork_join_nests();
-            self.make_data_nodes_in_fork_joins();
+            self.make_nodes_in_fork_joins();
             let def_uses = self.def_uses.as_ref().unwrap().iter();
             let typing = self.typing.as_ref().unwrap().iter();
             let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
             let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
-            let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter();
+            let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter();
             self.reduce_einsums = Some(
                 self.functions
                     .iter()
@@ -447,11 +447,11 @@ impl PassManager {
                     .zip(typing)
                     .zip(fork_join_maps)
                     .zip(fork_join_nests)
-                    .zip(data_nodes_in_fork_joins)
+                    .zip(nodes_in_fork_joins)
                     .map(
                         |(
                             ((((function, def_use), typing), fork_join_map), fork_join_nest),
-                            data_nodes_in_fork_joins,
+                            nodes_in_fork_joins,
                         )| {
                             einsum(
                                 function,
@@ -461,7 +461,7 @@ impl PassManager {
                                 typing,
                                 fork_join_map,
                                 fork_join_nest,
-                                data_nodes_in_fork_joins,
+                                nodes_in_fork_joins,
                             )
                         },
                     )
@@ -579,7 +579,7 @@ impl PassManager {
         self.fork_trees = None;
         self.loops = None;
         self.reduce_cycles = None;
-        self.data_nodes_in_fork_joins = None;
+        self.nodes_in_fork_joins = None;
         self.reduce_einsums = None;
         self.no_reset_constants = None;
         self.collection_objects = None;
@@ -809,6 +809,8 @@ impl PassManager {
             let mut nvcc_process = Command::new("nvcc")
                 .arg("-c")
                 .arg("-O3")
+                .arg("-diag-suppress")
+                .arg("177")
                 .arg("-o")
                 .arg(&cuda_object)
                 .arg(&cuda_path)
@@ -1515,6 +1517,7 @@ fn run_pass(
         }
         Pass::ForkSplit => {
             assert!(args.is_empty());
+            let mut created_fork_joins = vec![vec![vec![]]; pm.functions.len()];
             loop {
                 let mut inner_changed = false;
                 pm.make_fork_join_maps();
@@ -1529,17 +1532,80 @@ fn run_pass(
                     let Some(mut func) = func else {
                         continue;
                     };
-                    split_all_forks(&mut func, fork_join_map, reduce_cycles);
+                    if let Some((forks, joins)) =
+                        split_any_fork(&mut func, fork_join_map, reduce_cycles)
+                    {
+                        let created_fork_joins = &mut created_fork_joins[func.func_id().idx()];
+                        if forks.len() > created_fork_joins.len() {
+                            created_fork_joins.resize(forks.len(), vec![]);
+                        }
+                        for (idx, (fork, join)) in zip(forks, joins).enumerate() {
+                            created_fork_joins[idx].push((fork, join));
+                        }
+                    }
                     changed |= func.modified();
                     inner_changed |= func.modified();
                 }
-                pm.delete_gravestones();
                 pm.clear_analyses();
 
                 if !inner_changed {
                     break;
                 }
             }
+
+            pm.make_nodes_in_fork_joins();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            let mut new_fork_joins = HashMap::new();
+            for (mut func, created_fork_joins) in
+                build_editors(pm).into_iter().zip(created_fork_joins)
+            {
+                // For every function, create a label for every level of fork-
+                // joins resulting from the split.
+                let name = func.func().name.clone();
+                let func_id = func.func_id();
+                let labels = create_labels_for_node_sets(
+                    &mut func,
+                    created_fork_joins.into_iter().map(|level_fork_joins| {
+                        level_fork_joins
+                            .into_iter()
+                            .map(|(fork, _)| {
+                                nodes_in_fork_joins[func_id.idx()][&fork]
+                                    .iter()
+                                    .map(|id| *id)
+                            })
+                            .flatten()
+                    }),
+                );
+
+                // Assemble those labels into a record for this function. The
+                // format of the records is <function>.<fjN>, where N is the
+                // level of the split fork-joins being referred to.
+                let mut func_record = HashMap::new();
+                for (idx, label) in labels {
+                    func_record.insert(
+                        format!("fj{}", idx),
+                        Value::Label {
+                            labels: vec![LabelInfo {
+                                func: func_id,
+                                label: label,
+                            }],
+                        },
+                    );
+                }
+
+                // Try to avoid creating unnecessary record entries.
+                if !func_record.is_empty() {
+                    new_fork_joins.entry(name).insert_entry(Value::Record {
+                        fields: func_record,
+                    });
+                }
+            }
+
+            pm.delete_gravestones();
+            pm.clear_analyses();
+            result = Value::Record {
+                fields: new_fork_joins,
+            };
         }
         Pass::Forkify => {
             assert!(args.is_empty());
@@ -1914,27 +1980,22 @@ fn run_pass(
         }
         Pass::ForkChunk => {
             assert_eq!(args.len(), 3);
-            let tile_size = args.get(0);
-            let dim_idx = args.get(1);
-
-            let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
+            let Some(Value::Integer { val: tile_size }) = args.get(0) else {
                 return Err(SchedulerError::PassError {
                     pass: "forkChunk".to_string(),
-                    error: "expected boolean argument".to_string(),
+                    error: "expected integer argument".to_string(),
                 });
             };
-
             let Some(Value::Integer { val: dim_idx }) = args.get(1) else {
                 return Err(SchedulerError::PassError {
                     pass: "forkChunk".to_string(),
                     error: "expected integer argument".to_string(),
                 });
             };
-
-            let Some(Value::Integer { val: tile_size }) = args.get(0) else {
+            let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else {
                 return Err(SchedulerError::PassError {
                     pass: "forkChunk".to_string(),
-                    error: "expected integer argument".to_string(),
+                    error: "expected boolean argument".to_string(),
                 });
             };
 
@@ -2068,9 +2129,33 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
-        Pass::ForkChunk => todo!(),
     }
     println!("Ran Pass: {:?}", pass);
 
     Ok((result, changed))
 }
+
+fn create_labels_for_node_sets<I, J>(
+    editor: &mut FunctionEditor,
+    node_sets: I,
+) -> Vec<(usize, LabelID)>
+where
+    I: Iterator<Item = J>,
+    J: Iterator<Item = NodeID>,
+{
+    let mut labels = vec![];
+    editor.edit(|mut edit| {
+        for (set_idx, node_set) in node_sets.enumerate() {
+            let mut node_set = node_set.peekable();
+            if node_set.peek().is_some() {
+                let label = edit.fresh_label();
+                for node in node_set {
+                    edit = edit.add_label(node, label).unwrap();
+                }
+                labels.push((set_idx, label));
+            }
+        }
+        Ok(edit)
+    });
+    labels
+}
-- 
GitLab