Skip to content
Snippets Groups Projects

Add CUDA support to HerculesBox

Merged rarbore2 requested to merge cudart into main
9 files
+ 205
38
Compare changes
  • Side-by-side
  • Inline
Files
9
+ 17
7
@@ -295,10 +295,19 @@ impl<'a> RTContext<'a> {
ref dynamic_constants,
ref args,
} => {
match self.devices[callee_id.idx()] {
Device::LLVM => {
let device = self.devices[callee_id.idx()];
match device {
// The device backends ensure that device functions have the
// same C interface.
Device::LLVM | Device::CUDA => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let device = match device {
Device::LLVM => "cpu",
Device::CUDA => "cuda",
_ => panic!(),
};
// First, get the raw pointers to collections that the
// device function takes as input.
let callee_objs = &self.collection_objects[&callee_id];
@@ -308,16 +317,18 @@ impl<'a> RTContext<'a> {
if callee_objs.is_mutated(obj) {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n",
" let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n",
idx,
self.get_value(*arg)
self.get_value(*arg),
device
)?;
} else {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n",
" let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n",
idx,
self.get_value(*arg)
self.get_value(*arg),
device
)?;
}
} else {
@@ -401,7 +412,6 @@ impl<'a> RTContext<'a> {
}
write!(block, ").await;\n")?;
}
_ => todo!(),
}
}
_ => panic!(
Loading