From bc473a6fde0f42f1f514237ce2dac6a0c023323a Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 21 Feb 2025 08:53:45 -0600
Subject: [PATCH] Multi-return for gpu

---
 hercules_cg/src/gpu.rs                        | 181 ++++++++++++------
 juno_samples/multi_return/src/gpu.sch         |   4 +
 juno_samples/multi_return/src/main.rs         |  28 ++-
 juno_samples/multi_return/src/multi_return.jn |   8 +-
 juno_samples/products/src/gpu.sch             |   1 -
 5 files changed, 155 insertions(+), 67 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index d1a31d47..453d33d5 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -150,13 +150,17 @@ pub fn gpu_codegen<W: Write>(
         }
     }
 
-    let return_parameter = if collection_objects.returned_objects().len() == 1 {
-        collection_objects
-            .origin(*collection_objects.returned_objects().first().unwrap())
-            .try_parameter()
-    } else {
-        None
-    };
+    // Tracks for each return value whether it is always the same parameter
+    // collection
+    let return_parameters = (0..function.return_types.len())
+        .map(|idx| if collection_objects.returned_objects(idx).len() == 1 {
+            collection_objects
+                .origin(*collection_objects.returned_objects(idx).first().unwrap())
+                .try_parameter()
+        } else {
+            None
+        })
+        .collect::<Vec<_>>();
 
     let kernel_params = &GPUKernelParams {
         max_num_threads: 1024,
@@ -181,7 +185,7 @@ pub fn gpu_codegen<W: Write>(
         fork_reduce_map,
         reduct_reduce_map,
         control_data_phi_map,
-        return_parameter,
+        return_parameters,
         kernel_params,
     };
     ctx.codegen_function(w)
@@ -210,7 +214,7 @@ struct GPUContext<'a> {
     fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
     control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
-    return_parameter: Option<usize>,
+    return_parameters: Vec<Option<usize>>,
     kernel_params: &'a GPUKernelParams,
 }
 
@@ -262,7 +266,9 @@ impl GPUContext<'_> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // Emit all code up to the "goto" to Start's block
         let mut top = String::new();
-        self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
+        self.codegen_kernel_preamble(&mut top)?;
+        self.codegen_return_struct(&mut top)?;
+        self.codegen_kernel_begin(&mut top)?;
         let mut dynamic_shared_offset = "0".to_string();
         self.codegen_dynamic_constants(&mut top)?;
         self.codegen_declare_data(&mut top)?;
@@ -339,10 +345,7 @@ impl GPUContext<'_> {
         Ok(())
     }
 
-    /*
-     * Emit kernel headers, signature, arguments, and dynamic shared memory declaration
-     */
-    fn codegen_kernel_begin(&self, has_ret_var: bool, w: &mut String) -> Result<(), Error> {
+    fn codegen_kernel_preamble<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         write!(
             w,
             "
@@ -366,8 +369,23 @@ namespace cg = cooperative_groups;
 #define isqrt(a) ((int)sqrtf((float)(a)))
 
 ",
-        )?;
+        )
+    }
+
+    fn codegen_return_struct<W: Write>(&self, w: &mut W) -> Result<(), Error> {
+        write!(w, "struct return_{} {{ {} }};\n",
+               self.function.name,
+               self.function.return_types.iter().enumerate()
+               .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx))
+               .collect::<Vec<_>>()
+               .join(" "),
+        )
+    }
 
+    /*
+     * Emit kernel signature, arguments, and dynamic shared memory declaration
+     */
+    fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         write!(
             w,
             "__global__ void __launch_bounds__({}) {}_gpu(",
@@ -403,11 +421,21 @@ namespace cg = cooperative_groups;
             };
             write!(w, "{} p{}", param_type, idx)?;
         }
-        if has_ret_var {
+        let ret_fields = self
+            .return_parameters
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, param)| if param.is_some() {
+                None
+            } else {
+                Some((idx, self.function.return_types[idx]))
+            })
+            .collect::<Vec<(usize, TypeID)>>();
+        if !ret_fields.is_empty() {
             if !first_param {
                 write!(w, ", ")?;
             }
-            write!(w, "void* __restrict__ ret",)?;
+            write!(w, "return_{}* __restrict__ ret", self.function.name)?;
         }
 
         // Type is char since it's simplest to use single bytes for indexing
