From 977f1540607e53cbc8f1299fe9ce1a114d7bf1de Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 20 Feb 2025 20:34:48 -0600
Subject: [PATCH] Fix thread_block_tiles emit in GPU backend

---
 hercules_cg/src/gpu.rs                  | 53 +++++++++++++++++++------
 juno_samples/edge_detection/src/gpu.sch |  8 +---
 2 files changed, 43 insertions(+), 18 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 6dc5d53e..931071cb 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -3,8 +3,6 @@ extern crate hercules_ir;
 
 use std::collections::{BTreeMap, HashMap, HashSet};
 use std::fmt::{Error, Write};
-use std::fs::{File, OpenOptions};
-use std::io::Write as _;
 
 use self::hercules_ir::*;
 
@@ -269,7 +267,6 @@ impl GPUContext<'_> {
         self.codegen_dynamic_constants(&mut top)?;
         self.codegen_declare_data(&mut top)?;
         self.codegen_helpers(&mut top)?;
-        self.codegen_goto_start(&mut top)?;
         write!(w, "{}", top)?;
 
         // Setup for CUDA's "goto" for control flow between basic blocks.
@@ -281,10 +278,15 @@ impl GPUContext<'_> {
                 (node_id, goto)
             })
             .collect();
+        let mut thread_block_tiles = String::new();
 
         // If there are no forks, fast forward to single-block, single-thread codegen
         let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
-            self.codegen_data_control_no_forks(&mut dynamic_shared_offset, &mut gotos)?;
+            self.codegen_data_control_no_forks(
+                &mut dynamic_shared_offset,
+                &mut thread_block_tiles,
+                &mut gotos,
+            )?;
             ("1".to_string(), "1".to_string())
         } else {
             // Create structures and determine block and thread parallelization strategy
@@ -307,12 +309,15 @@ impl GPUContext<'_> {
                 &mut dynamic_shared_offset,
                 is_block_parallel,
                 num_threads,
+                &mut thread_block_tiles,
                 &mut gotos,
             )?;
             (num_blocks, num_threads.to_string())
         };
 
         // Emit all GPU kernel code from previous steps
+        self.codegen_goto_start(&mut thread_block_tiles)?;
+        write!(w, "{}", thread_block_tiles)?;
         let mut kernel_body = String::new();
         let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
         write!(w, "\n")?;
@@ -696,7 +701,7 @@ extern \"C\" {} {}(",
         let Node::Fork { factors, .. } = &self.function.nodes[root_fork.idx()] else {
             panic!("Expected fork node");
         };
