From af3175773528d77fa06b280e027029943e2539da Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 25 Feb 2025 12:34:37 -0600
Subject: [PATCH] Use different allocations in parallel calls

---
 hercules_cg/src/cpu.rs                  |  2 +-
 hercules_cg/src/gpu.rs                  |  2 +-
 hercules_cg/src/lib.rs                  | 11 +++++++---
 hercules_cg/src/rt.rs                   | 29 +++++++++++++++++++------
 hercules_opt/src/gcm.rs                 |  6 ++---
 juno_samples/edge_detection/src/cpu.sch |  4 ++--
 juno_scheduler/src/pm.rs                |  4 +++-
 7 files changed, 40 insertions(+), 18 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 6ad38fc0..552dc3a3 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -334,7 +334,7 @@ impl<'a> CPUContext<'a> {
                     }
                 } else {
                     let (_, offsets) = &self.backing_allocation[&Device::LLVM];
-                    let offset = offsets[&id];
+                    let offset = offsets[&id].0;
                     write!(
                         body,
                         "  {} = getelementptr i8, ptr %backing, i64 %dc{}\n",
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 76aba7e0..5f2feedd 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -1290,7 +1290,7 @@ namespace cg = cooperative_groups;
                 if !is_primitive && state == KernelState::OutBlock {
                     assert!(self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant), "PANIC: The CUDA backend cannot lower a global memory constant that has to be reset to zero. This is because we cannot efficiently implement a memset to the underlying memory of the constant due to the need for a grid level sync. Consider floating this collection outside the CUDA function and into an AsyncRust function, or attaching the NoResetConstant schedule to indicate that no memset is semantically necessary.");
                     let (_, offsets) = &self.backing_allocation[&Device::CUDA];
-                    let offset = offsets[&id];
+                    let offset = offsets[&id].0;
                     write!(
                         w,
                         "{}{} = backing + dc{};\n",
diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 98f91e1f..9866400c 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -53,10 +53,15 @@ pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>;
 /*
  * The allocation information of each function is a size of the backing memory
  * needed and offsets into that backing memory per constant object and call node
- * in the function.
+ * in the function (as well as their individual sizes).
  */
-pub type FunctionBackingAllocation =
-    BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>;
+pub type FunctionBackingAllocation = BTreeMap<
+    Device,
+    (
+        DynamicConstantID,
+        BTreeMap<NodeID, (DynamicConstantID, DynamicConstantID)>,
+    ),
+>;
 pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
 pub const BACKED_DEVICES: [Device; 2] = [Device::LLVM, Device::CUDA];
 
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index ddfa9503..3db0f16f 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -80,7 +80,7 @@ pub fn rt_codegen<W: Write>(
     typing: &Vec<TypeID>,
     control_subgraph: &Subgraph,
     fork_join_map: &HashMap<NodeID, NodeID>,
-    fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
     fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     collection_objects: &CollectionObjects,
@@ -103,7 +103,7 @@ pub fn rt_codegen<W: Write>(
         control_subgraph,
         fork_join_map,
         join_fork_map: &join_fork_map,
-        fork_control_map,
+        fork_join_nest,
         fork_tree,
         nodes_in_fork_joins,
         collection_objects,
@@ -124,7 +124,7 @@ struct RTContext<'a> {
     control_subgraph: &'a Subgraph,
     fork_join_map: &'a HashMap<NodeID, NodeID>,
     join_fork_map: &'a HashMap<NodeID, NodeID>,
-    fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>,
+    fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>,
     fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>,
     nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
     collection_objects: &'a CollectionObjects,
@@ -559,7 +559,7 @@ impl<'a> RTContext<'a> {
                     Constant::Product(ty, _)
                     | Constant::Summation(ty, _, _)
                     | Constant::Array(ty) => {
-                        let (device, offset) = self.backing_allocations[&self.func_id]
+                        let (device, (offset, _)) = self.backing_allocations[&self.func_id]
                             .iter()
                             .filter_map(|(device, (_, offsets))| {
                                 offsets.get(&id).map(|id| (*device, *id))
@@ -676,13 +676,28 @@ impl<'a> RTContext<'a> {
                     prefix,
                     self.module.functions[callee_id.idx()].name
                 )?;
-                for (device, offset) in self.backing_allocations[&self.func_id]
+                for (device, (offset, size)) in self.backing_allocations[&self.func_id]
                     .iter()
                     .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id)))
                 {
-                    write!(block, "backing_{}.byte_add(", device.name())?;
+                    write!(block, "backing_{}.byte_add(((", device.name())?;
                     self.codegen_dynamic_constant(offset, block)?;
-                    write!(block, " as usize), ")?
+                    let forks = &self.fork_join_nest[&bb];
+                    if !forks.is_empty() {
+                        write!(block, ") + ")?;
+                        let mut linear_thread = "0".to_string();
+                        for fork in forks {
+                            let factors = func.nodes[fork.idx()].try_fork().unwrap().1;
+                            for (factor_idx, factor) in factors.into_iter().enumerate() {
+                                linear_thread = format!("({} *", linear_thread);
+                                self.codegen_dynamic_constant(*factor, &mut linear_thread)?;
+                                write!(linear_thread, " + tid_{}_{})", fork.idx(), factor_idx)?;
+                            }
+                        }
+                        write!(block, "{} * (", linear_thread)?;
+                        self.codegen_dynamic_constant(size, block)?;
+                    }
+                    write!(block, ")) as usize), ")?
                 }
                 for dc in dynamic_constants {
                     self.codegen_dynamic_constant(*dc, block)?;
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index b415371f..c612acac 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1647,7 +1647,7 @@ fn object_allocation(
     _liveness: &Liveness,
     backing_allocations: &BackingAllocations,
 ) -> FunctionBackingAllocation {
-    let mut fba = BTreeMap::new();
+    let mut fba = FunctionBackingAllocation::new();
 
     let node_ids = editor.node_ids();
     editor.edit(|mut edit| {
@@ -1661,8 +1661,8 @@ fn object_allocation(
                         let (total, offsets) =
                             fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
                         *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]);
-                        offsets.insert(id, *total);
                         let type_size = type_size(&mut edit, typing[id.idx()], alignments);
+                        offsets.insert(id, (*total, type_size));
                         *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size));
                     }
                 }
@@ -1689,7 +1689,6 @@ fn object_allocation(
                             // We don't know the alignment requirement of the memory
                             // in the callee, so just assume the largest alignment.
                             *total = align(&mut edit, *total, LARGEST_ALIGNMENT);
-                            offsets.insert(id, *total);
                             // Substitute the dynamic constant parameters in the
                             // callee's backing size.
                             callee_backing_size = substitute_dynamic_constants(
@@ -1697,6 +1696,7 @@ fn object_allocation(
                                 callee_backing_size,
                                 &mut edit,
                             );
+                            offsets.insert(id, (*total, callee_backing_size));
                             // Multiply the backing allocation size of the
                             // callee by the number of parallel threads that
                             // will call the function.
diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch
index 6f1ee14b..4bd3254b 100644
--- a/juno_samples/edge_detection/src/cpu.sch
+++ b/juno_samples/edge_detection/src/cpu.sch
@@ -107,8 +107,8 @@ simpl!(reject_zero_crossings);
 
 async-call(edge_detection@le, edge_detection@zc);
 
-fork-split(gaussian_smoothing_body, laplacian_estimate, laplacian_estimate_body, zero_crossings, zero_crossings_body, gradient, reject_zero_crossings);
-unforkify(gaussian_smoothing_body, laplacian_estimate, laplacian_estimate_body, zero_crossings, zero_crossings_body, gradient, reject_zero_crossings);
+fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings);
+unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings);
 
 simpl!(*);
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 77437a61..5f2fa4cc 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -900,6 +900,7 @@ impl PassManager {
         self.make_typing();
         self.make_control_subgraphs();
         self.make_fork_join_maps();
+        self.make_fork_join_nests();
         self.make_fork_control_maps();
         self.make_fork_trees();
         self.make_nodes_in_fork_joins();
@@ -917,6 +918,7 @@ impl PassManager {
             typing: Some(typing),
             control_subgraphs: Some(control_subgraphs),
             fork_join_maps: Some(fork_join_maps),
+            fork_join_nests: Some(fork_join_nests),
             fork_control_maps: Some(fork_control_maps),
             fork_trees: Some(fork_trees),
             nodes_in_fork_joins: Some(nodes_in_fork_joins),
@@ -990,7 +992,7 @@ impl PassManager {
                     &typing[idx],
                     &control_subgraphs[idx],
                     &fork_join_maps[idx],
-                    &fork_control_maps[idx],
+                    &fork_join_nests[idx],
                     &fork_trees[idx],
                     &nodes_in_fork_joins[idx],
                     &collection_objects,
-- 
GitLab