From ae1863d9976aad58abc1023e667424bf762f8aa4 Mon Sep 17 00:00:00 2001
From: Praneet Rathi <prrathi10@gmail.com>
Date: Mon, 20 Jan 2025 12:19:10 -0600
Subject: [PATCH] still untested

---
 hercules_cg/src/gpu.rs   | 4 ++--
 hercules_opt/src/pass.rs | 3 +++
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index be797b2a..21d284b3 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -22,6 +22,7 @@ pub fn gpu_codegen<W: Write>(
     control_subgraph: &Subgraph,
     bbs: &BasicBlocks,
     collection_objects: &FunctionCollectionObjects,
+    fork_join_map: &HashMap<NodeID, NodeID>,
     w: &mut W,
 ) -> Result<(), Error> {
     /*
@@ -80,7 +81,6 @@ pub fn gpu_codegen<W: Write>(
         .collect();
 
 
-    let fork_join_map = fork_join_map(function, control_subgraph);
     let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
         .iter()
         .map(|(fork, join)| (*join, *fork))
@@ -239,7 +239,7 @@ struct GPUContext<'a> {
     bbs: &'a BasicBlocks,
     kernel_params: &'a GPUKernelParams,
     def_use_map: &'a ImmutableDefUseMap,
-    fork_join_map: HashMap<NodeID, NodeID>,
+    fork_join_map: &'a HashMap<NodeID, NodeID>,
     join_fork_map: HashMap<NodeID, NodeID>,
     fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index bb70bf08..9b4c09aa 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -984,11 +984,13 @@ impl PassManager {
                     self.make_control_subgraphs();
                     self.make_collection_objects();
                     self.make_callgraph();
+                    self.make_fork_join_maps();
                     let typing = self.typing.as_ref().unwrap();
                     let control_subgraphs = self.control_subgraphs.as_ref().unwrap();
                     let bbs = self.bbs.as_ref().unwrap();
                     let collection_objects = self.collection_objects.as_ref().unwrap();
                     let callgraph = self.callgraph.as_ref().unwrap();
+                    let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
 
                     let devices = device_placement(&self.module.functions, &callgraph);
 
@@ -1029,6 +1031,7 @@ impl PassManager {
                                 &control_subgraphs[idx],
                                 &bbs[idx],
                                 &collection_objects[&FunctionID::new(idx)],
+                                &fork_join_maps[idx],
                                 &mut cuda_ir,
                             )
                             .unwrap(),
-- 
GitLab