diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 45c0f46701c24560b1a70a2602738c08dc21dffa..509b98c54644992a65714d58d674d4866476f4ae 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -88,6 +88,11 @@ impl<'a> CPUContext<'a> { .collect::<Vec<_>>() .join(", "), )?; + write!( + w, + "define dso_local void @{}(", + self.function.name, + )?; } let mut first_param = true; // The first parameter is a pointer to CPU backing memory, if it's diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 26ca9d41dd4017af250aa0c297931f6eefdfa8d8..8a41b08d42355fd6ea4d33b769f227e801c57069 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -211,7 +211,7 @@ impl<'a> RTContext<'a> { if is_single_return { write!(w, "extern \"C\" {{")?; } - self.write_device_signature_async(w, *callee_id)?; + self.write_device_signature_async(w, *callee_id, !is_single_return)?; if is_single_return { write!(w, ";}}")?; } else { @@ -1200,9 +1200,9 @@ impl<'a> RTContext<'a> { // references. let func = self.get_func(); let param_devices = &self.node_colors.1; - let return_device = self.node_colors.2; + let return_devices = &self.node_colors.2; let mut param_muts = vec![false; func.param_types.len()]; - let mut return_mut = true; + let mut return_muts = vec![true; func.return_types.len()]; let objects = &self.collection_objects[&self.func_id]; for idx in 0..func.param_types.len() { if let Some(object) = objects.param_to_object(idx) @@ -1211,11 +1211,14 @@ impl<'a> RTContext<'a> { param_muts[idx] = true; } } - for object in objects.returned_objects() { - if let Some(idx) = objects.origin(*object).try_parameter() - && !param_muts[idx] - { - return_mut = false; + let num_returns = func.return_types.len(); + for idx in 0..num_returns { + for object in objects.returned_objects(idx) { + if let Some(param_idx) = objects.origin(*object).try_parameter() + && !param_muts[param_idx] + { + return_muts[idx] = false; + } } } @@ -1245,27 +1248,38 @@ impl<'a> RTContext<'a> { } write!(w, "}}}}")?; - // Every reference that may be returned has the same lifetime. Every - // other reference gets its own unique lifetime. - let returned_origins: HashSet<_> = self.collection_objects[&self.func_id] - .returned_objects() - .into_iter() - .map(|obj| self.collection_objects[&self.func_id].origin(*obj)) - .collect(); - - write!(w, "async fn run<'runner, 'returned")?; - for idx in 0..func.param_types.len() { - write!(w, ", 'p{}", idx)?; + // Each returned reference, input reference, and the runner will have + // its own lifetime. We use lifetime bounds to ensure that the runner + // and parameters are borrowed for the lifetimes needed by the outputs + let returned_origins: Vec<HashSet<_>> = (0..num_returns) + .map(|idx| objects.returned_objects(idx) + .iter() + .map(|obj| objects.origin(*obj)) + .collect() + ).collect(); + + write!(w, "async fn run<'runner:")?; + for (ret_idx, origins) in returned_origins.iter().enumerate() { + if origins.iter().any(|origin| !origin.is_parameter()) { + write!(w, " 'r{} +", ret_idx)?; + } } - write!( - w, - ">(&'{} mut self", - if returned_origins.iter().any(|origin| !origin.is_parameter()) { - "returned" - } else { - "runner" + for idx in 0..num_returns { + write!(w, ", 'r{}", idx)?; + } + for idx in 0..func.param_types.len() { + write!(w, ", 'p{}:", idx)?; + for (ret_idx, origins) in returned_origins.iter().enumerate() { + if origins.iter().any(|origin| origin + .try_parameter() + .map(|oidx| idx == oidx) + .unwrap_or(false)) + { + write!(w, " 'r{} +", ret_idx)?; + } } - )?; + } + write!(w, ">(&'runner mut self")?; for idx in 0..func.num_dynamic_constants { write!(w, ", dc_p{}: u64", idx)?; } @@ -1281,37 +1295,35 @@ impl<'a> RTContext<'a> { let mutability = if param_muts[idx] { "Mut" } else { "" }; write!( w, - ", p{}: ::hercules_rt::Hercules{}Ref{}<'{}>", + ", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>", idx, device, mutability, - if returned_origins.iter().any(|origin| origin - .try_parameter() - .map(|oidx| idx == oidx) - .unwrap_or(false)) - { - "returned".to_string() - } else { - format!("p{}", idx) - } + idx, )?; } } - if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, ") -> {} {{", self.get_type(func.return_type))?; - } else { - let device = match return_device { - Some(Device::LLVM) | None => "CPU", - Some(Device::CUDA) => "CUDA", - _ => panic!(), - }; - let mutability = if return_mut { "Mut" } else { "" }; - write!( - w, - ") -> ::hercules_rt::Hercules{}Ref{}<'returned> {{", - device, mutability - )?; - } + write!(w, ") -> {}{}{} {{", + if num_returns != 1 { "(" } else { "" }, + func.return_types.iter().enumerate() + .map(|(ret_idx, typ)| + if self.module.types[typ.idx()].is_primitive() { + self.get_type(*typ) + } else { + let device = match return_devices[ret_idx] { + Some(Device::LLVM) | None => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_muts[ret_idx] { "Mut" } else { "" }; + format!("::hercules_rt::Hercules{}Ref{}<'r{}>", + device, mutability, ret_idx) + } + ) + .collect::<Vec<_>>() + .join(", "), + if num_returns != 1 { ")" } else { "" }, + )?; // Start with possibly re-allocating the backing memory if it's not // large enough. @@ -1367,22 +1379,48 @@ impl<'a> RTContext<'a> { write!(w, "p{}, ", idx)?; } write!(w, ").await;")?; - if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, " ret")?; + // Return the result, appropriately wrapping pointers + if num_returns == 1 { + if self.module.types[func.return_types[0].idx()].is_primitive() { + write!(w, "ret")?; + } else { + let device = match return_devices[0] { + Some(Device::LLVM) | None => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_muts[0] { "Mut" } else { "" }; + write!( + w, + "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)", + device, + mutability, + self.codegen_type_size(func.return_types[0]) + )?; + } } else { - let device = match return_device { - Some(Device::LLVM) | None => "CPU", - Some(Device::CUDA) => "CUDA", - _ => panic!(), - }; - let mutability = if return_mut { "Mut" } else { "" }; - write!( - w, - "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)", - device, - mutability, - self.codegen_type_size(func.return_type) - )?; + write!(w, "(")?; + for (idx, typ) in func.return_types.iter().enumerate() { + if self.module.types[typ.idx()].is_primitive() { + write!(w, "ret.{},", idx)?; + } else { + let device = match return_devices[idx] { + Some(Device::LLVM) | None => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_muts[idx] { "Mut" } else { "" }; + write!( + w, + "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.{}.0, {} as usize),", + device, + mutability, + idx, + self.codegen_type_size(func.return_types[idx]), + )?; + } + } + write!(w, ")")?; } write!(w, "}}}}")?; @@ -1476,9 +1514,9 @@ impl<'a> RTContext<'a> { // this means that if the function is multi-return it will return a product in the produced // Rust code // Writes from the "fn" keyword up to the end of the return type - fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID, is_unsafe: bool) -> Result<(), Error> { let func = &self.module.functions[func_id.idx()]; - write!(w, "fn {}(", func.name)?; + write!(w, "{}fn {}(", if is_unsafe { "unsafe " } else { "" }, func.name)?; let mut first_param = true; if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { first_param = false; diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch index 972405f50bb0f2a6a0574807ff5aeb64decef4fd..d89b3b8bb819f3254491b94a4b26d205bfbc9afa 100644 --- a/juno_samples/multi_return/src/cpu.sch +++ b/juno_samples/multi_return/src/cpu.sch @@ -33,4 +33,3 @@ ccp(*); dce(*); gcm(*); -xdot[true](*); diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs index 63479dbab1b3f4aaede401566ea00f3917cfa3b9..6966e0df503ee5e2c9ca3e88e4b64cdfe5a26d43 100644 --- a/juno_samples/multi_return/src/main.rs +++ b/juno_samples/multi_return/src/main.rs @@ -1,22 +1,21 @@ #![feature(concat_idents)] -juno_build::juno!("median"); +juno_build::juno!("multi_return"); use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; fn main() { - let m = vec![ - 86, 72, 14, 5, 55, 25, 98, 89, 3, 66, 44, 81, 27, 3, 40, 18, 4, 57, 93, 34, 70, 50, 50, 18, - 34, - ]; - let m = HerculesImmBox::from(m.as_slice()); + const N: usize = 32; + let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect(); + let a = HerculesImmBox::from(a.as_ref()); + let mut r = runner!(rolling_sum_prod); + let (sums, prods) = async_std::task::block_on(async { r.run(N as u64, a.to()).await }); - let mut r = runner!(median_window); - let res = async_std::task::block_on(async { r.run(m.to()).await }); - assert_eq!(res, 57); + println!("Partial Sums: {:?}", sums.as_slice::<f32>()); + println!("Partial Prods: {:?}", prods.as_slice::<f32>()); } #[test] -fn test_median_window() { +fn test_multi_return() { main() }