From 90a1c4af468263be6b70083f5806c200c8bc577c Mon Sep 17 00:00:00 2001
From: Praneet Rathi <prrathi10@gmail.com>
Date: Fri, 3 Jan 2025 23:18:20 -0800
Subject: [PATCH] theoreticlaly just speicla case left

---
 hercules_cg/src/gpu.rs | 91 ++++++++++++++++++++----------------------
 1 file changed, 43 insertions(+), 48 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index f2cfd9cf..122af64b 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -47,12 +47,17 @@ pub fn gpu_codegen<W: Write>(
         .map(NodeID::new)
         .collect();
 
+    // Fork reduce map should have all reduces contained in some key
     let fork_reduce_map: &mut HashMap<NodeID, Vec<NodeID>> = &mut HashMap::new();
+    // Reduct reduce map should have all non-parallel and non-associative reduces
+    // contained in some key. Unlike fork, reduct is not involved in any assertions,
+    // put it here for convenience but can move.
+    let reduct_reduce_map: &mut HashMap<NodeID, Vec<NodeID>> = &mut HashMap::new();
     for reduce_node in &reduce_nodes {
         if let Node::Reduce {
             control,
             init: _,
-            reduct: _,
+            reduct,
         } = &function.nodes[reduce_node.idx()]
         {
             match function.nodes[control.idx()] {
@@ -71,6 +76,13 @@ pub fn gpu_codegen<W: Write>(
                     panic!("Reduce's control must be a join or region node");
                 }
             }
+            if !function.schedules[reduce_node.idx()].contains(&Schedule::ParallelReduce)
+                && !function.schedules[reduce_node.idx()].contains(&Schedule::Associative) {
+                reduct_reduce_map
+                    .entry(*reduct)
+                    .or_default()
+                    .push(*reduce_node);
+            }
         }
     }
     for idx in 0..function.nodes.len() {
@@ -160,6 +172,7 @@ pub fn gpu_codegen<W: Write>(
         bbs,
         kernel_params,
         fork_reduce_map,
+        reduct_reduce_map,
         label_data_for_phi,
         return_type_id,
     };
@@ -187,6 +200,7 @@ struct GPUContext<'a> {
     bbs: &'a Vec<NodeID>,
     kernel_params: &'a GPUKernelParams,
     fork_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
+    reduct_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
     label_data_for_phi: &'a HashMap<NodeID, Vec<NodeID>>,
     return_type_id: &'a TypeID,
 }
@@ -367,7 +381,9 @@ impl GPUContext<'_> {
     }
 
     // Emit helper registers that are used throughout the kernel- alignment
-    // is for proper dynamic shared memory allocation
+    // is for proper dynamic shared memory allocation, max_variant_size is
+    // for variant selection during read/write copies since we don't keep
+    // tag (don't need and it can double summation memory usage due to alignment)
     fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
         write!(w, "\tsize_t alignment;\n")?;
         write!(w, "\tsize_t max_variant_size;\n")?;
@@ -555,45 +571,6 @@ impl GPUContext<'_> {
         }
     }
 
-    // /*
-    //  * For each parallel reduce with a reduct write, meaning it's at the end of
-    //  * a potential parallel reduction chain, we walk back to beginning of chain
-    //  * and update the write's collect to be the beginning's init.
-    //  */
-    // fn update_write_collects(&self) -> HashMap<NodeID, NodeID> {
-    //     let mut write_collect_map = HashMap::new();
-    //     let mut parallel_reduces: HashSet<NodeID> = (0..self.function.nodes.len())
-    //         .map(NodeID::new)
-    //         .filter(|&node_id| {
-    //             self.function.schedules[node_id.idx()].contains(&Schedule::ParallelReduce)
-    //         })
-    //         .collect();
-    //     for reduce in parallel_reduces.clone() {
-    //         if let Node::Reduce {
-    //             control: _,
-    //             init,
-    //             reduct,
-    //         } = &self.function.nodes[reduce.idx()]
-    //             && let Node::Write { .. } = &self.function.nodes[reduct.idx()]
-    //         {
-    //             parallel_reduces.remove(&reduce);
-    //             while parallel_reduces.contains(&init) {
-    //                 let Node::Reduce {
-    //                     control: _,
-    //                     init,
-    //                     reduct: _,
-    //                 } = &self.function.nodes[init.idx()]
-    //                 else {
-    //                     panic!("Expected reduce node");
-    //                 };
-    //                 parallel_reduces.remove(&init);
-    //             }
-    //             write_collect_map.insert(*reduct, *init);
-    //         }
-    //     }
-    //     write_collect_map
-    // }
-
     fn codegen_data_control(
         &self,
         root_forks: &Vec<NodeID>,
@@ -684,10 +661,8 @@ impl GPUContext<'_> {
                     _ => { panic!("Unsupported state for ThreadID") }
                 }
             }
-            Node::Reduce { control: _, init, reduct: _ } => {
-                let init_val = self.get_value(*init, false, false);
-                write!(w, "{}{} = {};\n", tabs, declare_variable, init_val)?;
-            }
+            // Fork initializes the reduce and reduct updates the reduce
+            Node::Reduce { control: _, init: _, reduct: _ } => {}
             // Parameters emitted at top
             Node::Parameter { index: _ } => {}
             Node::Constant { id: cons_id } => {
@@ -889,16 +864,25 @@ impl GPUContext<'_> {
             }
         }
         if let Some(phis) = self.label_data_for_phi.get(&id) {
+            let val = self.get_value(id, false, false);
             for phi in phis {
+                let phi_val = self.get_value(*phi, false, false);
                 write!(
                     w,
                     "{}{} = {};\n",
                     tabs,
-                    self.get_value(*phi, false, false),
-                    self.get_value(id, false, false),
+                    phi_val,
+                    val,
                 )?;
             }
         }
+        if let Some(reduces) = self.reduct_reduce_map.get(&id) {
+            let val = self.get_value(id, true, false);
+            for reduce in reduces {
+                let reduce_val = self.get_value(*reduce, false, false);
+                write!(w, "{}{} = {};\n", tabs, reduce_val, val)?;
+            }
+        }
         Ok(())
     }
 
@@ -937,7 +921,18 @@ impl GPUContext<'_> {
             Node::Fork {
                 control: _,
                 factors: _,
-            } => {}
+            } => {
+                // Emitting reduces before the fork allows the reduce to be 
+                // used outside of the fork.
+                for &reduce in self.fork_reduce_map.get(&id).unwrap() {
+                    let reduce_val = self.get_value(reduce, true, false);
+                    let Node::Reduce { control: _, init, reduct: _ } = &self.function.nodes[reduce.idx()] else {
+                        panic!("Expected reduce node");
+                    };
+                    let init_val = self.get_value(*init, true, false);
+                    write!(w, "{}{} = {};\n", tabs, reduce_val, init_val)?;
+                }
+            }
             Node::Join { control: _ } => {}
             Node::Return { control: _, data } => {
                 if self.types[self.typing[data.idx()].idx()].is_primitive() {
-- 
GitLab