From ae8f17eb21ca53bbd1d9318518c4c2d62a3e1437 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 25 Feb 2025 10:53:22 -0600
Subject: [PATCH] Look for arrays that are used in parallel

---
 hercules_cg/src/lib.rs         |  2 +-
 hercules_ir/src/collections.rs |  3 +-
 hercules_opt/src/gcm.rs        | 94 ++++++++++++++++++++++++++++++++++
 hercules_rt/src/lib.rs         |  2 +-
 juno_scheduler/src/pm.rs       |  3 ++
 5 files changed, 101 insertions(+), 3 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 446231de..98f91e1f 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -16,7 +16,7 @@ use std::collections::BTreeMap;
 
 use hercules_ir::*;
 
-pub const LARGEST_ALIGNMENT: usize = 32;
+pub const LARGEST_ALIGNMENT: usize = 64;
 
 /*
  * The alignment of a type does not depend on dynamic constants.
diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index 60f4fb1c..a8fc3c67 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -202,7 +202,7 @@ pub fn collection_objects(
                     let fco = &collection_objects[&callee];
                     if fco.returned[*selection]
                         .iter()
-                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_some())
+                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_none())
                     {
                         // If the callee may return a new collection object, then
                         // this data projection node originates a single collection object. The
@@ -412,6 +412,7 @@ pub fn collection_objects(
             }
         }
 
+        assert_eq!(objects_per_node.len(), func.nodes.len());
         let fco = FunctionCollectionObjects {
             objects_per_node,
             mutated,
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index d3119705..ae8801f6 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -82,6 +82,7 @@ pub fn gcm(
     control_subgraph: &Subgraph,
     dom: &DomTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
     loops: &LoopTree,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     objects: &CollectionObjects,
@@ -120,6 +121,18 @@ pub fn gcm(
         return None;
     }
 
+    if add_extra_collection_dims(
+        editor,
+        typing,
+        fork_join_map,
+        fork_join_nest,
+        objects,
+        devices,
+        &bbs,
+    ) {
+        return None;
+    }
+
     let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else {
         return None;
     };
@@ -1027,6 +1040,87 @@ fn spill_clones(
     }
 }
 
+/*
+ * Look for mutated collections placed inside fork-joins in AsyncRust functions.
+ * These collections should be duplicated across the size of the fork-join.
+ */
+fn add_extra_collection_dims(
+    editor: &mut FunctionEditor,
+    typing: &Vec<TypeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
+    objects: &CollectionObjects,
+    devices: &Vec<Device>,
+    bbs: &BasicBlocks,
+) -> bool {
+    if devices[editor.func_id().idx()] == Device::AsyncRust
+        && editor.func().name == "_1_laplacian_estimate"
+    {
+        // Look for collection constant nodes inside fork-joins that are mutated
+        // inside the fork-join, aren't involved in any of the reduces of the
+        // fork-join, and have a user that isn't a direct read based on all of
+        // the thread IDs.
+        let nodes = &editor.func().nodes;
+        let fco = &objects[&editor.func_id()];
+        for id in editor.node_ids().filter(|id| {
+            nodes[id.idx()].is_constant() && !editor.get_type(typing[id.idx()]).is_primitive()
+        }) {
+            // Check all of the above conditions.
+            if editor.get_users(id).len() != 1 {
+                continue;
+            }
+            let forks = &fork_join_nest[&bbs.0[id.idx()]];
+            if forks.is_empty() {
+                continue;
+            }
+            let object = fco.objects(id)[0];
+            let mutated_inside = fco
+                .mutators(object)
+                .into_iter()
+                .any(|id| &fork_join_nest[&bbs.0[id.idx()]] == forks);
+            if !mutated_inside {
+                continue;
+            }
+            let in_reduce = forks.into_iter().any(|id| {
+                let join = fork_join_map[id];
+                let mut reduces = editor
+                    .get_users(join)
+                    .filter(|id| nodes[id.idx()].is_reduce());
+                reduces.any(|id| fco.objects(id).contains(&object))
+            });
+            if in_reduce {
+                continue;
+            }
+            if let Node::Read {
+                collect: _,
+                ref indices,
+            } = nodes[editor.get_users(id).next().unwrap().idx()]
+                && let Index::Position(ref pos) = indices[0]
+                && {
+                    let tid_pos: BTreeSet<(NodeID, usize)> = pos
+                        .into_iter()
+                        .filter_map(|id| nodes[id.idx()].try_thread_id())
+                        .collect();
+                    let reference: BTreeSet<(NodeID, usize)> = forks
+                        .into_iter()
+                        .flat_map(|id| {
+                            (0..nodes[id.idx()].try_fork().unwrap().1.len()).map(|dim| (*id, dim))
+                        })
+                        .collect();
+                    tid_pos == reference
+                }
+            {
+                continue;
+            }
+
+            // We know that this collection needs to be replicated across the
+            // fork-join dimensions, so do that.
+            todo!()
+        }
+    }
+    false
+}
+
 type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>;
 
 /*
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index df53a0e9..a5954ca0 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -13,7 +13,7 @@ use std::sync::OnceLock;
  * src/rt.rs (the RT backend).
  */
 
-pub const LARGEST_ALIGNMENT: usize = 32;
+pub const LARGEST_ALIGNMENT: usize = 64;
 
 pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 {
     let ptr = alloc(Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap());
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 4656d841..77437a61 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2105,6 +2105,7 @@ fn run_pass(
                 pm.make_control_subgraphs();
                 pm.make_doms();
                 pm.make_fork_join_maps();
+                pm.make_fork_join_nests();
                 pm.make_loops();
                 pm.make_reduce_cycles();
                 pm.make_collection_objects();
@@ -2115,6 +2116,7 @@ fn run_pass(
                 let typing = pm.typing.take().unwrap();
                 let doms = pm.doms.take().unwrap();
                 let fork_join_maps = pm.fork_join_maps.take().unwrap();
+                let fork_join_nests = pm.fork_join_nests.take().unwrap();
                 let loops = pm.loops.take().unwrap();
                 let reduce_cycles = pm.reduce_cycles.take().unwrap();
                 let control_subgraphs = pm.control_subgraphs.take().unwrap();
@@ -2136,6 +2138,7 @@ fn run_pass(
                         &control_subgraphs[id.idx()],
                         &doms[id.idx()],
                         &fork_join_maps[id.idx()],
+                        &fork_join_nests[id.idx()],
                         &loops[id.idx()],
                         &reduce_cycles[id.idx()],
                         &collection_objects,
-- 
GitLab