From d391d16354863a51c1d4ee8e3de4bce44b54402e Mon Sep 17 00:00:00 2001
From: prrathi <prrathi10@gmail.com>
Date: Wed, 29 Jan 2025 16:25:22 +0000
Subject: [PATCH] mm dot works

---
 hercules_cg/src/gpu.rs              | 29 ++++++-----------------------
 hercules_samples/matmul/src/main.rs | 12 +++++-------
 2 files changed, 11 insertions(+), 30 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index c3cb6634..ce52a20e 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -79,10 +79,6 @@ pub fn gpu_codegen<W: Write>(
      * - Add float8, float16, bfloat16 dtypes if they come
      */
 
-    // Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations
-    let block_parallel_override = false;
-    let thread_parallel_override = false;
-
     let reduce_nodes: Vec<NodeID> = (0..function.nodes.len())
         .filter(|idx| function.nodes[*idx].is_reduce())
         .map(NodeID::new)
@@ -164,8 +160,6 @@ pub fn gpu_codegen<W: Write>(
         threads_per_warp: 32,
     };
 
-    std::fs::write("out.txt", "debug\n\n").unwrap();
-
     let ctx = GPUContext {
         function,
         types,
@@ -185,8 +179,6 @@ pub fn gpu_codegen<W: Write>(
         control_data_phi_map,
         return_parameter,
         kernel_params,
-        block_parallel_override,
-        thread_parallel_override,
     };
     ctx.codegen_function(w)
 }
@@ -215,8 +207,6 @@ struct GPUContext<'a> {
     control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
     return_parameter: Option<usize>,
     kernel_params: &'a GPUKernelParams,
-    block_parallel_override: bool,
-    thread_parallel_override: bool,
 }
 
 /*
@@ -265,8 +255,6 @@ enum CGType {
 
 impl GPUContext<'_> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
-        let mut file = OpenOptions::new().append(true).open("out.txt").unwrap();
-
         // Emit all code up to the "goto" to Start's block
         let mut top = String::new();
         self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
@@ -289,19 +277,14 @@ impl GPUContext<'_> {
 
         // If there are no forks, fast forward to single-block, single-thread codegen
         let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
-            writeln!(file, "shortcut to 1b1t").unwrap();
             self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?;
             ("1".to_string(), "1".to_string())
         } else {
-            writeln!(file, "no shortcut! fork tree: {:?}", self.fork_tree).unwrap();
             // Create structures and determine block and thread parallelization strategy
             let (root_forks, num_blocks, is_block_parallel) =
-                self.get_root_forks_and_num_blocks(&self.fork_tree);
-            writeln!(file, "is_block_parallel: {}", is_block_parallel).unwrap();
-            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel);
-            writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap();
-            let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
-            writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap();
+                self.get_root_forks_and_num_blocks(self.fork_tree);
+            let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
+            let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
             // TODO: Uncomment and adjust once we know logic of extra dim
             // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
             let extra_dim_collects = HashSet::new();
@@ -590,7 +573,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 self.collection_objects.origin(*object).try_parameter().is_some()
             })
         }), "All collection reduces in block fork must originate from parameters");
-        if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
+        if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
         {
             let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
             (root_forks, fork_size, true)
@@ -697,7 +680,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             && fork_size.is_power_of_two()
             && reduces.iter().all(|&reduce| {
                 self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
-                    || self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                    || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
             })
         {
             // If there's an associative Reduce, parallelize the larger factor
@@ -710,7 +693,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             // restriction doesn't help for parallel Writes, so nested parallelization
             // is possible.
             if reduces.iter().any(|&reduce| {
-                self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
             }) || fork_size > self.kernel_params.max_num_threads / subtree_quota {
                 if fork_size >= subtree_quota {
                     (HashMap::new(), fork_size, true)
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 9421c773..7b6cfe79 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -10,13 +10,11 @@ juno_build::juno!("matmul");
 
 fn main() {
     async_std::task::block_on(async {
-        const I: usize = 4;
-        const J: usize = 2;
-        const K: usize = 8;
-        // let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
-        // let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
-        let mut a: Box<[i32]> = (0..I * J).map(|i| (i as i32) % 100).collect();
-        let mut b: Box<[i32]> = (0..J * K).map(|i| (i as i32) % 100).collect();
+        const I: usize = 256;
+        const J: usize = 8; // hardcoded constant in matmul.hir
+        const K: usize = 128;
+        let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
+        let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
         let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
         for i in 0..I {
             for k in 0..K {
-- 
GitLab