@@ -599,17 +627,17 @@ namespace cg = cooperative_groups;
         dynamic_shared_offset: &str,
         w: &mut String,
     ) -> Result<(), Error> {
-        // The following steps are for host-side C function arguments, but we also
-        // need to pass arguments to kernel, so we keep track of the arguments here.
-        let ret_type = self.get_type(self.function.return_type, false);
         let mut pass_args = String::new();
-        write!(
-            w,
-            "
-extern \"C\" {} {}(",
-            ret_type.clone(),
-            self.function.name
-        )?;
+
+        let is_multi_return = self.function.return_types.len() != 1;
+        write!(w, "extern \"C\" ")?;
+        if is_multi_return {
+            write!(w, "void")?;
+        } else {
+            write!(w, "{}", self.get_type(self.function.return_types[0], false))?;
+        }
+        write!(w, " {}(", self.function.name)?;
+
         let mut first_param = true;
         // The first parameter is a pointer to GPU backing memory, if it's
         // needed.
@@ -641,20 +669,42 @@ extern \"C\" {} {}(",
             write!(w, "{} p{}", param_type, idx)?;
             write!(pass_args, "p{}", idx)?;
         }
+        // If the function is multi-return, the last argument is the return pointer
+        // This is a CPU pointer, we will allocate a separate pointer used for the kernel's return
+        // arguments (if any)
+        if is_multi_return {
+            if !first_param {
+                write!(w, ", ")?;
+            }
+            write!(w, "return_{}* ret_ptr", self.function.name)?;
+        }
         write!(w, ") {{\n")?;
         // For case of dynamic block count
         self.codegen_dynamic_constants(w)?;
-        let has_ret_var = self.return_parameter.is_none();
-        if has_ret_var {
-            // Allocate return parameter and lift to kernel argument
-            let ret_type_pnt = self.get_type(self.function.return_type, true);
-            write!(w, "\t{} ret;\n", ret_type_pnt)?;
+
+        let (kernel_returns, param_returns) =
+            self.return_parameters.iter().enumerate()
+                .fold((vec![], vec![]),
+                    |(mut kernel_returns, mut param_returns), (idx, param)| {
+                    if let Some(param_idx) = param {
+                        param_returns.push((idx, param_idx));
+                    } else {
+                        kernel_returns.push((idx, self.function.return_types[idx]));
+                    }
+                    (kernel_returns, param_returns)
+                });
+
+        if !kernel_returns.is_empty() {
+            // Allocate kernel return struct
+            write!(w, "\treturn_{}* ret_cuda;\n", self.function.name)?;
+            write!(w, "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n", self.function.name)?;
+            // Add the return pointer to the kernel arguments
             if !first_param {
                 write!(pass_args, ", ")?;
             }
-            write!(pass_args, "ret")?;
-            write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
+            write!(pass_args, "ret_cuda")?;
         }
+
         write!(w, "\tcudaError_t err;\n")?;
         write!(
             w,
@@ -666,18 +716,38 @@ extern \"C\" {} {}(",
             w,
             "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
         )?;
-        if has_ret_var {
-            // Copy return from device to host, whether it's primitive value or collection pointer
-            write!(w, "\t{} host_ret;\n", ret_type)?;
-            write!(
-                w,
-                "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n",
-                ret_type
-            )?;
-            write!(w, "\treturn host_ret;\n")?;
+
+        if !is_multi_return {
+            if kernel_returns.is_empty() {
+                // A single return of a parameter, we can just return it directly
+                write!(w, "\treturn p{};\n", param_returns[0].1)?;
+            } else {
+                // A single return of a value computed on the device, we create a stack allocation
+                // and retrieve the value from the device and then return it
+                write!(w, "\t return_{} ret_host;\n", self.function.name)?;
+                write!(w,
+                    "\tcudaMemcpy(&ret_host, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n",
+                    self.function.name,
+                )?;
+                write!(w, "\treturn ret_host.f0;\n")?;
+            }
         } else {
-            write!(w, "\treturn p{};\n", self.return_parameter.unwrap())?;
+            // Multi return is handle via an output pointer provided to this function
+            // If there are kernel returns then we copy those back from the device and then fill in
+            // the parameter returns
+            if !kernel_returns.is_empty() {
+                // Copy from the device directly into the output struct
+                write!(w,
+                    "\tcudaMemcpy(ret_ptr, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n",
+                    self.function.name,
+                )?;
+            }
+            for (field_idx, param_idx) in param_returns {
+                write!(w, "\tret_ptr->f{} = p{};\n", field_idx, param_idx)?;
+            }
+            write!(w, "\treturn;\n")?;
         }
+
         write!(w, "}}\n")?;
         Ok(())
     }
@@ -1710,20 +1780,17 @@ extern \"C\" {} {}(",
                 }
                 tabs
             }
