diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 65f7121576c8b7c5d703a140ad90ea6262005cd8..13370c4580a9365f0a9013b573e558133656d407 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -295,10 +295,19 @@ impl<'a> RTContext<'a> { ref dynamic_constants, ref args, } => { - match self.devices[callee_id.idx()] { - Device::LLVM => { + let device = self.devices[callee_id.idx()]; + match device { + // The device backends ensure that device functions have the + // same C interface. + Device::LLVM | Device::CUDA => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let device = match device { + Device::LLVM => "cpu", + Device::CUDA => "cuda", + _ => panic!(), + }; + // First, get the raw pointers to collections that the // device function takes as input. let callee_objs = &self.collection_objects[&callee_id]; @@ -308,16 +317,18 @@ impl<'a> RTContext<'a> { if callee_objs.is_mutated(obj) { write!( block, - " let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n", + " let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n", idx, - self.get_value(*arg) + self.get_value(*arg), + device )?; } else { write!( block, - " let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n", + " let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n", idx, - self.get_value(*arg) + self.get_value(*arg), + device )?; } } else { @@ -401,7 +412,6 @@ impl<'a> RTContext<'a> { } write!(block, ").await;\n")?; } - _ => todo!(), } } _ => panic!( diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 3c5ed0bdfa9c6a7e7c9c55c9da16782ff5cc7cde..4fd0cf0b11a55beb3320e4211b2c52f7c0e4e38d 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -329,7 +329,7 @@ pub enum Schedule { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Device { LLVM, - NVVM, + CUDA, // Entry functions are lowered to async Rust code that calls device // functions (leaf nodes in the call graph), possibly concurrently. AsyncRust,