From 2b60a19fddaf6cb7f3f2c46cff3212e73ca5e2a7 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 25 Feb 2025 11:50:59 -0600
Subject: [PATCH] Allocate more memory for calls inside fork-joins

---
 hercules_opt/src/gcm.rs | 23 +++++++++++++++++++++--
 1 file changed, 21 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 939f1502..b415371f 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -152,6 +152,7 @@ pub fn gcm(
     let backing_allocation = object_allocation(
         editor,
         typing,
+        fork_join_nest,
         &node_colors,
         &alignments,
         &liveness,
@@ -1148,6 +1149,7 @@ fn add_extra_collection_dims(
                     collect: new_cons,
                     indices: Box::new([Index::Position(tids.into_boxed_slice())]),
                 });
+                edit.sub_edit(id, new_cons);
                 edit = edit.replace_all_uses(id, read)?;
                 edit = edit.delete_node(id)?;
                 Ok(edit)
@@ -1639,6 +1641,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
 fn object_allocation(
     editor: &mut FunctionEditor,
     typing: &Vec<TypeID>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
     node_colors: &FunctionNodeColors,
     alignments: &Vec<usize>,
     _liveness: &Liveness,
@@ -1664,7 +1667,7 @@ fn object_allocation(
                     }
                 }
                 Node::Call {
-                    control: _,
+                    control,
                     function: callee,
                     ref dynamic_constants,
                     args: _,
@@ -1694,9 +1697,25 @@ fn object_allocation(
                                 callee_backing_size,
                                 &mut edit,
                             );
+                            // Multiply the backing allocation size of the
+                            // callee by the number of parallel threads that
+                            // will call the function.
+                            let forks = &fork_join_nest[&control];
+                            let factors: Vec<_> = forks
+                                .into_iter()
+                                .rev()
+                                .flat_map(|id| edit.get_node(*id).try_fork().unwrap().1.into_iter())
+                                .map(|dc| *dc)
+                                .collect();
+                            let mut multiplied_callee_backing_size = callee_backing_size;
+                            for factor in factors {
+                                multiplied_callee_backing_size = edit.add_dynamic_constant(
+                                    DynamicConstant::mul(multiplied_callee_backing_size, factor),
+                                );
+                            }
                             *total = edit.add_dynamic_constant(DynamicConstant::add(
                                 *total,
-                                callee_backing_size,
+                                multiplied_callee_backing_size,
                             ));
                         }
                     }
-- 
GitLab