From 58639fc70b216345e0c12c229c2454474c42b306 Mon Sep 17 00:00:00 2001 From: prrathi <prrathi10@gmail.com> Date: Tue, 21 Jan 2025 05:50:54 +0000 Subject: [PATCH] address mr comms --- hercules_cg/src/gpu.rs | 43 ++++++++++++++++++------------------------ 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 11a91c9a..fb9526bf 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -372,7 +372,7 @@ namespace cg = cooperative_groups; } write!( w, - "char* __restrict__ ret", + "void* __restrict__ ret", )?; // Type is char since it's simplest to use single bytes for indexing @@ -512,26 +512,20 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; write!(pass_args, "p{}", idx)?; } write!(w, ") {{\n")?; - // Add return parameter, with allocation if primitive + // Allocate return parameter and lift to kernel argument let ret_type_pnt = self.get_type(self.function.return_type, true); write!(w, "\t{} ret;\n", ret_type_pnt)?; if !first_param { write!(pass_args, ", ")?; } write!(pass_args, "ret")?; - if self.types[self.function.return_type.idx()].is_primitive() { - write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; - } + write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; write!(w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args)?; write!(w, "\tcudaDeviceSynchronize();\n")?; - write!(w, "\tfflush(stdout);\n")?; - if self.types[self.function.return_type.idx()].is_primitive() { - write!(w, "\t{} host_ret;\n", ret_type)?; - write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?; - write!(w, "\treturn host_ret;\n")?; - } else { - write!(w, "\treturn ret;\n")?; - } + // Copy return from device to host, whether it's primitive value or collection pointer + write!(w, "\t{} host_ret;\n", ret_type)?; + write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?; + write!(w, "\treturn host_ret;\n")?; write!(w, "}}\n")?; Ok(()) } @@ -1426,12 +1420,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; Node::Return { control: _, data } => { // Since we lift originally primitive returns into a parameter, // we write to that parameter upon return. - if self.types[self.typing[data.idx()].idx()].is_primitive() { - let return_val = self.get_value(*data, false, false); - write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; - write!(w_term, "\t\t*ret = {};\n", return_val)?; - write!(w_term, "\t}}\n")?; - } + let return_val = self.get_value(*data, false, false); + let return_type_ptr = self.get_type(self.function.return_type, true); + write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; + write!(w_term, "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", return_type_ptr, return_val)?; + write!(w_term, "\t}}\n")?; write!(w_term, "\treturn;\n")?; 1 } @@ -1542,10 +1535,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; }; for i in 0..constant_fields.len() { // For each field update offset and issue recursive call - let field_type = self.get_type(type_fields[i], true); let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects); let field_constant = &self.constants[constant_fields[i].idx()]; if field_constant.is_scalar() { + let field_type = self.get_type(type_fields[i], true); self.codegen_constant( format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset), constant_fields[i], @@ -1573,10 +1566,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; let Type::Summation(variants) = &self.types[type_id.idx()] else { panic!("Summation constant should have summation type") }; - let variant_type = - self.get_type(self.typing[variants[*variant as usize].idx()], true); let variant_constant = &self.constants[field.idx()]; if variant_constant.is_scalar() { + let variant_type = self.get_type(self.typing[variants[*variant as usize].idx()], true); self.codegen_constant( format!("*reinterpret_cast<{}>({})", variant_type, name), *field, @@ -1851,10 +1843,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } fn get_type(&self, id: TypeID, make_pointer: bool) -> String { - if self.types[id.idx()].is_primitive() { - convert_type(&self.types[id.idx()], make_pointer) + let ty = &self.types[id.idx()]; + if ty.is_primitive() { + convert_type(ty, make_pointer) } else { - "char*".to_string() + format!("char*{}", if make_pointer { "*" } else { "" }) } } -- GitLab