Skip to content
Snippets Groups Projects
Commit d29ce73b authored by Russel Arbore's avatar Russel Arbore
Browse files

Use cuda ptr functions for cuda device functions

parent 0640c58f
No related branches found
No related tags found
1 merge request!104Add CUDA support to HerculesBox
Pipeline #200969 passed
......@@ -295,10 +295,19 @@ impl<'a> RTContext<'a> {
ref dynamic_constants,
ref args,
} => {
match self.devices[callee_id.idx()] {
Device::LLVM => {
let device = self.devices[callee_id.idx()];
match device {
// The device backends ensure that device functions have the
// same C interface.
Device::LLVM | Device::CUDA => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let device = match device {
Device::LLVM => "cpu",
Device::CUDA => "cuda",
_ => panic!(),
};
// First, get the raw pointers to collections that the
// device function takes as input.
let callee_objs = &self.collection_objects[&callee_id];
......@@ -308,16 +317,18 @@ impl<'a> RTContext<'a> {
if callee_objs.is_mutated(obj) {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n",
" let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n",
idx,
self.get_value(*arg)
self.get_value(*arg),
device
)?;
} else {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n",
" let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n",
idx,
self.get_value(*arg)
self.get_value(*arg),
device
)?;
}
} else {
......@@ -401,7 +412,6 @@ impl<'a> RTContext<'a> {
}
write!(block, ").await;\n")?;
}
_ => todo!(),
}
}
_ => panic!(
......
......@@ -329,7 +329,7 @@ pub enum Schedule {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Device {
LLVM,
NVVM,
CUDA,
// Entry functions are lowered to async Rust code that calls device
// functions (leaf nodes in the call graph), possibly concurrently.
AsyncRust,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment