From d137cb23704c9e921c770ac8149fc825c4990404 Mon Sep 17 00:00:00 2001
From: prrathi <prrathi10@gmail.com>
Date: Sun, 19 Jan 2025 06:44:54 +0000
Subject: [PATCH] itm

---
 hercules_cg/src/gpu.rs         | 104 +++++++++++++++++++--------------
 hercules_opt/src/pass.rs       |   7 +--
 juno_samples/cava/Cargo.toml   |   3 +
 juno_samples/concat/Cargo.toml |   3 +
 4 files changed, 67 insertions(+), 50 deletions(-)

diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index a0281795..efd7ba4b 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -8,6 +8,9 @@ use self::hercules_ir::*;
 
 use crate::*;
 
+use std::fs::OpenOptions;
+use std::io::Write as IoWrite;
+
 /*
  * The top level function to compile a Hercules IR function into CUDA
  * kernel for execution on the GPU. We generate CUDA C textually, with a lot
@@ -157,24 +160,25 @@ pub fn gpu_codegen<W: Write>(
 
     let return_type_id = &typing[data_node_id.idx()];
     let return_type = &types[return_type_id.idx()];
-    if return_type.is_array() || return_type.is_product() || return_type.is_summation() {
+    let return_param_idx = if !return_type.is_primitive() {
         let objects = &collection_objects.objects(data_node_id);
-        if objects.len() > 1 {
-            let origin = collection_objects.origin(objects[0]);
-            if !objects
-                .iter()
-                .all(|obj| collection_objects.origin(*obj) == origin)
-            {
-                panic!(
-                    "Returned data node {} has multiple collection objects with different origins",
-                    data_node_id.idx()
-                );
-            }
-            if !matches!(origin, CollectionObjectOrigin::Parameter(..)) {
-                panic!("Returns collection object that did not originate from a parameter");
-            }
+        let origin = collection_objects.origin(objects[0]);
+        if !objects
+            .iter()
+            .all(|obj| collection_objects.origin(*obj) == origin)
+        {
+            panic!(
+                "Returned data node {} has multiple collection objects with different origins",
+                data_node_id.idx()
+            );
         }
-    }
+        let CollectionObjectOrigin::Parameter(param_idx) = origin else {
+            panic!("Returns collection object that did not originate from a parameter");
+        };
+        Some(param_idx)
+    } else {
+        None
+    };
 
     // Temporary hardcoded values
     let kernel_params = &GPUKernelParams {
@@ -192,7 +196,7 @@ pub fn gpu_codegen<W: Write>(
                 panic!("Phi's control must be a region node");
             };
             for (i, &pred) in preds.iter().enumerate() {
-                control_data_phi_map.entry(pred).or_insert(vec![]).push((data[i], NodeID::new(idx)));
+                control_data_phi_map.entry(pred).or_default().push((data[i], NodeID::new(idx)));
             }
         }
     }
@@ -215,6 +219,7 @@ pub fn gpu_codegen<W: Write>(
         reduct_reduce_map,
         control_data_phi_map,
         return_type_id,
+        return_param_idx,
     };
     ctx.codegen_function(w)
 }
@@ -241,6 +246,7 @@ struct GPUContext<'a> {
     reduct_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>,
     control_data_phi_map: &'a HashMap<NodeID, Vec<(NodeID, NodeID)>>,
     return_type_id: &'a TypeID,
+    return_param_idx: Option<usize>,
 }
 
 /*
@@ -372,7 +378,6 @@ impl GPUContext<'_> {
 #include <math_constants.h>
 #include <mma.h>
 #include <cooperative_groups.h>
-#include <cooperative_groups/memcpy_async.h>
 #include <cooperative_groups/reduce.h>
 namespace cg = cooperative_groups;
 
@@ -608,7 +613,7 @@ int main() {{
             let ret_primitive = self.types[self.return_type_id.idx()].is_primitive();
             let ret_type = self.get_type(*self.return_type_id, false);
             write!(w, "
-extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_string() }, self.function.name)?;
+extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             // The first set of parameters are dynamic constants.
             let mut first_param = true;
             for idx in 0..self.function.num_dynamic_constants {
@@ -650,6 +655,8 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
                 write!(w, "\t{} host_ret;\n", ret_type)?;
                 write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?;
                 write!(w, "\treturn host_ret;\n")?;
+            } else {
+                write!(w, "\treturn p{};\n", self.return_param_idx.unwrap())?;
             }
         }
 
@@ -1286,7 +1293,7 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
                 TernaryOperator::Select => {
                     write!(
                         w,
-                        "{}{} = {} ? {} : {};",
+                        "{}{} = {} ? {} : {};\n",
                         tabs,
                         define_variable,
                         self.get_value(*first, false, false),
@@ -1315,26 +1322,23 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
                     write!(w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", tabs, define_variable, non_reduce_arg, cg_tile, cg_op, id_type_name)?;
                 } else {
                     let ty = &self.types[id_type.idx()];
-                    let func_name = self.codegen_intrinsic(intrinsic, ty);
+                    let intrinsic = self.codegen_intrinsic(intrinsic, ty);
+                    let args = args.iter()
+                        .map(|arg| self.get_value(*arg, false, false))
+                        .collect::<Vec<_>>()
+                        .join(", ");
                     write!(
                         w,
                         "{}{} = {}({});\n",
                         tabs,
                         define_variable,
-                        func_name,
-                        self.get_value(args[0], false, false),
+                        intrinsic,
+                        args,
                     )?;
                 }
             }
-            // For read, all the cases are:
-            // 1. Reading collection from/to global to/from shared
-            // 2. Reading primitive from/to global to/from shared
-            // 3. Reading primitive from/to global to/from register
-            // 4. Reading primitive from/to shared to/from register
-            // The first three can all use cooperative groups memcpy and the last
-            // one can't. However, the C++/CUDA semantics for the last three are
-            // identical, so we differentiate the cases by data type instead of
-            // data src/dest, with only collection type using collective group.
+            // If we read collection, distribute elements among threads with cg
+            // sync after. If we read primitive, copy read on all threads.
             Node::Read { collect, indices } => {
                 let is_char = self.is_char(self.typing[collect.idx()]);
                 let collect_with_indices = self.codegen_collect(*collect, indices, is_char, extra_dim_collects.contains(&self.typing[collect.idx()]));
@@ -1347,20 +1351,27 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
                         write!(w, "{}{} = *({});\n", tabs, define_variable, collect_with_indices)?;
                     }
                 } else {
-                    let nested_fork = nesting_fork.unwrap();
+                    // Divide up "elements", which are collection size divided
+                    // by element size, among threads.
                     let cg_tile = match state {
                         KernelState::OutBlock => "grid".to_string(),
                         KernelState::InBlock => "block".to_string(),
-                        KernelState::InThread => self.get_cg_tile(nested_fork, CGType::UsePerId),
+                        KernelState::InThread => self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId),
                     };
                     let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
-                    write!(w, "{}cg::memcpy_async({}, {}, {}, {});\n", tabs, cg_tile, define_variable, collect_with_indices, data_size)?;
-                    write!(w, "{}cg::wait({});\n", tabs, cg_tile)?;
+                    let data_type = self.get_type(data_type_id, false);
+                    let num_elements = format!("(({}) / sizeof({}))", data_size, data_type.strip_suffix('*').unwrap());
+                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {} / {}.size()) {{\n", tabs, cg_tile, num_elements, num_elements, cg_tile)?;
+                    write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, define_variable, collect_with_indices)?;
+                    write!(w, "{}}}\n", tabs)?;
+                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
+                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, define_variable, cg_tile, num_elements, cg_tile, cg_tile, collect_with_indices, cg_tile, num_elements, cg_tile, cg_tile)?;
+                    write!(w, "{}}}\n", tabs)?;
+                    write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
                 }
             }
-            // For write, the cases are the same, but when using C++ dereference
-            // semantics, we need to gate the write with a thread rank check for
-            // thread safety.
+            // Write is same as read, but when writing a primitive, we need to gate with
+            // a thread rank check.
             Node::Write {
                 collect,
                 data,
@@ -1373,9 +1384,7 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
                 let cg_tile = match state {
                     KernelState::OutBlock => "grid".to_string(),
                     KernelState::InBlock => "block".to_string(),
-                    KernelState::InThread => {
-                        self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId)
-                    }
+                    KernelState::InThread => self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId),
                 };
                 if self.types[data_type_id.idx()].is_primitive() {
                     write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
@@ -1388,8 +1397,15 @@ extern \"C\" {} {}(", if ret_primitive { ret_type.clone() } else { "void".to_str
                     write!(w, "{}}}\n", tabs)?;
                 } else {
                     let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
-                    write!(w, "{}cg::memcpy_async({}, {}, {}, {});\n", tabs, cg_tile, collect_with_indices, data_variable, data_size)?;
-                    write!(w, "{}cg::wait({});\n", tabs, cg_tile)?;
+                    let data_type = self.get_type(data_type_id, false);
+                    let num_elements = format!("(({}) / sizeof({}))", data_size, data_type.strip_suffix('*').unwrap());
+                    write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {} / {}.size()) {{\n", tabs, cg_tile, num_elements, num_elements, cg_tile)?;
+                    write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, collect_with_indices, data_variable)?;
+                    write!(w, "{}}}\n", tabs)?;
+                    write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
+                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, num_elements, cg_tile, cg_tile, data_variable, cg_tile, num_elements, cg_tile, cg_tile)?;
+                    write!(w, "{}}}\n", tabs)?;
+                    write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
                 }
                 let collect_variable = self.get_value(*collect, false, false);
                 write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index dbc24016..bb70bf08 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -1079,12 +1079,6 @@ impl PassManager {
                         file.write_all(cuda_ir.as_bytes())
                             .expect("PANIC: Unable to write output CUDA IR file contents.");
 
-                        let cuda_text_path = format!("{}.cu", module_name);
-                        let mut cuda_text_file = File::create(&cuda_text_path)
-                            .expect("PANIC: Unable to open CUDA IR text file.");
-                        cuda_text_file.write_all(cuda_ir.as_bytes())
-                            .expect("PANIC: Unable to write CUDA IR text file contents.");
-
                         let mut nvcc_process = Command::new("nvcc")
                             .arg("-c")
                             .arg("-O3")
@@ -1111,6 +1105,7 @@ impl PassManager {
                         .expect("PANIC: Unable to open output Rust runtime file.");
                     file.write_all(rust_rt.as_bytes())
                         .expect("PANIC: Unable to write output Rust runtime file contents.");
+
                 }
                 Pass::Serialize(output_file) => {
                     let module_contents: Vec<u8> = postcard::to_allocvec(&self.module).unwrap();
diff --git a/juno_samples/cava/Cargo.toml b/juno_samples/cava/Cargo.toml
index ff375d80..63b6b2ac 100644
--- a/juno_samples/cava/Cargo.toml
+++ b/juno_samples/cava/Cargo.toml
@@ -8,6 +8,9 @@ edition = "2021"
 name = "juno_cava"
 path = "src/main.rs"
 
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
 [build-dependencies]
 juno_build = { path = "../../juno_build" }
 
diff --git a/juno_samples/concat/Cargo.toml b/juno_samples/concat/Cargo.toml
index 24ba1acf..f2f90237 100644
--- a/juno_samples/concat/Cargo.toml
+++ b/juno_samples/concat/Cargo.toml
@@ -8,6 +8,9 @@ edition = "2021"
 name = "juno_concat"
 path = "src/main.rs"
 
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
 [build-dependencies]
 juno_build = { path = "../../juno_build" }
 
-- 
GitLab