diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 21d284b3336b03541cac170f9a3a8edde6dc366e..11a91c9aeae5cddbdabf366e02a901c6b61d8f29 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -134,51 +134,6 @@ pub fn gpu_codegen<W: Write>( } } - // Obtain the Return node and if it's a collection, use the collection objects - // analysis to determine the origin. Also save the return node id for later - // conversion of primitive Return into Parameter. - let (_, data_node_id) = { - let pos = function - .nodes - .iter() - .position(|node| { - matches!( - node, - Node::Return { - control: _, - data: _ - } - ) - }) - .expect("Function must have a return node"); - let Node::Return { control: _, data } = &function.nodes[pos] else { - panic!("Return node must be a return node"); - }; - (NodeID::new(pos), *data) - }; - - let return_type_id = typing[data_node_id.idx()]; - let return_type = &types[return_type_id.idx()]; - let return_param_idx = if !return_type.is_primitive() { - let objects = &collection_objects.objects(data_node_id); - let origin = collection_objects.origin(objects[0]); - if !objects - .iter() - .all(|obj| collection_objects.origin(*obj) == origin) - { - panic!( - "Returned data node {} has multiple collection objects with different origins", - data_node_id.idx() - ); - } - let CollectionObjectOrigin::Parameter(param_idx) = origin else { - panic!("Returns collection object that did not originate from a parameter"); - }; - Some(param_idx) - } else { - None - }; - // Temporary hardcoded values let kernel_params = &GPUKernelParams { max_num_blocks: 1024, @@ -217,8 +172,6 @@ pub fn gpu_codegen<W: Write>( fork_reduce_map, reduct_reduce_map, control_data_phi_map, - return_type_id, - return_param_idx, }; ctx.codegen_function(w) } @@ -244,8 +197,6 @@ struct GPUContext<'a> { fork_reduce_map: HashMap<NodeID, Vec<NodeID>>, reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>, control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>, - return_type_id: TypeID, - return_param_idx: Option<usize>, } /* @@ -416,17 +367,13 @@ namespace cg = cooperative_groups; }; write!(w, "{} p{}", param_type, idx)?; } - // Pull primitive return to a pointer parameter - if self.types[self.return_type_id.idx()].is_primitive() { - if !first_param { - write!(w, ", ")?; - } - write!( - w, - "{} __restrict__ ret", - self.get_type(self.return_type_id, true) - )?; + if !first_param { + write!(w, ", ")?; } + write!( + w, + "char* __restrict__ ret", + )?; // Type is char since it's simplest to use single bytes for indexing // and it's required for heterogeneous Product and Summation types. @@ -536,9 +483,8 @@ namespace cg = cooperative_groups; fn codegen_launch_code(&self, num_blocks: usize, num_threads: usize, dynamic_shared_offset: &str, w: &mut String) -> Result<(), Error> { // The following steps are for host-side C function arguments, but we also // need to pass arguments to kernel, so we keep track of the arguments here. + let ret_type = self.get_type(self.function.return_type, false); let mut pass_args = String::new(); - let ret_primitive = self.types[self.return_type_id.idx()].is_primitive(); - let ret_type = self.get_type(self.return_type_id, false); write!(w, " extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; // The first set of parameters are dynamic constants. @@ -566,25 +512,25 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; write!(pass_args, "p{}", idx)?; } write!(w, ") {{\n")?; - // Pull primitive return as pointer parameter for kernel - if ret_primitive { - let ret_type_pnt = self.get_type(self.return_type_id, true); - write!(w, "\t{} ret;\n", ret_type_pnt)?; + // Add return parameter, with allocation if primitive + let ret_type_pnt = self.get_type(self.function.return_type, true); + write!(w, "\t{} ret;\n", ret_type_pnt)?; + if !first_param { + write!(pass_args, ", ")?; + } + write!(pass_args, "ret")?; + if self.types[self.function.return_type.idx()].is_primitive() { write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; - if !first_param { - write!(pass_args, ", ")?; - } - write!(pass_args, "ret")?; } write!(w, "\t{}_gpu<<<{}, {}, {}>>>({});\n", self.function.name, num_blocks, num_threads, dynamic_shared_offset, pass_args)?; write!(w, "\tcudaDeviceSynchronize();\n")?; write!(w, "\tfflush(stdout);\n")?; - if ret_primitive { + if self.types[self.function.return_type.idx()].is_primitive() { write!(w, "\t{} host_ret;\n", ret_type)?; write!(w, "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", ret_type)?; write!(w, "\treturn host_ret;\n")?; } else { - write!(w, "\treturn p{};\n", self.return_param_idx.unwrap())?; + write!(w, "\treturn ret;\n")?; } write!(w, "}}\n")?; Ok(())