From 58639fc70b216345e0c12c229c2454474c42b306 Mon Sep 17 00:00:00 2001
From: prrathi <prrathi10@gmail.com>
Date: Tue, 21 Jan 2025 05:50:54 +0000
Subject: [PATCH] address mr comms

---
 hercules_cg/src/gpu.rs | 43 ++++++++++++++++++------------------------
 1 file changed, 18 insertions(+), 25 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 11a91c9a..fb9526bf 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -372,7 +372,7 @@ namespace cg = cooperative_groups;
         }
         write!(
             w,
-            "char* __restrict__ ret",
+            "void* __restrict__ ret",
         )?;
 
         // Type is char since it's simplest to use single bytes for indexing
@@ -512,26 +512,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             write!(pass_args, "p{}", idx)?;
         }
         write!(w, ") {{\n")?;
-        // Add return parameter, with allocation if primitive
+        // Allocate return parameter and lift to kernel argument
         let ret_type_pnt = self.get_type(self.function.return_type, true);
         write!(w, "\t{} ret;\n", ret_type_pnt)?;
         if !first_param {
             write!(pass_args, ", ")?;
         }
         write!(pass_args, "ret")?;
-        if self.types[self.function.return_type.idx()].is_primitive() {
-            write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
-        }
+        write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
         write!(w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args)?;
         write!(w, "\tcudaDeviceSynchronize();\n")?;
-        write!(w, "\tfflush(stdout);\n")?;
-        if self.types[self.function.return_type.idx()].is_primitive() {
-            write!(w, "\t{} host_ret;\n", ret_type)?;
-            write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?;
-            write!(w, "\treturn host_ret;\n")?;
-        } else {
-            write!(w, "\treturn ret;\n")?;
-        }
+        // Copy return from device to host, whether it's primitive value or collection pointer
+        write!(w, "\t{} host_ret;\n", ret_type)?;
+        write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?;
+        write!(w, "\treturn host_ret;\n")?;
         write!(w, "}}\n")?;
         Ok(())
     }
@@ -1426,12 +1420,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             Node::Return { control: _, data } => {
                 // Since we lift originally primitive returns into a parameter,
                 // we write to that parameter upon return.
-                if self.types[self.typing[data.idx()].idx()].is_primitive() {
-                    let return_val = self.get_value(*data, false, false);
-                    write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
-                    write!(w_term, "\t\t*ret = {};\n", return_val)?;
-                    write!(w_term, "\t}}\n")?;
-                }
+                let return_val = self.get_value(*data, false, false);
+                let return_type_ptr = self.get_type(self.function.return_type, true);
+                write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
+                write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?;
+                write!(w_term, "\t}}\n")?;
                 write!(w_term, "\treturn;\n")?;
                 1
             }
@@ -1542,10 +1535,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 };
                 for i in 0..constant_fields.len() {
                     // For each field update offset and issue recursive call
-                    let field_type = self.get_type(type_fields[i], true);
                     let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects);
                     let field_constant = &self.constants[constant_fields[i].idx()];
                     if field_constant.is_scalar() {
+                        let field_type = self.get_type(type_fields[i], true);
                         self.codegen_constant(
                             format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset),
                             constant_fields[i],
@@ -1573,10 +1566,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 let Type::Summation(variants) = &self.types[type_id.idx()] else {
                     panic!("Summation constant should have summation type")
                 };
-                let variant_type =
-                    self.get_type(self.typing[variants[*variant as usize].idx()], true);
                 let variant_constant = &self.constants[field.idx()];
                 if variant_constant.is_scalar() {
+                    let variant_type = self.get_type(self.typing[variants[*variant as usize].idx()], true);
                     self.codegen_constant(
                         format!("*reinterpret_cast<{}>({})", variant_type, name),
                         *field,
@@ -1851,10 +1843,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     }
 
     fn get_type(&self, id: TypeID, make_pointer: bool) -> String {
-        if self.types[id.idx()].is_primitive() {
-            convert_type(&self.types[id.idx()], make_pointer)
+        let ty = &self.types[id.idx()];
+        if ty.is_primitive() {
+            convert_type(ty, make_pointer)
         } else {
-            "char*".to_string()
+            format!("char*{}", if make_pointer { "*" } else { "" })
         }
     }
 
-- 
GitLab