From 32ad1e8299a1b903b2855c72aa9809e520280f74 Mon Sep 17 00:00:00 2001
From: Praneet Rathi <prrathi10@gmail.com>
Date: Mon, 20 Jan 2025 12:16:55 -0600
Subject: [PATCH] untested

---
 hercules_cg/src/gpu.rs | 149 ++++++++++++++++-------------------------
 1 file changed, 57 insertions(+), 92 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index bb28db8a..be797b2a 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -80,17 +80,17 @@ pub fn gpu_codegen<W: Write>(
         .collect();
 
 
-    let fork_join_map = &fork_join_map(function, control_subgraph);
-    let join_fork_map: &HashMap<NodeID, NodeID> = &fork_join_map
-        .into_iter()
+    let fork_join_map = fork_join_map(function, control_subgraph);
+    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
+        .iter()
         .map(|(fork, join)| (*join, *fork))
         .collect();
     // Fork Reduce map should have all reduces contained in some key
-    let fork_reduce_map: &mut HashMap<NodeID, Vec<NodeID>> = &mut HashMap::new();
+    let mut fork_reduce_map: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
     // Reduct Reduce map should have all non-parallel and non-associative reduces
     // contained in some key. Unlike Fork, Reduct is not involved in any assertions.
     // It's placed here for convenience but can be moved.
-    let reduct_reduce_map: &mut HashMap<NodeID, Vec<NodeID>> = &mut HashMap::new();
+    let mut reduct_reduce_map: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
     for reduce_node in &reduce_nodes {
         if let Node::Reduce {
             control,
@@ -124,11 +124,13 @@ pub fn gpu_codegen<W: Write>(
         }
     }
     for idx in 0..function.nodes.len() {
-        if function.nodes[idx].is_fork()
-            && fork_reduce_map
-            .get(&NodeID::new(idx)).is_none_or(|reduces| reduces.is_empty())
-        {
-            panic!("Fork node {} has no reduce nodes", idx);
+        if function.nodes[idx].is_fork() {
+            assert!(fork_reduce_map
+                .get(&NodeID::new(idx))
+                .is_none_or(|reduces| reduces.is_empty()),
+                "Fork node {} has no reduce nodes",
+                idx
+            );
         }
     }
 
@@ -155,7 +157,7 @@ pub fn gpu_codegen<W: Write>(
         (NodeID::new(pos), *data)
     };
 
-    let return_type_id = &typing[data_node_id.idx()];
+    let return_type_id = typing[data_node_id.idx()];
     let return_type = &types[return_type_id.idx()];
     let return_param_idx = if !return_type.is_primitive() {
         let objects = &collection_objects.objects(data_node_id);
@@ -186,7 +188,7 @@ pub fn gpu_codegen<W: Write>(
 
     // Map from control to pairs of data to update phi
     // For each phi, we go to its region and get region's controls
-    let control_data_phi_map: &mut HashMap<NodeID, Vec<(NodeID, NodeID)>> = &mut HashMap::new();
+    let mut control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>> = HashMap::new();
     for (idx, node) in function.nodes.iter().enumerate() {
         if let Node::Phi { control, data } = node {
             let Node::Region { preds } = &function.nodes[control.idx()] else {
@@ -237,12 +239,12 @@ struct GPUContext<'a> {
     bbs: &'a BasicBlocks,
     kernel_params: &'a GPUKernelParams,
     def_use_map: &'a ImmutableDefUseMap,
-    fork_join_map: &'a HashMap<NodeID, NodeID>,
-    join_fork_map: &'a HashMap<NodeID, NodeID>,
-    fork_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
-    reduct_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
-    control_data_phi_map: &'a HashMap<NodeID, Vec<(NodeID, NodeID)>>,
-    return_type_id: &'a TypeID,
+    fork_join_map: HashMap<NodeID, NodeID>,
+    join_fork_map: HashMap<NodeID, NodeID>,
+    fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
+    reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
+    control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
+    return_type_id: TypeID,
     return_param_idx: Option<usize>,
 }
 
@@ -318,7 +320,7 @@ impl GPUContext<'_> {
             (1, 1)
         } else {
             // Create structures and determine block and thread parallelization strategy
-            let (fork_tree, fork_control_map) = self.make_fork_structures(self.fork_join_map);
+            let (fork_tree, fork_control_map) = self.make_fork_structures(&self.fork_join_map);
             let (root_forks, num_blocks) =
                 self.get_root_forks_and_num_blocks(&fork_tree, self.kernel_params.max_num_blocks);
             let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &fork_tree, num_blocks);
@@ -422,7 +424,7 @@ namespace cg = cooperative_groups;
             write!(
                 w,
                 "{} __restrict__ ret",
-                self.get_type(*self.return_type_id, true)
+                self.get_type(self.return_type_id, true)
             )?;
         }
 
@@ -536,7 +538,7 @@ namespace cg = cooperative_groups;
         // need to pass arguments to kernel, so we keep track of the arguments here.
         let mut pass_args = String::new();
         let ret_primitive = self.types[self.return_type_id.idx()].is_primitive();
-        let ret_type = self.get_type(*self.return_type_id, false);
+        let ret_type = self.get_type(self.return_type_id, false);
         write!(w, "
 extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         // The first set of parameters are dynamic constants.
@@ -566,7 +568,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         write!(w, ") {{\n")?;
         // Pull primitive return as pointer parameter for kernel
         if ret_primitive {
-            let ret_type_pnt = self.get_type(*self.return_type_id, true);
+            let ret_type_pnt = self.get_type(self.return_type_id, true);
             write!(w, "\t{} ret;\n", ret_type_pnt)?;
             write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
             if !first_param {
@@ -1267,16 +1269,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             // If we read collection, distribute elements among threads with cg
             // sync after. If we read primitive, copy read on all threads.
             Node::Read { collect, indices } => {
-                let is_char = self.is_char(self.typing[collect.idx()]);
-                let collect_with_indices = self.codegen_collect(*collect, indices, is_char, extra_dim_collects.contains(&self.typing[collect.idx()]));
+                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
                 let data_type_id = self.typing[id.idx()];
                 if self.types[data_type_id.idx()].is_primitive() {
-                    if is_char {
-                        let type_name = self.get_type(data_type_id, true);
-                        write!(w, "{}{} = *reinterpret_cast<{}>({});\n", tabs, define_variable, type_name, collect_with_indices)?;
-                    } else {
-                        write!(w, "{}{} = *({});\n", tabs, define_variable, collect_with_indices)?;
-                    }
+                    let type_name = self.get_type(data_type_id, true);
+                    write!(w, "{}{} = *reinterpret_cast<{}>({});\n", tabs, define_variable, type_name, collect_with_indices)?;
                 } else {
                     if KernelState::OutBlock == state && num_blocks.unwrap() > 1 {
                         panic!("GPU can't guarantee correctness for multi-block collection reads");
@@ -1287,13 +1284,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     };
                     // Divide up "elements", which are collection size divided
                     // by element size, among threads.
-                    let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects), Some(true));
-                    let num_elements = format!("({})", data_size);
-                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
+                    let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
+                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
                     write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, define_variable, collect_with_indices)?;
                     write!(w, "{}}}\n", tabs)?;
-                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
-                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, define_variable, cg_tile, num_elements, cg_tile, cg_tile, collect_with_indices, cg_tile, num_elements, cg_tile, cg_tile)?;
+                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
+                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, define_variable, cg_tile, data_size, cg_tile, cg_tile, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile)?;
                     write!(w, "{}}}\n", tabs)?;
                     write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
                 }
@@ -1305,8 +1301,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 data,
                 indices,
             } => {
-                let is_char = self.is_char(self.typing[collect.idx()]);
-                let collect_with_indices = self.codegen_collect(*collect, indices, is_char, extra_dim_collects.contains(&self.typing[collect.idx()]));
+                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
                 let data_variable = self.get_value(*data, false, false);
                 let data_type_id = self.typing[data.idx()];
                 if KernelState::OutBlock == state && num_blocks.unwrap() > 1 {
@@ -1318,21 +1313,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 };
                 if self.types[data_type_id.idx()].is_primitive() {
                     write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
-                    if is_char {
-                        let type_name = self.get_type(data_type_id, true);
-                        write!(w, "{}\t*reinterpret_cast<{}>({}) = {};\n", tabs, type_name, collect_with_indices, data_variable)?;
-                    } else {
-                        write!(w, "{}\t*({}) = {};\n", tabs, collect_with_indices, data_variable)?;
-                    }
+                    let type_name = self.get_type(data_type_id, true);
+                    write!(w, "{}\t*reinterpret_cast<{}>({}) = {};\n", tabs, type_name, collect_with_indices, data_variable)?;
                     write!(w, "{}}}\n", tabs)?;
                 } else {
-                    let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects), Some(true));
-                    let num_elements = format!("({})", data_size);
-                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
+                    let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
+                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
                     write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, collect_with_indices, data_variable)?;
                     write!(w, "{}}}\n", tabs)?;
-                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
-                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, num_elements, cg_tile, cg_tile, data_variable, cg_tile, num_elements, cg_tile, cg_tile)?;
+                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
+                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?;
                     write!(w, "{}}}\n", tabs)?;
                 }
                 write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
