diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 47b739dc58efcbb3ba7fa54b48c9f5597ecb7e52..305ecf9b153dbab61536d4cb94e26a4da5116233 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -24,6 +24,7 @@ pub fn rt_codegen<W: Write>( control_subgraph: &Subgraph, bbs: &Vec<NodeID>, callgraph: &CallGraph, + devices: &Vec<Device>, memory_objects: &Vec<MemoryObjects>, memory_objects_mutability: &MemoryObjectsMutability, w: &mut W, @@ -36,6 +37,7 @@ pub fn rt_codegen<W: Write>( control_subgraph, bbs, callgraph, + devices, memory_objects, _memory_objects_mutability: memory_objects_mutability, }; @@ -50,6 +52,7 @@ struct RTContext<'a> { control_subgraph: &'a Subgraph, bbs: &'a Vec<NodeID>, callgraph: &'a CallGraph, + devices: &'a Vec<Device>, memory_objects: &'a Vec<MemoryObjects>, // TODO: use once memory objects are passed in a custom type where this // actually matters. @@ -157,6 +160,9 @@ impl<'a> RTContext<'a> { // Dump signatures for called CPU functions. write!(w, " extern \"C\" {{\n")?; for callee in self.callgraph.get_callees(self.func_id) { + if self.devices[callee.idx()] != Device::LLVM { + continue; + } let callee = &self.module.functions[callee.idx()]; write!(w, " fn {}(", callee.name)?; let mut first_param = true; @@ -376,83 +382,112 @@ impl<'a> RTContext<'a> { ref dynamic_constants, ref args, } => { - let block = &mut blocks.get_mut(&self.bbs[id.idx()]).unwrap(); - write!( - block, - " {} = unsafe {{ {}(", - self.get_value(id), - self.module.functions[callee_id.idx()].name - )?; - for dc in dynamic_constants { - self.codegen_dynamic_constant(*dc, block)?; - write!(block, ", ")?; - } - for arg in args { - write!(block, "{}, ", self.get_value(*arg))?; - } - write!(block, ") }};\n")?; - - // When a CPU function is called that returns a memory object, - // that memory object must have come from one of its parameters. - // Dynamically figure out which one it came from, so that we can - // move it to the slot of the output memory object. - let call_memory_objects = - self.memory_objects[self.func_id.idx()].memory_objects(id); - if !call_memory_objects.is_empty() { - assert_eq!(call_memory_objects.len(), 1); - let call_memory_object = call_memory_objects[0]; - - let callee_returned_memory_objects = - self.memory_objects[callee_id.idx()].returned_memory_objects(); - let possible_params: Vec<_> = (0..self.module.functions[callee_id.idx()] - .param_types - .len()) - .filter(|idx| { - let memory_object_of_param = self.memory_objects[callee_id.idx()] - .memory_object_of_parameter(*idx); - // Look at parameters that could be the source of - // the memory object returned by the function. - memory_object_of_param - .map(|memory_object_of_param| { - callee_returned_memory_objects.contains(&memory_object_of_param) - }) - .unwrap_or(false) - }) - .collect(); - let arg_memory_objects = args - .into_iter() - .enumerate() - .filter(|(idx, _)| possible_params.contains(idx)) - .map(|(_, arg)| { - self.memory_objects[self.func_id.idx()] - .memory_objects(*arg) + match self.devices[callee_id.idx()] { + Device::LLVM => { + let block = &mut blocks.get_mut(&self.bbs[id.idx()]).unwrap(); + write!( + block, + " {} = unsafe {{ {}(", + self.get_value(id), + self.module.functions[callee_id.idx()].name + )?; + for dc in dynamic_constants { + self.codegen_dynamic_constant(*dc, block)?; + write!(block, ", ")?; + } + for arg in args { + write!(block, "{}, ", self.get_value(*arg))?; + } + write!(block, ") }};\n")?; + + // When a CPU function is called that returns a memory + // object, that memory object must have come from one of + // its parameters. Dynamically figure out which one it + // came from, so that we can move it to the slot of the + // output memory object. + let call_memory_objects = + self.memory_objects[self.func_id.idx()].memory_objects(id); + if !call_memory_objects.is_empty() { + assert_eq!(call_memory_objects.len(), 1); + let call_memory_object = call_memory_objects[0]; + + let callee_returned_memory_objects = + self.memory_objects[callee_id.idx()].returned_memory_objects(); + let possible_params: Vec<_> = + (0..self.module.functions[callee_id.idx()].param_types.len()) + .filter(|idx| { + let memory_object_of_param = self.memory_objects + [callee_id.idx()] + .memory_object_of_parameter(*idx); + // Look at parameters that could be the + // source of the memory object returned + // by the function. + memory_object_of_param + .map(|memory_object_of_param| { + callee_returned_memory_objects + .contains(&memory_object_of_param) + }) + .unwrap_or(false) + }) + .collect(); + let arg_memory_objects = args .into_iter() - }) - .flatten(); - - // Dynamically check which of the memory objects - // corresponding to arguments to the call was returned by - // the call. Move that memory object into the memory object - // of the call. - let mut first_obj = true; - for arg_memory_object in arg_memory_objects { - write!(block, " ")?; - if first_obj { - first_obj = false; - } else { - write!(block, "else ")?; + .enumerate() + .filter(|(idx, _)| possible_params.contains(idx)) + .map(|(_, arg)| { + self.memory_objects[self.func_id.idx()] + .memory_objects(*arg) + .into_iter() + }) + .flatten(); + + // Dynamically check which of the memory objects + // corresponding to arguments to the call was + // returned by the call. Move that memory object + // into the memory object of the call. + let mut first_obj = true; + for arg_memory_object in arg_memory_objects { + write!(block, " ")?; + if first_obj { + first_obj = false; + } else { + write!(block, "else ")?; + } + write!(block, "if let Some(mem_obj) = mem_obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(mem_obj) as *mut u8 == {} {{\n", arg_memory_object, self.get_value(id))?; + write!( + block, + " mem_obj{} = mem_obj{}.take();\n", + call_memory_object, arg_memory_object + )?; + write!(block, " }}\n")?; + } + write!(block, " else {{\n")?; + write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known memory objects.\");\n")?; + write!(block, " }}\n")?; } - write!(block, "if let Some(mem_obj) = mem_obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(mem_obj) as *mut u8 == {} {{\n", arg_memory_object, self.get_value(id))?; + } + Device::AsyncRust => { + let block = &mut blocks.get_mut(&self.bbs[id.idx()]).unwrap(); write!( block, - " mem_obj{} = mem_obj{}.take();\n", - call_memory_object, arg_memory_object + " {} = {}(", + self.get_value(id), + self.module.functions[callee_id.idx()].name )?; - write!(block, " }}\n")?; + for dc in dynamic_constants { + self.codegen_dynamic_constant(*dc, block)?; + write!(block, ", ")?; + } + for arg in args { + if self.module.types[self.typing[arg.idx()].idx()].is_primitive() { + write!(block, "{}, ", self.get_value(*arg))?; + } else { + write!(block, "{}.take(), ", self.get_value(*arg))?; + } + } + write!(block, ").await;\n")?; } - write!(block, " else {{\n")?; - write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known memory objects.\");\n")?; - write!(block, " }}\n")?; + _ => todo!(), } } _ => panic!( diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index c0faec591bdc6d7f3ba0079395fa855cbc9012c8..b46e2dda12b41fe5b2b3905d53a8bd11387eda4d 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -329,7 +329,7 @@ pub enum Schedule { * The authoritative enumeration of supported backends. Multiple backends may * correspond to the same kind of hardware. */ -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Device { LLVM, NVVM, diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index b85835cf295620615aee586982d1d6179d3e8d29..24ed0e4eea64632e99e13cb9039512eb84f0fd1d 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -890,6 +890,7 @@ impl PassManager { &control_subgraphs[idx], &bbs[idx], &callgraph, + &devices, &memory_objects, &memory_objects_mutable, &mut rust_rt, diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs index 3bbb634c7405dd9aff81dd9c3a0068b54df45a26..b5c999fdac3738f6d3fced2164fbf9320d5a1034 100644 --- a/hercules_samples/call/src/main.rs +++ b/hercules_samples/call/src/main.rs @@ -14,6 +14,6 @@ fn main() { } #[test] -fn dot_test() { +fn call_test() { main(); }