From 58b7928b8e7cd5d197b06cddc81674dce3bf8468 Mon Sep 17 00:00:00 2001
From: prrathi <prrathi10@gmail.com>
Date: Sat, 25 Jan 2025 18:22:09 +0000
Subject: [PATCH] things work

---
 hercules_cg/src/gpu.rs | 150 +++++++++++++++++++++--------------------
 1 file changed, 78 insertions(+), 72 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 074281f8..c3cb6634 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -79,6 +79,10 @@ pub fn gpu_codegen<W: Write>(
      * - Add float8, float16, bfloat16 dtypes if they come
      */
 
+    // Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations
+    let block_parallel_override = false;
+    let thread_parallel_override = false;
+
     let reduce_nodes: Vec<NodeID> = (0..function.nodes.len())
         .filter(|idx| function.nodes[*idx].is_reduce())
         .map(NodeID::new)
@@ -109,7 +113,6 @@ pub fn gpu_codegen<W: Write>(
                         .entry(fork_node)
                         .or_default()
                         .push(*reduce_node);
-                    println!("reduce_node: {:?}, fork_node: {:?}, join: {:?}", reduce_node, fork_node, control);
                 }
                 Node::Region { preds: _ } => {
                     // TODO: map region node to fork node
@@ -118,14 +121,10 @@ pub fn gpu_codegen<W: Write>(
                     panic!("Reduce's control must be a join or region node");
                 }
             }
-            if !function.schedules[reduce_node.idx()].contains(&Schedule::ParallelReduce)
-                && !function.schedules[reduce_node.idx()].contains(&Schedule::TightAssociative)
-            {
-                reduct_reduce_map
-                    .entry(*reduct)
-                    .or_default()
-                    .push(*reduce_node);
-            }
+            reduct_reduce_map
+                .entry(*reduct)
+                .or_default()
+                .push(*reduce_node);
         }
     }
     for idx in 0..function.nodes.len() {
@@ -186,6 +185,8 @@ pub fn gpu_codegen<W: Write>(
         control_data_phi_map,
         return_parameter,
         kernel_params,
+        block_parallel_override,
+        thread_parallel_override,
     };
     ctx.codegen_function(w)
 }
@@ -214,6 +215,8 @@ struct GPUContext<'a> {
     control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
     return_parameter: Option<usize>,
     kernel_params: &'a GPUKernelParams,
