From cd96d8334f53e5b9b7e1385b10679a85939f3559 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 12 Feb 2025 11:23:20 -0600
Subject: [PATCH] Remove extra_dim_collects

---
 hercules_cg/src/gpu.rs | 127 +++++++----------------------------------
 1 file changed, 22 insertions(+), 105 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index c093949f..d6461a1e 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -284,11 +284,7 @@ impl GPUContext<'_> {
 
         // If there are no forks, fast forward to single-block, single-thread codegen
         let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
-            self.codegen_data_control_no_forks(
-                &HashSet::new(),
-                &mut dynamic_shared_offset,
-                &mut gotos,
-            )?;
+            self.codegen_data_control_no_forks(&mut dynamic_shared_offset, &mut gotos)?;
             ("1".to_string(), "1".to_string())
         } else {
             // Create structures and determine block and thread parallelization strategy
@@ -298,10 +294,6 @@ impl GPUContext<'_> {
                 self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
             let (fork_thread_quota_map, num_threads) =
                 self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
-            // TODO: Uncomment and adjust once we know logic of extra dim. This will affect constant
-            // collections, reads, and writes.
-            // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
-            let extra_dim_collects = HashSet::new();
 
             // Core function for the CUDA code of all data and control nodes.
             self.codegen_data_control(
@@ -312,7 +304,6 @@ impl GPUContext<'_> {
                 },
                 &thread_root_forks,
                 &fork_thread_quota_map,
-                &extra_dim_collects,
                 &mut dynamic_shared_offset,
                 is_block_parallel,
                 num_threads,
@@ -859,25 +850,8 @@ extern \"C\" {} {}(",
         }
     }
 
-    /*
-     * All non reduced-over collections used in fork joins have an extra dimension.
-     * However, this is only useful if ThreadIDs run in parallel not serially,
-     * otherwise it's unnecessarily consuming shared memory. This function returns
-     * the set of collections that have an unnecessary extra dimension.
-     */
-    fn get_extra_dim_collects(
-        &self,
-        fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
-        fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
-    ) -> HashSet<TypeID> {
-        // Determine which fork each collection is used in, and check if it's
-        // parallelized via the fork_thread_quota_map.
-        todo!()
-    }
-
     fn codegen_data_control_no_forks(
         &self,
-        extra_dim_collects: &HashSet<TypeID>,
         dynamic_shared_offset: &mut String,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
     ) -> Result<(), Error> {
@@ -901,7 +875,6 @@ extern \"C\" {} {}(",
                         None,
                         None,
                         false,
-                        extra_dim_collects,
                         dynamic_shared_offset,
                         body,
                         &mut tabs,
@@ -919,7 +892,6 @@ extern \"C\" {} {}(",
         block_fork: Option<NodeID>,
         thread_root_forks: &HashSet<NodeID>,
         fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
-        extra_dim_collects: &HashSet<TypeID>,
         dynamic_shared_offset: &mut String,
         is_block_parallel: bool,
         num_threads: usize,
@@ -945,7 +917,6 @@ extern \"C\" {} {}(",
                     None,
                     None,
                     false,
-                    extra_dim_collects,
                     dynamic_shared_offset,
                     body,
                     &mut tabs,
@@ -979,7 +950,6 @@ extern \"C\" {} {}(",
                         None,
                         Some(block_fork.unwrap()),
                         false,
-                        extra_dim_collects,
                         dynamic_shared_offset,
                         body,
                         &mut tabs,
@@ -996,7 +966,6 @@ extern \"C\" {} {}(",
                 fork_thread_quota_map,
                 1,
                 num_threads,
-                extra_dim_collects,
                 dynamic_shared_offset,
                 gotos,
             )?;
@@ -1017,7 +986,6 @@ extern \"C\" {} {}(",
         fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
         parent_quota: usize,
         num_threads: usize,
-        extra_dim_collections: &HashSet<TypeID>,
         dynamic_shared_offset: &mut String,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
     ) -> Result<(), Error> {
@@ -1068,7 +1036,6 @@ extern \"C\" {} {}(",
                     parallel_factor,
                     Some(curr_fork),
                     reducts.contains(data),
-                    extra_dim_collections,
                     dynamic_shared_offset,
                     body,
                     &mut tabs,
@@ -1082,7 +1049,6 @@ extern \"C\" {} {}(",
                 fork_thread_quota_map,
                 use_thread_quota,
                 num_threads,
-                extra_dim_collections,
                 dynamic_shared_offset,
                 gotos,
             )?;
@@ -1099,7 +1065,6 @@ extern \"C\" {} {}(",
         parallel_factor: Option<usize>,
         nesting_fork: Option<NodeID>,
         is_special_reduct: bool,
-        extra_dim_collects: &HashSet<TypeID>,
         dynamic_shared_offset: &mut String,
         w: &mut String,
         num_tabs: &mut usize,
