From 37567e75c3f7adbb96248a12084bb0ed51914f5b Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 25 Feb 2025 11:30:47 -0600
Subject: [PATCH] Add extra dimensions to arrays when they'd cause a race
 otherwise

---
 hercules_opt/src/gcm.rs | 54 +++++++++++++++++++++++++++++++++++------
 1 file changed, 46 insertions(+), 8 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index ae8801f6..939f1502 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1053,19 +1053,22 @@ fn add_extra_collection_dims(
     devices: &Vec<Device>,
     bbs: &BasicBlocks,
 ) -> bool {
-    if devices[editor.func_id().idx()] == Device::AsyncRust
-        && editor.func().name == "_1_laplacian_estimate"
-    {
+    if devices[editor.func_id().idx()] == Device::AsyncRust {
         // Look for collection constant nodes inside fork-joins that are mutated
         // inside the fork-join, aren't involved in any of the reduces of the
         // fork-join, and have a user that isn't a direct read based on all of
         // the thread IDs.
-        let nodes = &editor.func().nodes;
         let fco = &objects[&editor.func_id()];
-        for id in editor.node_ids().filter(|id| {
-            nodes[id.idx()].is_constant() && !editor.get_type(typing[id.idx()]).is_primitive()
-        }) {
+        let candidates: Vec<_> = editor
+            .node_ids()
+            .filter(|id| {
+                editor.func().nodes[id.idx()].is_constant()
+                    && !editor.get_type(typing[id.idx()]).is_primitive()
+            })
+            .collect();
+        for id in candidates {
             // Check all of the above conditions.
+            let nodes = &editor.func().nodes;
             if editor.get_users(id).len() != 1 {
                 continue;
             }
@@ -1115,7 +1118,42 @@ fn add_extra_collection_dims(
 
             // We know that this collection needs to be replicated across the
             // fork-join dimensions, so do that.
-            todo!()
+            let ty = typing[id.idx()];
+            let num_dims: Vec<_> = forks
+                .into_iter()
+                .rev()
+                .map(|id| nodes[id.idx()].try_fork().unwrap().1.len())
+                .collect();
+            let factors = forks
+                .into_iter()
+                .rev()
+                .flat_map(|id| nodes[id.idx()].try_fork().unwrap().1.into_iter())
+                .map(|dc| *dc)
+                .collect();
+            let array_ty = Type::Array(ty, factors);
+            let success = editor.edit(|mut edit| {
+                let new_ty = edit.add_type(array_ty);
+                let new_cons = edit.add_zero_constant(new_ty);
+                let new_cons = edit.add_node(Node::Constant { id: new_cons });
+                let mut tids = vec![];
+                for (fork, num_dims) in forks.into_iter().rev().zip(num_dims) {
+                    for dim in 0..num_dims {
+                        tids.push(edit.add_node(Node::ThreadID {
+                            control: *fork,
+                            dimension: dim,
+                        }));
+                    }
+                }
+                let read = edit.add_node(Node::Read {
+                    collect: new_cons,
+                    indices: Box::new([Index::Position(tids.into_boxed_slice())]),
+                });
+                edit = edit.replace_all_uses(id, read)?;
+                edit = edit.delete_node(id)?;
+                Ok(edit)
+            });
+            assert!(success);
+            return true;
         }
     }
     false
-- 
GitLab