+    block_parallel_override: bool,
+    thread_parallel_override: bool,
 }
 
 /*
@@ -298,6 +301,7 @@ impl GPUContext<'_> {
             let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel);
             writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap();
             let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
+            writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap();
             // TODO: Uncomment and adjust once we know logic of extra dim
             // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
             let extra_dim_collects = HashSet::new();
@@ -586,7 +590,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 self.collection_objects.origin(*object).try_parameter().is_some()
             })
         }), "All collection reduces in block fork must originate from parameters");
-        if true || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
+        if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
         {
             let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
             (root_forks, fork_size, true)
@@ -693,7 +697,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             && fork_size.is_power_of_two()
             && reduces.iter().all(|&reduce| {
                 self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
-                    || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                    || self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
             })
         {
             // If there's an associative Reduce, parallelize the larger factor
@@ -706,7 +710,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             // restriction doesn't help for parallel Writes, so nested parallelization
             // is possible.
             if reduces.iter().any(|&reduce| {
-                self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
             }) || fork_size > self.kernel_params.max_num_threads / subtree_quota {
                 if fork_size >= subtree_quota {
                     (HashMap::new(), fork_size, true)
@@ -1069,8 +1073,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 }
             },
             Node::Binary { op, left, right } => {
-                let left_val = self.get_value(*left, false, false);
-                let right_val = self.get_value(*right, false, false);
+                let mut left_val = self.get_value(*left, false, false);
+                let mut right_val = self.get_value(*right, false, false);
                 let id_type = self.typing[id.idx()];
                 if matches!(op, BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And
                     | BinaryOperator::Xor) && is_special_reduct {
@@ -1079,14 +1083,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     // supported. We need to use CGType::Use not CGType::UsePerId
                     // because for parallelized reduction we only have one thread
                     // per ThreadID and the reduction is over Use, not UsePerId.
-                    let non_reduce_arg = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] {
-                        right_val
+                    let (reduce_val, non_reduce_val) = if let Node::Reduce { control: _, init: _, reduct: _ } = &self.function.nodes[left.idx()] {
+                        (left_val, right_val)
                     } else {
-                        left_val
+                        (right_val, left_val)
                     };
                     // Special reduct is only enabled for thread parallelization
                     // so don't need to worry about grid and block cases
-                    let cg_tile = self.get_cg_tile(id, CGType::Use);
+                    let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use);
                     #[allow(unreachable_patterns)]
                     let cg_op = match op {
                         BinaryOperator::Add => "plus",
@@ -1096,57 +1100,60 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                         _ => unreachable!(),
                     };
                     let id_type_name = self.get_type(id_type, false);
-                    write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_arg, cg_op, id_type_name)?;
-                } else {
-                    match (op, &self.types[id_type.idx()]) {
-                        (BinaryOperator::Or, Type::Boolean) => write!(
-                            w,
-                            "{}{} = {} || {};\n",
-                            tabs, define_variable, left_val, right_val,
-                        )?,
-                        (BinaryOperator::And, Type::Boolean) => write!(
-                            w,
-                            "{}{} = {} && {};\n",
-                            tabs, define_variable, left_val, right_val,
-                        )?,
-                        (BinaryOperator::Rem, Type::Float32) => write!(
-                            w,
-                            "{}{} = fmodf({}, {});\n",
-                            tabs, define_variable, left_val, right_val,
-                        )?,
-                        (BinaryOperator::Rem, Type::Float64) => write!(
-                            w,
-                            "{}{} = fmod({}, {});\n",
-                            tabs, define_variable, left_val, right_val,
-                        )?,
-                        (op, _) => write!(
-                            w,
-                            "{}{} = {} {} {};\n",
-                            tabs,
-                            define_variable,
-                            left_val,
-                            match op {
-                                BinaryOperator::Add => "+",
-                                BinaryOperator::Sub => "-",
-                                BinaryOperator::Mul => "*",
-                                BinaryOperator::Div => "/",
-                                BinaryOperator::Rem => "%",
-                                BinaryOperator::LT => "<",
-                                BinaryOperator::LTE => "<=",
-                                BinaryOperator::GT => ">",
-                                BinaryOperator::GTE => ">=",
-                                BinaryOperator::EQ => "==",
-                                BinaryOperator::NE => "!=",
-                                BinaryOperator::Or => "|",
-                                BinaryOperator::And => "&",
-                                BinaryOperator::Xor => "^",
-                                BinaryOperator::LSh => "<<",
-                                BinaryOperator::RSh => ">>",
-                            },
-                            right_val,
-                        )?,
-                    };
+                    write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name)?;
+                    // Setup binop between reduce's init and reduced reduct. Since it's associative,
+                    // we can change binop ordering
+                    left_val = define_variable.clone();
+                    right_val = reduce_val;
                 }
+                match (op, &self.types[id_type.idx()]) {
+                    (BinaryOperator::Or, Type::Boolean) => write!(
+                        w,
+                        "{}{} = {} || {};\n",
+                        tabs, define_variable, left_val, right_val,
+                    )?,
+                    (BinaryOperator::And, Type::Boolean) => write!(
+                        w,
+                        "{}{} = {} && {};\n",
+                        tabs, define_variable, left_val, right_val,
+                    )?,
+                    (BinaryOperator::Rem, Type::Float32) => write!(
+                        w,
+                        "{}{} = fmodf({}, {});\n",
+                        tabs, define_variable, left_val, right_val,
+                    )?,
+                    (BinaryOperator::Rem, Type::Float64) => write!(
+                        w,
+                        "{}{} = fmod({}, {});\n",
+                        tabs, define_variable, left_val, right_val,
+                    )?,
+                    (op, _) => write!(
+                        w,
+                        "{}{} = {} {} {};\n",
+                        tabs,
+                        define_variable,
+                        left_val,
+                        match op {
+                            BinaryOperator::Add => "+",
+                            BinaryOperator::Sub => "-",
+                            BinaryOperator::Mul => "*",
+                            BinaryOperator::Div => "/",
+                            BinaryOperator::Rem => "%",
+                            BinaryOperator::LT => "<",
+                            BinaryOperator::LTE => "<=",
+                            BinaryOperator::GT => ">",
+                            BinaryOperator::GTE => ">=",
+                            BinaryOperator::EQ => "==",
+                            BinaryOperator::NE => "!=",
+                            BinaryOperator::Or => "|",
+                            BinaryOperator::And => "&",
+                            BinaryOperator::Xor => "^",
+                            BinaryOperator::LSh => "<<",
+                            BinaryOperator::RSh => ">>",
+                        },
+                        right_val,
+                    )?,
+                };
             }
             Node::Ternary {
                 op,
@@ -1175,7 +1182,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     } else {
                         self.get_value(args[0], false, false)
                     };
-                    let cg_tile = self.get_cg_tile(id, CGType::Use);
+                    let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use);
                     #[allow(unreachable_patterns)]
                     let cg_op = match intrinsic {
                         Intrinsic::Max => "max",
@@ -1248,8 +1255,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
             }
             _ => {
-                println!("Unsupported data node type: {:?}", self.function.nodes[id.idx()]);
-                panic!("Unsupported data node type")
+                panic!("Unsupported data node type: {:?}", self.function.nodes[id.idx()])
             }
         }
         // Since reducts are responsible for updating Reduce nodes, we check and
@@ -1404,8 +1410,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
                     write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?;
                     write!(w_term, "\t}}\n")?;
-                    write!(w_term, "\treturn;\n")?;
                 }
+                write!(w_term, "\treturn;\n")?;
                 1
             }
             _ => {
-- 
GitLab