@@ -1206,7 +1171,6 @@ extern \"C\" {} {}(",
                         define_variable.clone(),
                         *cons_id,
                         true,
-                        Some(extra_dim_collects),
                         dynamic_shared_offset,
                         w,
                         *num_tabs,
@@ -1232,8 +1196,7 @@ extern \"C\" {} {}(",
                 if !is_primitive
                     && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
                 {
-                    let data_size =
-                        self.get_size(self.typing[id.idx()], None, Some(extra_dim_collects));
+                    let data_size = self.get_size(self.typing[id.idx()], None);
                     write!(
                         w,
                         "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n",
@@ -1453,8 +1416,7 @@ extern \"C\" {} {}(",
             }
             // Read of primitive requires load after pointer math.
             Node::Read { collect, indices } => {
-                let collect_with_indices =
-                    self.codegen_collect(*collect, indices, extra_dim_collects);
+                let collect_with_indices = self.codegen_collect(*collect, indices);
                 let data_type_id = self.typing[id.idx()];
                 if self.types[data_type_id.idx()].is_primitive() {
                     let type_name = self.get_type(data_type_id, true);
@@ -1478,8 +1440,7 @@ extern \"C\" {} {}(",
                 data,
                 indices,
             } => {
-                let collect_with_indices =
-                    self.codegen_collect(*collect, indices, extra_dim_collects);
+                let collect_with_indices = self.codegen_collect(*collect, indices);
                 let data_variable = self.get_value(*data, false, false);
                 let data_type_id = self.typing[data.idx()];
                 let cg_tile = match state {
@@ -1498,7 +1459,7 @@ extern \"C\" {} {}(",
                     )?;
                     write!(w, "{}}}\n", tabs)?;
                 } else {
-                    let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
+                    let data_size = self.get_size(data_type_id, None);
                     write!(
                         w,
                         "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n",
@@ -1754,21 +1715,13 @@ extern \"C\" {} {}(",
      * This function emits collection name + pointer math for the provided indices.
      * All collection types use char pointers.
      */
-    fn codegen_collect(
-        &self,
-        collect: NodeID,
-        indices: &[Index],
-        extra_dim_collects: &HashSet<TypeID>,
-    ) -> String {
+    fn codegen_collect(&self, collect: NodeID, indices: &[Index]) -> String {
         let mut index_ptr = "0".to_string();
         let mut type_id = self.typing[collect.idx()];
         for index in indices {
             match index {
                 Index::Field(field) => {
-                    index_ptr.push_str(&format!(
-                        " + ({})",
-                        self.get_size(type_id, Some(*field), Some(extra_dim_collects))
-                    ));
+                    index_ptr.push_str(&format!(" + ({})", self.get_size(type_id, Some(*field))));
                     type_id = if let Type::Product(fields) = &self.types[type_id.idx()] {
                         fields[*field]
                     } else {
@@ -1786,10 +1739,6 @@ extern \"C\" {} {}(",
                 // Convert multi-d array index to 1-d index, and optionally
                 // convert to single-byte index by multiplying by element size
                 Index::Position(array_indices) => {
-                    let has_extra_dim = extra_dim_collects.contains(&type_id);
-                    if has_extra_dim {
-                        continue;
-                    }
                     let Type::Array(element_type, extents) = &self.types[type_id.idx()] else {
                         panic!("Expected array type")
                     };
@@ -1812,7 +1761,7 @@ extern \"C\" {} {}(",
                         cumulative_offset,
                         ")".repeat(array_indices.len())
                     ));
-                    let element_size = self.get_size(*element_type, None, Some(extra_dim_collects));
+                    let element_size = self.get_size(*element_type, None);
                     index_ptr.push_str(&format!(" * ({})", element_size));
                     type_id = *element_type;
                 }
@@ -1838,7 +1787,6 @@ extern \"C\" {} {}(",
         name: String,
         cons_id: ConstantID,
         allow_allocate: bool,
-        extra_dim_collects: Option<&HashSet<TypeID>>,
         dynamic_shared_offset: &mut String,
         w: &mut String,
         num_tabs: usize,
@@ -1863,7 +1811,7 @@ extern \"C\" {} {}(",
             Constant::Product(type_id, constant_fields) => {
                 if allow_allocate {
                     let alignment = self.get_alignment(*type_id);
-                    let size = self.get_size(*type_id, None, extra_dim_collects);
+                    let size = self.get_size(*type_id, None);
                     *dynamic_shared_offset = format!(
                         "(({} + {} - 1) / {}) * {}",
                         dynamic_shared_offset, alignment, alignment, alignment
@@ -1885,7 +1833,7 @@ extern \"C\" {} {}(",
                 };
                 for i in 0..constant_fields.len() {
                     // For each field update offset and issue recursive call
-                    let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects);
+                    let offset = self.get_size(type_fields[i], Some(i));
                     let field_constant = &self.constants[constant_fields[i].idx()];
                     if field_constant.is_scalar() {
                         let field_type = self.get_type(type_fields[i], true);
@@ -1893,7 +1841,6 @@ extern \"C\" {} {}(",
                             format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset),
                             constant_fields[i],
                             false,
-                            None,
                             dynamic_shared_offset,
                             w,
                             num_tabs,
@@ -1903,7 +1850,6 @@ extern \"C\" {} {}(",
                             format!("{}+{}", name, offset),
                             constant_fields[i],
                             false,
-                            extra_dim_collects,
                             dynamic_shared_offset,
                             w,
                             num_tabs,
@@ -1914,7 +1860,7 @@ extern \"C\" {} {}(",
             Constant::Summation(type_id, variant, field) => {
                 if allow_allocate {
                     let alignment = self.get_alignment(*type_id);
-                    let size = self.get_size(*type_id, None, extra_dim_collects);
+                    let size = self.get_size(*type_id, None);
                     *dynamic_shared_offset = format!(
                         "(({} + {} - 1) / {}) * {}",
                         dynamic_shared_offset, alignment, alignment, alignment
@@ -1943,21 +1889,12 @@ extern \"C\" {} {}(",
                         format!("*reinterpret_cast<{}>({})", variant_type, name),
                         *field,
                         false,
-                        extra_dim_collects,
                         dynamic_shared_offset,
                         w,
                         num_tabs,
                     )?;
                 } else if !variant_constant.is_array() {
-                    self.codegen_constant(
-                        name,
-                        *field,
-                        false,
-                        extra_dim_collects,
-                        dynamic_shared_offset,
-                        w,
-                        num_tabs,
-                    )?;
+                    self.codegen_constant(name, *field, false, dynamic_shared_offset, w, num_tabs)?;
                 };
             }
             Constant::Array(type_id) => {
@@ -1965,7 +1902,7 @@ extern \"C\" {} {}(",
                     panic!("Nested array constant should not be re-allocated");
                 }
                 let alignment = self.get_alignment(*type_id);
-                let size = self.get_size(*type_id, None, extra_dim_collects);
+                let size = self.get_size(*type_id, None);
                 *dynamic_shared_offset = format!(
                     "(({} + {} - 1) / {}) * {}",
                     dynamic_shared_offset, alignment, alignment, alignment
@@ -1992,39 +1929,19 @@ extern \"C\" {} {}(",
      * and offset to 2nd field. This is useful for constant initialization and read/write
      * index math.
      */
-    fn get_size(
-        &self,
-        type_id: TypeID,
-        num_fields: Option<usize>,
-        extra_dim_collects: Option<&HashSet<TypeID>>,
-    ) -> String {
+    fn get_size(&self, type_id: TypeID, num_fields: Option<usize>) -> String {
         match &self.types[type_id.idx()] {
             Type::Array(element_type, extents) => {
                 assert!(num_fields.is_none());
-                let array_size = if extra_dim_collects.is_some()
-                    && extra_dim_collects.unwrap().contains(&type_id)
-                {
-                    "1".to_string()
-                } else {
-                    multiply_dcs(extents)
-                };
-                format!(
-                    "{} * {}",
-                    self.get_size(*element_type, None, extra_dim_collects),
-                    array_size
-                )
+                let array_size = multiply_dcs(extents);
+                format!("{} * {}", self.get_size(*element_type, None), array_size)
             }
             Type::Product(fields) => {
                 let num_fields = num_fields.unwrap_or(fields.len());
                 fields
                     .iter()
                     .take(num_fields)
-                    .map(|id| {
-                        (
-                            self.get_size(*id, None, extra_dim_collects),
-                            self.get_alignment(*id),
-                        )
-                    })
+                    .map(|id| (self.get_size(*id, None), self.get_alignment(*id)))
                     .fold(String::from("0"), |acc, (size, align)| {
                         if acc == "0" {
                             size
@@ -2041,16 +1958,16 @@ extern \"C\" {} {}(",
                 // The argmax variant by size is not guaranteed to be same as
                 // argmax variant by alignment, eg product of 3 4-byte primitives
                 // vs 1 8-byte primitive, so we need to calculate both.
-                let max_size = variants
-                    .iter()
-                    .map(|id| self.get_size(*id, None, extra_dim_collects))
-                    .fold(String::from("0"), |acc, x| {
+                let max_size = variants.iter().map(|id| self.get_size(*id, None)).fold(
+                    String::from("0"),
+                    |acc, x| {
                         if acc == "0" {
                             x
                         } else {
                             format!("umax({}, {})", acc, x)
                         }
-                    });
+                    },
+                );
                 let max_alignment = variants
                     .iter()
                     .map(|id| self.get_alignment(*id))
-- 
GitLab