@@ -1508,18 +1498,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
 
     /*
      * This function emits collection name + pointer math for the provided indices.
-     * One nuance is whether the collection is represented as char pointer or
-     * the original primitive pointer. For Field, it's always char, for Variant,
-     * it doesn't matter here, and for Array, it depends- so we may need to tack
-     * on the element size to the index math.
+     * All collection types use char pointers.
      */
-    fn codegen_collect(&self, collect: NodeID, indices: &[Index], is_char: bool, has_extra_dim: bool) -> String {
+    fn codegen_collect(&self, collect: NodeID, indices: &[Index], has_extra_dim: bool) -> String {
         let mut index_ptr = "0".to_string();
         let type_id = self.typing[collect.idx()];
         for index in indices {
             match index {
                 Index::Field(field) => {
-                    self.get_size(type_id, Some(*field), None, None);
+                    self.get_size(type_id, Some(*field), None);
                 }
                 // Variants of summations have zero offset
                 Index::Variant(_) => {}
@@ -1550,10 +1537,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                         cumulative_offset,
                         ")".repeat(array_indices.len() - if has_extra_dim { 1 } else { 0 })
                     ));
-                    if is_char {
-                        let element_size = self.get_size(*element_type, None, None, None);
-                        index_ptr.push_str(&format!(" * {}", element_size));
-                    }
+                    let element_size = self.get_size(*element_type, None, None);
+                    index_ptr.push_str(&format!(" * {}", element_size));
                 }
             }
         }