-            Node::Return { control: _, data } => {
-                if self.return_parameter.is_none() {
-                    // Since we lift return into a kernel argument, we write to that
-                    // argument upon return.
-                    let return_val = self.get_value(*data, false, false);
-                    let return_type_ptr = self.get_type(self.function.return_type, true);
-                    write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
-                    write!(
-                        w_term,
-                        "\t\t*(reinterpret_cast<{}>(ret)) = {};\n",
-                        return_type_ptr, return_val
-                    )?;
-                    write!(w_term, "\t}}\n")?;
+            Node::Return { control: _, ref data } => {
+                write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
+                for (idx, (data, param)) in data.iter().zip(self.return_parameters.iter()).enumerate() {
+                    // For return values that are not identical to some parameter, we write it into
+                    // the output struct
+                    if !param.is_some() {
+                        write!(w_term, "\t\tret->f{} = {};\n", idx,
+                               self.get_value(*data, false, false))?;
+                    }
                 }
+                write!(w_term, "\t}}\n")?;
                 write!(w_term, "\treturn;\n")?;
                 1
             }
diff --git a/juno_samples/multi_return/src/gpu.sch b/juno_samples/multi_return/src/gpu.sch
index e733551d..0c0b569a 100644
--- a/juno_samples/multi_return/src/gpu.sch
+++ b/juno_samples/multi_return/src/gpu.sch
@@ -4,6 +4,10 @@ dce(*);
 
 ip-sroa(*);
 sroa(*);
+
+ip-sroa[true](rolling_sum);
+sroa[true](rolling_sum, rolling_sum_prod);
+
 dce(*);
 
 forkify(*);
diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs
index 6966e0df..0b3508a7 100644
--- a/juno_samples/multi_return/src/main.rs
+++ b/juno_samples/multi_return/src/main.rs
@@ -2,17 +2,35 @@
 
 juno_build::juno!("multi_return");
 
-use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
+use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
 
 fn main() {
     const N: usize = 32;
     let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect();
-    let a = HerculesImmBox::from(a.as_ref());
+    let arg = HerculesImmBox::from(a.as_ref());
     let mut r = runner!(rolling_sum_prod);
-    let (sums, prods) = async_std::task::block_on(async { r.run(N as u64, a.to()).await });
+    let (sums, sum, prods, prod) = async_std::task::block_on(async { r.run(N as u64, arg.to()).await });
 
-    println!("Partial Sums: {:?}", sums.as_slice::<f32>());
-    println!("Partial Prods: {:?}", prods.as_slice::<f32>());
+    let mut sums = HerculesMutBox::<f32>::from(sums);
+    let mut prods = HerculesMutBox::<f32>::from(prods);
+
+    let (expected_sums, expected_sum) = a.iter()
+        .fold((vec![0.0], 0.0), |(mut sums, sum), v| {
+            let new_sum = sum + v;
+            sums.push(new_sum);
+            (sums, new_sum)
+        });
+    let (expected_prods, expected_prod) = a.iter()
+        .fold((vec![1.0], 1.0), |(mut prods, prod), v| {
+            let new_prod = prod * v;
+            prods.push(new_prod);
+            (prods, new_prod)
+        });
+
+    assert_eq!(sum, expected_sum);
+    assert_eq!(sums.as_slice(), expected_sums.as_slice());
+    assert_eq!(prod, expected_prod);
+    assert_eq!(prods.as_slice(), expected_prods.as_slice());
 }
 
 #[test]
diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn
index 84bab015..30b5576e 100644
--- a/juno_samples/multi_return/src/multi_return.jn
+++ b/juno_samples/multi_return/src/multi_return.jn
@@ -25,8 +25,8 @@ fn rolling_prod<t: number, n: usize>(x: t[n]) -> t, t[n + 1] {
 }
 
 #[entry]
-fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32[n + 1] {
-  let rsum = rolling_sum::<_, n>(x).1;
-  let _, rprod = rolling_prod::<_, n>(x);
-  return rsum, rprod;
+fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32, f32[n + 1], f32 {
+  let (sum, rsum) = rolling_sum::<_, n>(x);
+  let prod, rprod = rolling_prod::<_, n>(x);
+  return rsum, sum, rprod, prod;
 }
diff --git a/juno_samples/products/src/gpu.sch b/juno_samples/products/src/gpu.sch
index 5ef4c479..0a734bb2 100644
--- a/juno_samples/products/src/gpu.sch
+++ b/juno_samples/products/src/gpu.sch
@@ -5,7 +5,6 @@ dce(*);
 let out = auto-outline(*);
 gpu(out.product_read);
 
-ip-sroa(*);
 sroa(*);
 reuse-products(*);
 crc(*);
-- 
GitLab