diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index d3013239f5f78ce8e63181c181c0d5a8cbf77f81..26ca9d41dd4017af250aa0c297931f6eefdfa8d8 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -192,40 +192,70 @@ impl<'a> RTContext<'a> { } write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?; } - write!(w, ") -> {} {{", self.get_type(func.return_type))?; + write!(w, ") -> ")?; + self.write_rust_return_type(w, &func.return_types)?; + write!(w, " {{")?; // Dump signatures for called device functions. - write!(w, "extern \"C\" {{")?; + // For single-return functions we directly expose the device function + // while for multi-return functions we generate a wrapper which handles + // allocation of the return struct and extracting values from it. This + // ensures that device function signatures match what they would be in + // AsyncRust for callee_id in self.callgraph.get_callees(self.func_id) { if self.devices[callee_id.idx()] == Device::AsyncRust { continue; } let callee = &self.module.functions[callee_id.idx()]; - write!(w, "fn {}(", callee.name)?; - let mut first_param = true; - if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { - first_param = false; - write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; + let is_single_return = callee.return_types.len() == 1; + if is_single_return { + write!(w, "extern \"C\" {{")?; } - for idx in 0..callee.num_dynamic_constants { - if first_param { + self.write_device_signature_async(w, *callee_id)?; + if is_single_return { + write!(w, ";}}")?; + } else { + // Generate the wrapper function for multi-return device functions + write!(w, " {{ ")?; + // Define the return struct + write!(w, "#[repr(C)] struct ReturnStruct {{ {} }} ", + callee.return_types + .iter() + .enumerate() + .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t))) + .collect::<Vec<_>>() + .join(", "), + )?; + // Declare the extern function's signature + write!(w, "extern \"C\" {{ ")?; + self.write_device_signature(w, *callee_id)?; + write!(w, "; }}")?; + // Create the return struct + write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?; + // Call the device function + write!(w, "{}(", callee.name)?; + if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { first_param = false; - } else { - write!(w, ", ")?; + write!(w, "backing, ")?; } - write!(w, "dc{}: u64", idx)?; - } - for (idx, ty) in callee.param_types.iter().enumerate() { - if first_param { - first_param = false; - } else { - write!(w, ", ")?; + for idx in 0..callee.num_dynamic_constants { + write!(w, "dc{}, ", idx)?; } - write!(w, "p{}: {}", idx, self.get_type(*ty))?; + for idx in 0..callee.param_types.len() { + write!(w, "p{}, ", idx)?; + } + write!(w, "ret_struct.as_mut_ptr());")?; + // Extract the result into a Rust product + write!(w, "let ret_struct = ret_struct.assume_init();")?; + write!(w, "({})", + (0..callee.return_types.len()) + .map(|idx| format!("ret_struct.f{}", idx)) + .collect::<Vec<_>>() + .join(", "), + )?; + write!(w, "}}")?; } - write!(w, ") -> {};", self.get_type(callee.return_type))?; } - write!(w, "}}")?; // Set up the root environment for the function. An environment is set // up for every created task in async closures, and there needs to be a @@ -301,7 +331,7 @@ impl<'a> RTContext<'a> { // successor and are otherwise simple. Node::Start | Node::Region { preds: _ } - | Node::Projection { + | Node::ControlProjection { control: _, selection: _, } => { @@ -320,7 +350,7 @@ impl<'a> RTContext<'a> { let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some(); + let succ1_is_true = func.nodes[succ1.idx()].try_control_projection(1).is_some(); write!( epilogue, "control_token = if {} {{{}}} else {{{}}};}}", @@ -329,11 +359,20 @@ impl<'a> RTContext<'a> { if succ1_is_true { succ2 } else { succ1 }.idx(), )?; } - Node::Return { control: _, data } => { + Node::Return { control: _, ref data } => { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; - write!(epilogue, "return {};}}", self.get_value(data, id, false))?; + if data.len() == 1 { + write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?; + } else { + write!(epilogue, "return ({});}}", + data.iter() + .map(|v| self.get_value(*v, id, false)) + .collect::<Vec<_>>() + .join(", "), + )?; + } } // Fork nodes open a new environment for defining an async closure. Node::Fork { @@ -574,8 +613,8 @@ impl<'a> RTContext<'a> { ref args, } => { assert_eq!(control, bb); - // The device backends ensure that device functions have the - // same interface as AsyncRust functions. + // The device backends and the wrappers we generated earlier ensure that device + // functions have the same interface as AsyncRust functions. let block = &mut blocks.get_mut(&bb).unwrap(); let block = &mut block.data; let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); @@ -628,6 +667,13 @@ impl<'a> RTContext<'a> { } write!(block, "){};", postfix)?; } + Node::DataProjection { data, selection } => { + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = {}.{};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + selection)?; + } Node::LibraryCall { library_function, ref args, @@ -1008,7 +1054,7 @@ impl<'a> RTContext<'a> { */ fn codegen_type_size(&self, ty: TypeID) -> String { match self.module.types[ty.idx()] { - Type::Control => panic!(), + Type::Control | Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { "1".to_string() } @@ -1116,15 +1162,7 @@ impl<'a> RTContext<'a> { if is_reduce_on_child { "reduce" } else { "node" }, idx, self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_bool() { - "false" - } else if self.module.types[self.typing[idx].idx()].is_integer() { - "0" - } else if self.module.types[self.typing[idx].idx()].is_float() { - "0.0" - } else { - "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" - } + self.get_default_value(self.typing[idx]), )?; } } @@ -1402,8 +1440,101 @@ impl<'a> RTContext<'a> { } } - fn get_type(&self, id: TypeID) -> &'static str { - convert_type(&self.module.types[id.idx()]) + fn get_type(&self, id: TypeID) -> String { + convert_type(&self.module.types[id.idx()], &self.module.types) + } + + fn get_default_value(&self, idx: TypeID) -> String { + let typ = &self.module.types[idx.idx()]; + if typ.is_bool() { + "false".to_string() + } else if typ.is_integer() { + "0".to_string() + } else if typ.is_float() { + "0.0".to_string() + } else if let Some(ts) = typ.try_multi_return() { + format!("({})", ts.iter().map(|t| self.get_default_value(*t)).collect::<Vec<_>>().join(", ")) + } else { + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())".to_string() + } + } + + fn write_rust_return_type<W: Write>(&self, w: &mut W, tys: &[TypeID]) -> Result<(), Error> { + if tys.len() == 1 { + write!(w, "{}", self.get_type(tys[0])) + } else { + write!(w, "({})", + tys.iter() + .map(|t| self.get_type(*t)) + .collect::<Vec<_>>() + .join(", "), + ) + } + } + + // Writes the signature of a device function as if it were an async function, in particular + // this means that if the function is multi-return it will return a product in the produced + // Rust code + // Writes from the "fn" keyword up to the end of the return type + fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + let func = &self.module.functions[func_id.idx()]; + write!(w, "fn {}(", func.name)?; + let mut first_param = true; + if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { + first_param = false; + write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; + } + for idx in 0..func.num_dynamic_constants { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "dc{}: u64", idx)?; + } + for (idx, ty) in func.param_types.iter().enumerate() { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "p{}: {}", idx, self.get_type(*ty))?; + } + write!(w, ") -> ")?; + self.write_rust_return_type(w, &func.return_types) + } + + // Writes the true signature of a device function + // Compared to the _async version this converts multi-return into a return struct + fn write_device_signature<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + let func = &self.module.functions[func_id.idx()]; + write!(w, "fn {}(", func.name)?; + let mut first_param = true; + if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { + first_param = false; + write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; + } + for idx in 0..func.num_dynamic_constants { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "dc{}: u64", idx)?; + } + for (idx, ty) in func.param_types.iter().enumerate() { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "p{}: {}", idx, self.get_type(*ty))?; + } + if func.return_types.len() == 1 { + write!(w, ") -> {}", self.get_type(func.return_types[0])) + } else { + write!(w, ", ret_ptr: *mut ReturnStruct)") + } } fn library_prim_ty(&self, id: TypeID) -> &'static str { @@ -1426,21 +1557,24 @@ impl<'a> RTContext<'a> { } } -fn convert_type(ty: &Type) -> &'static str { +fn convert_type(ty: &Type, types: &[Type]) -> String { match ty { - Type::Boolean => "bool", - Type::Integer8 => "i8", - Type::Integer16 => "i16", - Type::Integer32 => "i32", - Type::Integer64 => "i64", - Type::UnsignedInteger8 => "u8", - Type::UnsignedInteger16 => "u16", - Type::UnsignedInteger32 => "u32", - Type::UnsignedInteger64 => "u64", - Type::Float32 => "f32", - Type::Float64 => "f64", + Type::Boolean => "bool".to_string(), + Type::Integer8 => "i8".to_string(), + Type::Integer16 => "i16".to_string(), + Type::Integer32 => "i32".to_string(), + Type::Integer64 => "i64".to_string(), + Type::UnsignedInteger8 => "u8".to_string(), + Type::UnsignedInteger16 => "u16".to_string(), + Type::UnsignedInteger32 => "u32".to_string(), + Type::UnsignedInteger64 => "u64".to_string(), + Type::Float32 => "f32".to_string(), + Type::Float64 => "f64".to_string(), Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { - "::hercules_rt::__RawPtrSendSync" + "::hercules_rt::__RawPtrSendSync".to_string() + } + Type::MultiReturn(ts) => { + format!("({})", ts.iter().map(|t| convert_type(&types[t.idx()], types)).collect::<Vec<_>>().join(", ")) } _ => panic!(), } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 3d625a3931b5125f2259b11f7833cc281518e881..d69c3cd7f41df0e5d9e5638a9693331d6e5a2ab7 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -979,6 +979,14 @@ impl Type { } } + pub fn try_multi_return(&self) -> Option<&[TypeID]> { + if let Type::MultiReturn(ts) = self { + Some(ts) + } else { + None + } + } + pub fn num_bits(&self) -> u8 { match self { Type::Boolean => 1,