@@ -1600,7 +1585,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             Constant::Product(type_id, constant_fields) => {
                 if allow_allocate {
                     let alignment = self.get_alignment(*type_id);
-                    let size = self.get_size(*type_id, None, extra_dim_collects, None);
+                    let size = self.get_size(*type_id, None, extra_dim_collects);
                     *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
                     write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
                     write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
@@ -1612,7 +1597,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 for i in 0..constant_fields.len() {
                     // For each field update offset and issue recursive call
                     let field_type = self.get_type(type_fields[i], true);
-                    let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects, None);
+                    let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects);
                     let field_constant = &self.constants[constant_fields[i].idx()];
                     if field_constant.is_scalar() {
                         self.codegen_constant(
@@ -1632,7 +1617,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             Constant::Summation(type_id, variant, field) => {
                 if allow_allocate {
                     let alignment = self.get_alignment(*type_id);
-                    let size = self.get_size(*type_id, None, extra_dim_collects, None);
+                    let size = self.get_size(*type_id, None, extra_dim_collects);
                     *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
                     write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
                     write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
@@ -1660,18 +1645,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 };
             }
             Constant::Array(type_id) => {
-                let Type::Array(element_type, _) = &self.types[type_id.idx()] else {
-                    panic!("Expected array type")
-                };
                 if !allow_allocate {
                     panic!("Nested array constant should not be re-allocated");
                 }
                 let alignment = self.get_alignment(*type_id);
-                let size = self.get_size(*type_id, None, extra_dim_collects, None);
-                let element_type = self.get_type(*element_type, true);
+                let size = self.get_size(*type_id, None, extra_dim_collects);
                 *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
                 write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
-                write!(w, "{}{} = reinterpret_cast<{}>(dynamic_shared + dynamic_shared_offset);\n", tabs, name, element_type)?;
+                write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
                 *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
             }
         }
@@ -1684,15 +1665,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
      * and offset to 2nd field. This is useful for constant initialization and read/write
      * index math.
      */
-    fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>, exclude_element_size: Option<bool>) -> String {
+    fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String {
         match &self.types[type_id.idx()] {
             Type::Array(element_type, extents) => {
                 let array_size = multiply_dcs(if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { &extents[1..] } else { extents });
-                if exclude_element_size.unwrap_or(false) {
-                    array_size
-                } else {
-                    format!("{} * {}", self.get_alignment(*element_type), array_size)
-                }
+                format!("{} * {}", self.get_alignment(*element_type), array_size)
             }
             Type::Product(fields) => {
                 let num_fields = &num_fields.unwrap_or(fields.len());
@@ -1700,7 +1677,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     .iter()
                     .enumerate()
                     .filter(|(i, _)| i < num_fields)
-                    .map(|(_, id)| (self.get_size(*id, None, extra_dim_collects, None), self.get_alignment(*id)))
+                    .map(|(_, id)| (self.get_size(*id, None, extra_dim_collects), self.get_alignment(*id)))
                     .fold(String::from("0"), |acc, (size, align)| {
                         if acc == "0" {
                             size
@@ -1715,7 +1692,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     format!(
                         "{} - {}",
                         with_field,
-                        self.get_size(fields[*num_fields], None, extra_dim_collects, None)
+                        self.get_size(fields[*num_fields], None, extra_dim_collects)
                     )
                 } else {
                     with_field
@@ -1725,7 +1702,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 // The argmax variant by size is not guaranteed to be same as
                 // argmax variant by alignment, eg product of 3 4-byte primitives
                 // vs 1 8-byte primitive, so we need to calculate both.
-                let max_size = variants.iter().map(|id| self.get_size(*id, None, extra_dim_collects, None)).fold(
+                let max_size = variants.iter().map(|id| self.get_size(*id, None, extra_dim_collects)).fold(
                     String::from("0"),
                     |acc, x| {
                         if acc == "0" {
@@ -1880,16 +1857,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         func_name.to_string()
     }
 
-    // Check if a type should be represented as char*. Must be a product,
-    // summation, or array of product/summation types.
-    fn is_char(&self, type_id: TypeID) -> bool {
-        match &self.types[type_id.idx()] {
-            Type::Product(_) | Type::Summation(_) => true,
-            Type::Array(element_type, _) => self.is_char(*element_type),
-            _ => false,
-        }
-    }
-
     fn get_cg_tile(&self, fork: NodeID, cg_type: CGType) -> String {
         format!("cg_{}{}", self.get_value(fork, false, false), if cg_type == CGType::Use { "_use" } else if cg_type == CGType::Available { "_available" } else { "" })
     }
@@ -1938,12 +1905,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     }
 
     fn get_type(&self, id: TypeID, make_pointer: bool) -> String {
-        match &self.types[id.idx()] {
-            // Product and summation collections are char* for 1 byte-addressability
-            // since we can have variable type fields
-            Type::Product(_) | Type::Summation(_) => "char*".to_string(),
-            Type::Array(element_type, _) => self.get_type(*element_type, true),
-            _ => convert_type(&self.types[id.idx()], make_pointer),
+        if self.types[id.idx()].is_primitive() {
+            convert_type(&self.types[id.idx()], make_pointer)
+        } else {
+            "char*".to_string()
         }
     }
 
-- 
GitLab