From 70c06a3b461800b509649fd0ed94ac8b11de9847 Mon Sep 17 00:00:00 2001
From: prrathi <prrathi10@gmail.com>
Date: Wed, 22 Jan 2025 16:18:30 +0000
Subject: [PATCH] cleanup

---
 hercules_cg/src/fork_tree.rs |  37 ++++++++++
 hercules_cg/src/gpu.rs       | 126 +++++++++++++----------------------
 hercules_cg/src/lib.rs       |   4 ++
 hercules_opt/src/pass.rs     |  52 ++++++++++++++-
 4 files changed, 137 insertions(+), 82 deletions(-)

diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs
index e69de29b..da7f640a 100644
--- a/hercules_cg/src/fork_tree.rs
+++ b/hercules_cg/src/fork_tree.rs
@@ -0,0 +1,37 @@
+use std::collections::{HashMap, HashSet};
+
+use crate::*;
+
+/* Construct a map from each fork node F to all forks satisfying:
+ * a) domination by F
+ * b) no domination by F's join
+ * c) no domination by any other fork that's also dominated by F, where we don't count self-domination
+ * Note that the fork_tree also includes the non-fork start node, as unique root node.
+ */
+pub fn fork_tree(function: &Function, fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+    let mut fork_tree = HashMap::new();
+    for (control, forks) in fork_join_nesting {
+        if function.nodes[control.idx()].is_fork() {
+            fork_tree.entry(*control).or_insert_with(HashSet::new);
+            let nesting_fork = forks.get(1).copied().unwrap_or(NodeID::new(0));
+            fork_tree.entry(nesting_fork).or_insert_with(HashSet::new).insert(*control);
+        }
+    }
+    fork_tree
+}
+
+/*
+ * Construct a map from fork node to all control nodes (including itself) satisfying:
+ * a) domination by F
+ * b) no domination by F's join
+ * c) no domination by any other fork that's also dominated by F, where we do count self-domination
+ * Here too we include the non-fork start node, as key for all controls outside any fork.
+ */
+pub fn fork_control_map(fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>) -> HashMap<NodeID, HashSet<NodeID>> {
+    let mut fork_control_map = HashMap::new();
+    for (control, forks) in fork_join_nesting {
+        let fork = forks.first().copied().unwrap_or(NodeID::new(0));
+        fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(*control);
+    }
+    fork_control_map
+}
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index df95f63f..d960de89 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -22,7 +22,10 @@ pub fn gpu_codegen<W: Write>(
     control_subgraph: &Subgraph,
     bbs: &BasicBlocks,
     collection_objects: &FunctionCollectionObjects,
+    def_use_map: &ImmutableDefUseMap,
     fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
+    fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
     w: &mut W,
 ) -> Result<(), Error> {
     /*
@@ -133,13 +136,6 @@ pub fn gpu_codegen<W: Write>(
         }
     }
 
-    // Temporary hardcoded values
-    let kernel_params = &GPUKernelParams {
-        max_num_blocks: 1024,
-        max_num_threads: 1024,
-        threads_per_warp: 32,
-    };
-
     // Map from control to pairs of data to update phi
     // For each phi, we go to its region and get region's controls
     let mut control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>> = HashMap::new();
@@ -154,7 +150,18 @@ pub fn gpu_codegen<W: Write>(
         }
     }
 
-    let def_use_map = &def_use(function);
+    let return_parameter = if collection_objects.returned_objects().len() == 1 {
+        Some(collection_objects.origin(*collection_objects.returned_objects()
+            .first().unwrap()).try_parameter().unwrap())
+    } else {
+        None
+    };
+
+    let kernel_params = &GPUKernelParams {
+        max_num_blocks: 1024,
+        max_num_threads: 1024,
+        threads_per_warp: 32,
+    };
 
     let ctx = GPUContext {
         function,
@@ -165,13 +172,16 @@ pub fn gpu_codegen<W: Write>(
         control_subgraph,
         bbs,
         collection_objects,
-        kernel_params,
         def_use_map,
         fork_join_map,
+        fork_control_map,
+        fork_tree,
         join_fork_map,
         fork_reduce_map,
         reduct_reduce_map,
         control_data_phi_map,
+        return_parameter,
+        kernel_params,
     };
     ctx.codegen_function(w)
 }
@@ -191,13 +201,16 @@ struct GPUContext<'a> {
     control_subgraph: &'a Subgraph,
     bbs: &'a BasicBlocks,
     collection_objects: &'a FunctionCollectionObjects,
-    kernel_params: &'a GPUKernelParams,
     def_use_map: &'a ImmutableDefUseMap,
     fork_join_map: &'a HashMap<NodeID, NodeID>,
+    fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>,
+    fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>,
     join_fork_map: HashMap<NodeID, NodeID>,
     fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
+    return_parameter: Option<usize>,
+    kernel_params: &'a GPUKernelParams,
 }
 
 /*
@@ -248,13 +261,7 @@ impl GPUContext<'_> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // Emit all code up to the "goto" to Start's block
         let mut top = String::new();
-        let return_parameter = if self.collection_objects.returned_objects().len() == 1 {
-            Some(self.collection_objects.origin(*self.collection_objects.returned_objects()
-                .first().unwrap()).try_parameter().unwrap())
-        } else {
-            None
-        };
-        self.codegen_kernel_begin(return_parameter.is_none(), &mut top)?;
+        self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
         let mut dynamic_shared_offset = "0".to_string();
         self.codegen_dynamic_constants(&mut top)?;
         self.codegen_declare_data(&mut top)?;
@@ -278,11 +285,10 @@ impl GPUContext<'_> {
             (1, 1)
         } else {
             // Create structures and determine block and thread parallelization strategy
-            let (fork_tree, fork_control_map) = self.make_fork_structures(self.fork_join_map);
             let (root_forks, num_blocks) =
-                self.get_root_forks_and_num_blocks(&fork_tree, self.kernel_params.max_num_blocks);
-            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &fork_tree, num_blocks);
-            let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&fork_tree, thread_root_root_fork);
+                self.get_root_forks_and_num_blocks(&self.fork_tree, self.kernel_params.max_num_blocks);
+            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, num_blocks);
+            let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
             // TODO: Uncomment and adjust once we know logic of extra dim
             // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
             let extra_dim_collects = HashSet::new();
@@ -295,8 +301,6 @@ impl GPUContext<'_> {
                     None
                 },
                 &thread_root_forks,
-                &fork_tree,
-                &fork_control_map,
                 &fork_thread_quota_map,
                 &extra_dim_collects,
                 &mut dynamic_shared_offset,
@@ -315,7 +319,7 @@ impl GPUContext<'_> {
 
         // Emit host launch code
         let mut host_launch = String::new();
-        self.codegen_launch_code(num_blocks, num_threads, &dynamic_shared_offset, return_parameter, &mut host_launch)?;
+        self.codegen_launch_code(num_blocks, num_threads, &dynamic_shared_offset, &mut host_launch)?;
         write!(w, "{}", host_launch)?;
 
         Ok(())
@@ -489,7 +493,7 @@ namespace cg = cooperative_groups;
         Ok(())
     }
 
-    fn codegen_launch_code(&self, num_blocks: usize, num_threads: usize, dynamic_shared_offset: &str, return_parameter: Option<usize>, w: &mut String) -> Result<(), Error> {
+    fn codegen_launch_code(&self, num_blocks: usize, num_threads: usize, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> {
         // The following steps are for host-side C function arguments, but we also
         // need to pass arguments to kernel, so we keep track of the arguments here.
         let ret_type = self.get_type(self.function.return_type, false);
@@ -521,7 +525,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             write!(pass_args, "p{}", idx)?;
         }
         write!(w, ") {{\n")?;
-        let has_ret_var = return_parameter.is_none();
+        let has_ret_var = self.return_parameter.is_none();
         if has_ret_var {
             // Allocate return parameter and lift to kernel argument
             let ret_type_pnt = self.get_type(self.function.return_type, true);
@@ -540,46 +544,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?;
             write!(w, "\treturn host_ret;\n")?;
         } else {
-            write!(w, "\treturn p{};\n", return_parameter.unwrap())?;
+            write!(w, "\treturn p{};\n", self.return_parameter.unwrap())?;
         }
         write!(w, "}}\n")?;
         Ok(())
     }
 
-    /* Create fork_tree, a map from each fork node F to all forks satisfying:
-     * a) domination by F
-     * b) no domination by F's join
-     * c) no domination by any other fork that's also dominated by F, where we don't count self-domination
-     * Note that the fork_tree also includes the start node, to include all controls
-     * outside any fork.
-     *
-     * Second, fork_control_map is a map from fork node to all control nodes (including itself) satisfying:
-     * a) domination by F
-     * b) no domination by F's join
-     * c) no domination by any other fork that's also dominated by F, where we do count self-domination
-     */
-    fn make_fork_structures(&self, fork_join_map: &HashMap<NodeID, NodeID>) -> (HashMap<NodeID, HashSet<NodeID>>, HashMap<NodeID, HashSet<NodeID>>) {
-        let dom = dominator(self.control_subgraph, NodeID::new(0));
-        let fork_nesting = compute_fork_join_nesting(self.function, &dom, fork_join_map);
-        fork_nesting.into_iter().fold(
-            (HashMap::new(), HashMap::new()),
-            |(mut fork_tree, mut fork_control_map), (control, forks)| {
-                if self.function.nodes[control.idx()].is_fork() {
-                    // If control node is fork make sure it's in the fork_tree even
-                    // if has no nested forks.
-                    fork_tree.entry(control).or_insert_with(HashSet::new);
-                    // Then get it's nesting fork- index = 1 to not count itself!
-                    let nesting_fork = forks.get(1).copied().unwrap_or(NodeID::new(0));
-                    fork_tree.entry(nesting_fork).or_insert_with(HashSet::new).insert(control);
-                }
-                // Here the desired fork is always the first fork
-                let fork = forks.first().copied().unwrap_or(NodeID::new(0));
-                fork_control_map.entry(fork).or_insert_with(HashSet::new).insert(control);
-                (fork_tree, fork_control_map)
-            },
-        )
-    }
-
     /*
      * If tree has a single root fork of known size s <= max_num_blocks
      * with parallel-fork schedule, then set num_blocks to s, else set num_blocks
@@ -808,8 +778,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         &self,
         block_fork: Option<NodeID>,
         thread_root_forks: &HashSet<NodeID>,
-        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
-        fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
         fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
         extra_dim_collects: &HashSet<TypeID>,
         dynamic_shared_offset: &mut String,
@@ -820,7 +788,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         // First emit data and control gen for each control node outside any fork.
         // Recall that this was tracked through a fake fork node with NodeID 0.
         let mut state = KernelState::OutBlock;
-        for control in fork_control_map.get(&NodeID::new(0)).unwrap() {
+        for control in self.fork_control_map.get(&NodeID::new(0)).unwrap() {
             let goto = gotos.get_mut(control).unwrap();
             let init = &mut goto.init;
             let post_init = &mut goto.post_init;
@@ -834,7 +802,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         // Then generate data and control for the single block fork if it exists
         if block_fork.is_some() {
             state = KernelState::InBlock;
-            for control in fork_control_map.get(&block_fork.unwrap()).unwrap() {
+            for control in self.fork_control_map.get(&block_fork.unwrap()).unwrap() {
                 let goto = gotos.get_mut(control).unwrap();
                 let init = &mut goto.init;
                 let post_init = &mut goto.post_init;
@@ -852,8 +820,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             self.codegen_data_control_traverse(
                 root_fork,
                 state,
-                fork_tree,
-                fork_control_map,
                 fork_thread_quota_map,
                 1,
                 num_threads,
@@ -875,8 +841,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         &self,
         curr_fork: NodeID,
         state: KernelState,
-        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
-        fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
         fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
         parent_quota: usize,
         num_threads: usize,
@@ -902,7 +866,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         } else {
             HashSet::new()
         };
-        for control in fork_control_map.get(&curr_fork).unwrap() {
+        for control in self.fork_control_map.get(&curr_fork).unwrap() {
             let goto = gotos.get_mut(control).unwrap();
             let init = &mut goto.init;
             let post_init = &mut goto.post_init;
@@ -925,12 +889,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 )?;
             }
         }
-        for child in fork_tree.get(&curr_fork).unwrap() {
+        for child in self.fork_tree.get(&curr_fork).unwrap() {
             self.codegen_data_control_traverse(
                 *child,
                 state,
-                fork_tree,
-                fork_control_map,
                 fork_thread_quota_map,
                 use_thread_quota,
                 num_threads,
@@ -1424,14 +1386,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 tabs
             }
             Node::Return { control: _, data } => {
-                // Since we lift originally primitive returns into a parameter,
-                // we write to that parameter upon return.
-                let return_val = self.get_value(*data, false, false);
-                let return_type_ptr = self.get_type(self.function.return_type, true);
-                write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
-                write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?;
-                write!(w_term, "\t}}\n")?;
-                write!(w_term, "\treturn;\n")?;
+                if self.return_parameter.is_none() {
+                    // Since we lift return into a kernel argument, we write to that
+                    // argument upon return.
+                    let return_val = self.get_value(*data, false, false);
+                    let return_type_ptr = self.get_type(self.function.return_type, true);
+                    write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
+                    write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?;
+                    write!(w_term, "\t}}\n")?;
+                    write!(w_term, "\treturn;\n")?;
+                }
                 1
             }
             _ => {
diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index e41f0205..fbab6dbc 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -5,11 +5,15 @@ pub mod gpu;
 pub mod device;
 pub mod rt;
 
+pub mod fork_tree;
+
 pub use crate::cpu::*;
 pub use crate::gpu::*;
 pub use crate::device::*;
 pub use crate::rt::*;
 
+pub use crate::fork_tree::*;
+
 use hercules_ir::*;
 
 /*
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index 9b4c09aa..295b8bcb 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -68,6 +68,8 @@ pub struct PassManager {
     pub postdoms: Option<Vec<DomTree>>,
     pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>,
     pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>,
+    pub fork_control_maps: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
+    pub fork_trees: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub loops: Option<Vec<LoopTree>>,
     pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
@@ -89,6 +91,8 @@ impl PassManager {
             postdoms: None,
             fork_join_maps: None,
             fork_join_nests: None,
+            fork_control_maps: None,
+            fork_trees: None,
             loops: None,
             reduce_cycles: None,
             data_nodes_in_fork_joins: None,
@@ -204,6 +208,31 @@ impl PassManager {
         }
     }
 
+    pub fn make_fork_control_maps(&mut self) {
+        if self.fork_control_maps.is_none() {
+            self.make_fork_join_nests();
+            self.fork_control_maps = Some(
+                self.fork_join_nests.as_ref().unwrap().iter().map(fork_control_map).collect(),
+            );
+        }
+    }
+
+    pub fn make_fork_trees(&mut self) {
+        if self.fork_trees.is_none() {
+            self.make_fork_join_nests();
+            self.fork_trees = Some(
+                zip(
+                    self.module.functions.iter(),
+                    self.fork_join_nests.as_ref().unwrap().iter(),
+                )
+                .map(|(function, fork_join_nesting)| {
+                    fork_tree(function, fork_join_nesting)
+                })
+                .collect(),
+            );
+        }
+    }
+
     pub fn make_loops(&mut self) {
         if self.loops.is_none() {
             self.make_control_subgraphs();
@@ -985,14 +1014,20 @@ impl PassManager {
                     self.make_collection_objects();
                     self.make_callgraph();
                     self.make_fork_join_maps();
+                    self.make_fork_control_maps();
+                    self.make_fork_trees();
+                    self.make_def_uses();
                     let typing = self.typing.as_ref().unwrap();
                     let control_subgraphs = self.control_subgraphs.as_ref().unwrap();
                     let bbs = self.bbs.as_ref().unwrap();
                     let collection_objects = self.collection_objects.as_ref().unwrap();
                     let callgraph = self.callgraph.as_ref().unwrap();
+                    let def_uses = self.def_uses.as_ref().unwrap();
                     let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
+                    let fork_control_maps = self.fork_control_maps.as_ref().unwrap();
+                    let fork_trees = self.fork_trees.as_ref().unwrap();
 
-                    let devices = device_placement(&self.module.functions, &callgraph);
+                    let devices = device_placement(&self.module.functions, callgraph);
 
                     let mut rust_rt = String::new();
                     let mut llvm_ir = String::new();
@@ -1031,7 +1066,10 @@ impl PassManager {
                                 &control_subgraphs[idx],
                                 &bbs[idx],
                                 &collection_objects[&FunctionID::new(idx)],
+                                &def_uses[idx],
                                 &fork_join_maps[idx],
+                                &fork_control_maps[idx],
+                                &fork_trees[idx],
                                 &mut cuda_ir,
                             )
                             .unwrap(),
@@ -1082,6 +1120,12 @@ impl PassManager {
                         file.write_all(cuda_ir.as_bytes())
                             .expect("PANIC: Unable to write output CUDA IR file contents.");
 
+                        let cuda_text_path = format!("{}.cu", module_name);
+                        let mut cuda_text_file = File::create(&cuda_text_path)
+                            .expect("PANIC: Unable to open CUDA IR text file.");
+                        cuda_text_file.write_all(cuda_ir.as_bytes())
+                            .expect("PANIC: Unable to write CUDA IR text file contents.");
+
                         let mut nvcc_process = Command::new("nvcc")
                             .arg("-c")
                             .arg("-O3")
@@ -1109,6 +1153,12 @@ impl PassManager {
                     file.write_all(rust_rt.as_bytes())
                         .expect("PANIC: Unable to write output Rust runtime file contents.");
 
+                    let rt_text_path = format!("{}.hrt", module_name);
+                    let mut rt_text_file = File::create(&rt_text_path)
+                        .expect("PANIC: Unable to open Rust runtime text file.");
+                    rt_text_file.write_all(rust_rt.as_bytes())
+                        .expect("PANIC: Unable to write Rust runtime text file contents.");
+
                 }
                 Pass::Serialize(output_file) => {
                     let module_contents: Vec<u8> = postcard::to_allocvec(&self.module).unwrap();
-- 
GitLab