From d29ce73b52edd6c6a1894d9f908565d457b817da Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 13 Jan 2025 20:01:49 -0800 Subject: [PATCH] Use cuda ptr functions for cuda device functions --- hercules_cg/src/rt.rs | 24 +++++++++++++++++------- hercules_ir/src/ir.rs | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 65f71215..13370c45 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -295,10 +295,19 @@ impl<'a> RTContext<'a> { ref dynamic_constants, ref args, } => { - match self.devices[callee_id.idx()] { - Device::LLVM => { + let device = self.devices[callee_id.idx()]; + match device { + // The device backends ensure that device functions have the + // same C interface. + Device::LLVM | Device::CUDA => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let device = match device { + Device::LLVM => "cpu", + Device::CUDA => "cuda", + _ => panic!(), + }; + // First, get the raw pointers to collections that the // device function takes as input. let callee_objs = &self.collection_objects[&callee_id]; @@ -308,16 +317,18 @@ impl<'a> RTContext<'a> { if callee_objs.is_mutated(obj) { write!( block, - " let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n", + " let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n", idx, - self.get_value(*arg) + self.get_value(*arg), + device )?; } else { write!( block, - " let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n", + " let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n", idx, - self.get_value(*arg) + self.get_value(*arg), + device )?; } } else { @@ -401,7 +412,6 @@ impl<'a> RTContext<'a> { } write!(block, ").await;\n")?; } - _ => todo!(), } } _ => panic!( diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 3c5ed0bd..4fd0cf0b 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -329,7 +329,7 @@ pub enum Schedule { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Device { LLVM, - NVVM, + CUDA, // Entry functions are lowered to async Rust code that calls device // functions (leaf nodes in the call graph), possibly concurrently. AsyncRust, -- GitLab