-        let reduces = &self.fork_reduce_map[root_fork];
+        let _reduces = &self.fork_reduce_map[root_fork];
         if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) {
             let fork_size = factors
                 .iter()
@@ -847,6 +852,7 @@ extern \"C\" {} {}(",
     fn codegen_data_control_no_forks(
         &self,
         dynamic_shared_offset: &mut String,
+        thread_block_tiles: &mut String,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
     ) -> Result<(), Error> {
         (0..self.function.nodes.len())
@@ -858,8 +864,16 @@ extern \"C\" {} {}(",
                 let post_init = &mut goto.post_init;
                 let body = &mut goto.body;
                 let term = &mut goto.term;
-                let mut tabs =
-                    self.codegen_control_node(control, None, None, None, init, post_init, term)?;
+                let mut tabs = self.codegen_control_node(
+                    control,
+                    None,
+                    None,
+                    None,
+                    thread_block_tiles,
+                    init,
+                    post_init,
+                    term,
+                )?;
                 for data in self.bbs.1[control.idx()].iter() {
                     self.codegen_data_node(
                         *data,
@@ -889,6 +903,7 @@ extern \"C\" {} {}(",
         dynamic_shared_offset: &mut String,
         is_block_parallel: bool,
         num_threads: usize,
+        thread_block_tiles: &mut String,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
     ) -> Result<(), Error> {
         // First emit data and control gen for each control node outside any fork.
@@ -900,8 +915,16 @@ extern \"C\" {} {}(",
             let post_init = &mut goto.post_init;
             let body = &mut goto.body;
             let term = &mut goto.term;
-            let mut tabs =
-                self.codegen_control_node(*control, None, None, None, init, post_init, term)?;
+            let mut tabs = self.codegen_control_node(
+                *control,
+                None,
+                None,
+                None,
+                thread_block_tiles,
+                init,
+                post_init,
+                term,
+            )?;
             for data in self.bbs.1[control.idx()].iter() {
                 self.codegen_data_node(
                     *data,
@@ -931,6 +954,7 @@ extern \"C\" {} {}(",
                     Some(num_threads),
                     Some(num_threads),
                     Some(1),
+                    thread_block_tiles,
                     init,
                     post_init,
                     term,
@@ -961,6 +985,7 @@ extern \"C\" {} {}(",
                 1,
                 num_threads,
                 dynamic_shared_offset,
+                thread_block_tiles,
                 gotos,
             )?;
         }
@@ -981,6 +1006,7 @@ extern \"C\" {} {}(",
         parent_quota: usize,
         num_threads: usize,
         dynamic_shared_offset: &mut String,
+        thread_block_tiles: &mut String,
         gotos: &mut BTreeMap<NodeID, CudaGoto>,
     ) -> Result<(), Error> {
         let (available_thread_quota, use_thread_quota, parallel_factor) = fork_thread_quota_map
@@ -1017,6 +1043,7 @@ extern \"C\" {} {}(",
                 Some(available_thread_quota),
                 Some(use_thread_quota),
                 parallel_factor,
+                thread_block_tiles,
                 init,
                 post_init,
                 term,
@@ -1044,6 +1071,7 @@ extern \"C\" {} {}(",
                 use_thread_quota,
                 num_threads,
                 dynamic_shared_offset,
+                thread_block_tiles,
                 gotos,
             )?;
         }
@@ -1504,6 +1532,7 @@ extern \"C\" {} {}(",
         available_thread_quota: Option<usize>,
         use_thread_quota: Option<usize>,
         parallel_factor: Option<usize>,
+        thread_block_tiles: &mut String,
         w_init: &mut String,
         w_post_init: &mut String,
         w_term: &mut String,
@@ -1579,20 +1608,20 @@ extern \"C\" {} {}(",
                         use_thread_quota
                     };
                     write!(
-                        w_init,
+                        thread_block_tiles,
                         "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
                         use_thread_per_id, cg_tile, use_thread_per_id
                     )?;
                     let cg_tile_use = self.get_cg_tile(id, CGType::Use);
                     write!(
-                        w_init,
+                        thread_block_tiles,
                         "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
                         use_thread_quota, cg_tile_use, use_thread_quota
                     )?;
                     let available_thread_quota = available_thread_quota.unwrap();
                     let cg_tile_available = self.get_cg_tile(id, CGType::Available);
                     write!(
-                        w_init,
+                        thread_block_tiles,
                         "\tcg::thread_block_tile<{}> {} = cg::tiled_partition<{}>(block);\n",
                         available_thread_quota, cg_tile_available, available_thread_quota
                     )?;
diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch
index ed414084..a3c804d5 100644
--- a/juno_samples/edge_detection/src/gpu.sch
+++ b/juno_samples/edge_detection/src/gpu.sch
@@ -62,13 +62,9 @@ fixpoint {
 simpl!(max_gradient);
 fork-dim-merge(max_gradient);
 simpl!(max_gradient);
-fork-tile[1024, 0, false, true](max_gradient);
-let out = fork-split(max_gradient);
-fork-tile[32, 0, false, true](out._4_max_gradient.fj1);
-let out = fork-split(max_gradient);
-simpl!(max_gradient);
+fork-tile[32, 0, false, true](max_gradient);
+fork-split(max_gradient);
 clean-monoid-reduces(max_gradient);
-fork-fission-bufferize[out._4_max_gradient.fj0, out._4_max_gradient.fj1](max_gradient);
 simpl!(max_gradient);
 
 no-memset(reject_zero_crossings@res);
-- 
GitLab