From ee408623f14eb01fe405642f73e7b93fe894fee6 Mon Sep 17 00:00:00 2001
From: Praneet Rathi <prrathi10@gmail.com>
Date: Fri, 17 Jan 2025 11:20:53 -0600
Subject: [PATCH] sm

---
 hercules_cg/src/gpu.rs | 33 +++++++++++++++++++--------------
 1 file changed, 19 insertions(+), 14 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 37f0cd31..f51479a7 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -291,9 +291,13 @@ enum CGType {
 
 impl GPUContext<'_> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
+        // If run_debug, wrapping C host code is self-contained with malloc, etc,
+        // else it only does kernel launch.
+        let run_debug = false;
+
         // Emit all code up to the "goto" to Start's block
         let mut top = String::new();
-        self.codegen_kernel_begin(&mut top)?;
+        self.codegen_kernel_begin(run_debug, &mut top)?;
         let mut dynamic_shared_offset = "0".to_string();
         self.codegen_dynamic_constants(&mut top)?;
         self.codegen_declare_data(&mut top)?;
@@ -353,14 +357,14 @@ impl GPUContext<'_> {
 
         // Emit host launch code
         let mut host_launch = String::new();
-        self.codegen_launch_code(false, num_blocks, num_threads, &dynamic_shared_offset, &mut host_launch)?;
+        self.codegen_launch_code(run_debug, num_blocks, num_threads, &dynamic_shared_offset, &mut host_launch)?;
         write!(w, "{}", host_launch)?;
 
         Ok(())
     }
 
     // Emit kernel headers, signature, arguments, and dynamic shared memory declaration
-    fn codegen_kernel_begin(&self, w: &mut String) -> Result<(), Error> {
+    fn codegen_kernel_begin(&self, run_debug: bool, w: &mut String) -> Result<(), Error> {
         write!(w, "
 #include <assert.h>
 #include <stdio.h>
@@ -385,8 +389,8 @@ namespace cg = cooperative_groups;
 
         write!(
             w,
-            "__global__ void __launch_bounds__({}) {}(",
-            self.kernel_params.max_num_threads, self.function.name
+            "__global__ void __launch_bounds__({}) {}{}(",
+            self.kernel_params.max_num_threads, self.function.name, if run_debug { "" } else { "_gpu" }
         )?;
         // The first set of parameters are dynamic constants.
         let mut first_param = true;
@@ -519,13 +523,12 @@ namespace cg = cooperative_groups;
     }
 
     fn codegen_launch_code(&self, run_debug: bool, num_blocks: usize, num_threads: usize, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> {
-        write!(w, "
-int main(")?;
         // The following steps are for host-side C function arguments, but we also
         // need to pass arguments to kernel, so we keep track of the arguments here.
         let mut pass_args = String::new();
         if run_debug {
-            write!(w, ") {{
+            write!(w, "
+int main() {{
 ")?;
             // The first set of parameters are dynamic constants.
             let mut first_param = true;
@@ -588,13 +591,13 @@ int main(")?;
                 }
             }
             if self.types[self.return_type_id.idx()].is_primitive() {
-                write!(w, "\tcudaFree(ret);\n");
+                write!(w, "\tcudaFree(ret);\n")?;
             }
-            write!(w, "\treturn 0;\n");
-            write!(w, "}}\n");
         }
 
         else {
+            write!(w, "
+extern \"C\" int {}(", self.function.name)?;
             // The first set of parameters are dynamic constants.
             let mut first_param = true;
             for idx in 0..self.function.num_dynamic_constants {
@@ -627,11 +630,13 @@ int main(")?;
                 write!(w, "{} ret", ret_type)?;
                 write!(pass_args, "ret")?;
             }
-            write!(w, ") {{
-    {}<<<{}, {}, {}>>>({});
-}}", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args);
+            write!(w, ") {{\n")?;
+            write!(w, "\t{}<<<{}_gpu, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args)?;
+            write!(w, "\tcudaDeviceSynchronize();\n")?;
         }
 
+        write!(w, "\treturn 0;\n")?;
+        write!(w, "}}\n")?;
         Ok(())
     }
 
-- 
GitLab