Skip to content
Snippets Groups Projects
Commit 0187323d authored by rarbore2's avatar rarbore2
Browse files

Add CUDA support to HerculesBox

parent 9db302b1
No related branches found
No related tags found
1 merge request!104Add CUDA support to HerculesBox
build-job:
stage: build
script:
- cargo build
test-job: test-job:
stage: test stage: test
script: script:
- cargo test - cargo test
- cargo test --features=cuda
...@@ -295,10 +295,19 @@ impl<'a> RTContext<'a> { ...@@ -295,10 +295,19 @@ impl<'a> RTContext<'a> {
ref dynamic_constants, ref dynamic_constants,
ref args, ref args,
} => { } => {
match self.devices[callee_id.idx()] { let device = self.devices[callee_id.idx()];
Device::LLVM => { match device {
// The device backends ensure that device functions have the
// same C interface.
Device::LLVM | Device::CUDA => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let device = match device {
Device::LLVM => "cpu",
Device::CUDA => "cuda",
_ => panic!(),
};
// First, get the raw pointers to collections that the // First, get the raw pointers to collections that the
// device function takes as input. // device function takes as input.
let callee_objs = &self.collection_objects[&callee_id]; let callee_objs = &self.collection_objects[&callee_id];
...@@ -308,16 +317,18 @@ impl<'a> RTContext<'a> { ...@@ -308,16 +317,18 @@ impl<'a> RTContext<'a> {
if callee_objs.is_mutated(obj) { if callee_objs.is_mutated(obj) {
write!( write!(
block, block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n", " let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n",
idx, idx,
self.get_value(*arg) self.get_value(*arg),
device
)?; )?;
} else { } else {
write!( write!(
block, block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n", " let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n",
idx, idx,
self.get_value(*arg) self.get_value(*arg),
device
)?; )?;
} }
} else { } else {
...@@ -401,7 +412,6 @@ impl<'a> RTContext<'a> { ...@@ -401,7 +412,6 @@ impl<'a> RTContext<'a> {
} }
write!(block, ").await;\n")?; write!(block, ").await;\n")?;
} }
_ => todo!(),
} }
} }
_ => panic!( _ => panic!(
......
...@@ -329,7 +329,7 @@ pub enum Schedule { ...@@ -329,7 +329,7 @@ pub enum Schedule {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Device { pub enum Device {
LLVM, LLVM,
NVVM, CUDA,
// Entry functions are lowered to async Rust code that calls device // Entry functions are lowered to async Rust code that calls device
// functions (leaf nodes in the call graph), possibly concurrently. // functions (leaf nodes in the call graph), possibly concurrently.
AsyncRust, AsyncRust,
......
...@@ -4,5 +4,8 @@ version = "0.1.0" ...@@ -4,5 +4,8 @@ version = "0.1.0"
authors = ["Russel Arbore <rarbore2@illinois.edu>"] authors = ["Russel Arbore <rarbore2@illinois.edu>"]
edition = "2021" edition = "2021"
[features]
cuda = []
[dependencies] [dependencies]
use std::env::var;
use std::path::Path;
use std::process::Command;
fn main() {
if cfg!(feature = "cuda") {
let out_dir = var("OUT_DIR").unwrap();
Command::new("nvcc")
.args(&["src/rtdefs.cu", "-c", "-o"])
.arg(&format!("{}/rtdefs.o", out_dir))
.status()
.expect("PANIC: NVCC failed when building runtime. Is NVCC installed?");
Command::new("ar")
.args(&["crus", "librtdefs.a", "rtdefs.o"])
.current_dir(&Path::new(&out_dir))
.status()
.unwrap();
println!("cargo::rustc-link-search=native={}", out_dir);
println!("cargo::rustc-link-search=native=/usr/lib/x86_64-linux-gnu/");
println!("cargo::rustc-link-lib=static=rtdefs");
println!("cargo::rustc-link-lib=cudart");
println!("cargo::rerun-if-changed=src/rtdefs.cu");
}
}
...@@ -4,6 +4,16 @@ use std::mem::swap; ...@@ -4,6 +4,16 @@ use std::mem::swap;
use std::ptr::{copy_nonoverlapping, NonNull}; use std::ptr::{copy_nonoverlapping, NonNull};
use std::slice::from_raw_parts; use std::slice::from_raw_parts;
#[cfg(feature = "cuda")]
extern "C" {
fn cuda_alloc(size: usize) -> *mut u8;
fn cuda_alloc_zeroed(size: usize) -> *mut u8;
fn cuda_dealloc(ptr: *mut u8);
fn copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
fn copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
fn copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
}
/* /*
* An in-memory collection object that can be used by functions compiled by the * An in-memory collection object that can be used by functions compiled by the
* Hercules compiler. * Hercules compiler.
...@@ -13,16 +23,23 @@ pub struct HerculesBox<'a> { ...@@ -13,16 +23,23 @@ pub struct HerculesBox<'a> {
cpu_exclusive: Option<NonNull<u8>>, cpu_exclusive: Option<NonNull<u8>>,
cpu_owned: Option<NonNull<u8>>, cpu_owned: Option<NonNull<u8>>,
#[cfg(feature = "cuda")]
cuda_owned: Option<NonNull<u8>>,
size: usize, size: usize,
_phantom: PhantomData<&'a u8>, _phantom: PhantomData<&'a u8>,
} }
impl<'a> HerculesBox<'a> { impl<'b, 'a: 'b> HerculesBox<'a> {
pub fn from_slice<T>(slice: &'a [T]) -> Self { pub fn from_slice<T>(slice: &'a [T]) -> Self {
HerculesBox { HerculesBox {
cpu_shared: Some(unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }), cpu_shared: Some(unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }),
cpu_exclusive: None, cpu_exclusive: None,
cpu_owned: None, cpu_owned: None,
#[cfg(feature = "cuda")]
cuda_owned: None,
size: slice.len() * size_of::<T>(), size: slice.len() * size_of::<T>(),
_phantom: PhantomData, _phantom: PhantomData,
} }
...@@ -33,36 +50,69 @@ impl<'a> HerculesBox<'a> { ...@@ -33,36 +50,69 @@ impl<'a> HerculesBox<'a> {
cpu_shared: None, cpu_shared: None,
cpu_exclusive: Some(unsafe { NonNull::new_unchecked(slice.as_mut_ptr() as *mut u8) }), cpu_exclusive: Some(unsafe { NonNull::new_unchecked(slice.as_mut_ptr() as *mut u8) }),
cpu_owned: None, cpu_owned: None,
#[cfg(feature = "cuda")]
cuda_owned: None,
size: slice.len() * size_of::<T>(), size: slice.len() * size_of::<T>(),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
pub fn as_slice<T>(&'a self) -> &'a [T] { pub fn as_slice<T>(&'b mut self) -> &'b [T] {
assert_eq!(self.size % size_of::<T>(), 0); assert_eq!(self.size % size_of::<T>(), 0);
unsafe { from_raw_parts(self.__cpu_ptr() as *const T, self.size / size_of::<T>()) } unsafe { from_raw_parts(self.__cpu_ptr() as *const T, self.size / size_of::<T>()) }
} }
unsafe fn into_cpu(&self) -> NonNull<u8> { unsafe fn get_cpu_ptr(&self) -> Option<NonNull<u8>> {
self.cpu_shared self.cpu_owned.or(self.cpu_exclusive).or(self.cpu_shared)
.or(self.cpu_exclusive) }
.or(self.cpu_owned)
.unwrap() #[cfg(feature = "cuda")]
unsafe fn get_cuda_ptr(&self) -> Option<NonNull<u8>> {
self.cuda_owned
} }
unsafe fn into_cpu_mut(&mut self) -> NonNull<u8> { unsafe fn allocate_cpu(&mut self) -> NonNull<u8> {
if let Some(ptr) = self.cpu_exclusive.or(self.cpu_owned) { if let Some(ptr) = self.cpu_owned {
ptr ptr
} else { } else {
let ptr = let ptr =
NonNull::new(alloc(Layout::from_size_align_unchecked(self.size, 16))).unwrap(); NonNull::new(alloc(Layout::from_size_align_unchecked(self.size, 16))).unwrap();
copy_nonoverlapping(self.cpu_shared.unwrap().as_ptr(), ptr.as_ptr(), self.size);
self.cpu_owned = Some(ptr); self.cpu_owned = Some(ptr);
self.cpu_shared = None;
ptr ptr
} }
} }
#[cfg(feature = "cuda")]
unsafe fn allocate_cuda(&mut self) -> NonNull<u8> {
if let Some(ptr) = self.cuda_owned {
ptr
} else {
let ptr = cuda_alloc(self.size);
self.cuda_owned = Some(NonNull::new(ptr).unwrap());
self.cuda_owned.unwrap()
}
}
unsafe fn deallocate_cpu(&mut self) {
if let Some(ptr) = self.cpu_owned {
dealloc(
ptr.as_ptr(),
Layout::from_size_align_unchecked(self.size, 16),
);
self.cpu_owned = None;
}
}
#[cfg(feature = "cuda")]
unsafe fn deallocate_cuda(&mut self) {
if let Some(ptr) = self.cuda_owned {
cuda_dealloc(ptr.as_ptr());
self.cuda_owned = None;
}
}
pub unsafe fn __zeros(size: u64) -> Self { pub unsafe fn __zeros(size: u64) -> Self {
assert_ne!(size, 0); assert_ne!(size, 0);
let size = size as usize; let size = size as usize;
...@@ -72,6 +122,10 @@ impl<'a> HerculesBox<'a> { ...@@ -72,6 +122,10 @@ impl<'a> HerculesBox<'a> {
cpu_owned: Some( cpu_owned: Some(
NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16))).unwrap(), NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16))).unwrap(),
), ),
#[cfg(feature = "cuda")]
cuda_owned: None,
size: size, size: size,
_phantom: PhantomData, _phantom: PhantomData,
} }
...@@ -82,6 +136,10 @@ impl<'a> HerculesBox<'a> { ...@@ -82,6 +136,10 @@ impl<'a> HerculesBox<'a> {
cpu_shared: None, cpu_shared: None,
cpu_exclusive: None, cpu_exclusive: None,
cpu_owned: None, cpu_owned: None,
#[cfg(feature = "cuda")]
cuda_owned: None,
size: 0, size: 0,
_phantom: PhantomData, _phantom: PhantomData,
} }
...@@ -93,24 +151,61 @@ impl<'a> HerculesBox<'a> { ...@@ -93,24 +151,61 @@ impl<'a> HerculesBox<'a> {
ret ret
} }
pub unsafe fn __cpu_ptr(&self) -> *mut u8 { pub unsafe fn __cpu_ptr(&mut self) -> *mut u8 {
self.into_cpu().as_ptr() if let Some(ptr) = self.get_cpu_ptr() {
return ptr.as_ptr();
}
#[cfg(feature = "cuda")]
{
let cuda_ptr = self.get_cuda_ptr().unwrap();
let cpu_ptr = self.allocate_cpu();
copy_cuda_to_cpu(cpu_ptr.as_ptr(), cuda_ptr.as_ptr(), self.size);
return cpu_ptr.as_ptr();
}
panic!()
} }
pub unsafe fn __cpu_ptr_mut(&mut self) -> *mut u8 { pub unsafe fn __cpu_ptr_mut(&mut self) -> *mut u8 {
self.into_cpu_mut().as_ptr() let cpu_ptr = self.__cpu_ptr();
if Some(cpu_ptr) == self.cpu_shared.map(|nn| nn.as_ptr()) {
self.allocate_cpu();
copy_nonoverlapping(cpu_ptr, self.cpu_owned.unwrap().as_ptr(), self.size);
}
self.cpu_shared = None;
self.cpu_exclusive = None;
#[cfg(feature = "cuda")]
self.deallocate_cuda();
cpu_ptr
}
#[cfg(feature = "cuda")]
pub unsafe fn __cuda_ptr(&mut self) -> *mut u8 {
if let Some(ptr) = self.get_cuda_ptr() {
ptr.as_ptr()
} else {
let cpu_ptr = self.get_cpu_ptr().unwrap();
let cuda_ptr = self.allocate_cuda();
copy_cpu_to_cuda(cuda_ptr.as_ptr(), cpu_ptr.as_ptr(), self.size);
cuda_ptr.as_ptr()
}
}
#[cfg(feature = "cuda")]
pub unsafe fn __cuda_ptr_mut(&mut self) -> *mut u8 {
let cuda_ptr = self.__cuda_ptr();
self.cpu_shared = None;
self.cpu_exclusive = None;
self.deallocate_cpu();
cuda_ptr
} }
} }
impl<'a> Drop for HerculesBox<'a> { impl<'a> Drop for HerculesBox<'a> {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(ptr) = self.cpu_owned { unsafe {
unsafe { self.deallocate_cpu();
dealloc( #[cfg(feature = "cuda")]
ptr.as_ptr(), self.deallocate_cuda();
Layout::from_size_align_unchecked(self.size, 16),
)
}
} }
} }
} }
extern "C" {
void *cuda_alloc(size_t size) {
void *ptr = NULL;
cudaError_t res = cudaMalloc(&ptr, size);
if (res != cudaSuccess) {
ptr = NULL;
}
return ptr;
}
void *cuda_alloc_zeroed(size_t size) {
void *ptr = cuda_alloc(size);
if (!ptr) {
return NULL;
}
cudaError_t res = cudaMemset(ptr, 0, size);
if (res != cudaSuccess) {
return NULL;
}
return ptr;
}
void cuda_dealloc(void *ptr) {
cudaFree(ptr);
}
void copy_cpu_to_cuda(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
}
void copy_cuda_to_cpu(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
}
void copy_cuda_to_cuda(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
}
}
...@@ -23,7 +23,7 @@ fn main() { ...@@ -23,7 +23,7 @@ fn main() {
} }
let a = HerculesBox::from_slice_mut(&mut a); let a = HerculesBox::from_slice_mut(&mut a);
let b = HerculesBox::from_slice_mut(&mut b); let b = HerculesBox::from_slice_mut(&mut b);
let c = matmul(I as u64, J as u64, K as u64, a, b).await; let mut c = matmul(I as u64, J as u64, K as u64, a, b).await;
assert_eq!(c.as_slice::<i32>(), &*correct_c); assert_eq!(c.as_slice::<i32>(), &*correct_c);
}); });
} }
......
...@@ -21,17 +21,17 @@ fn main() { ...@@ -21,17 +21,17 @@ fn main() {
} }
} }
} }
let c = { let mut c = {
let a = HerculesBox::from_slice(&a); let a = HerculesBox::from_slice(&a);
let b = HerculesBox::from_slice(&b); let b = HerculesBox::from_slice(&b);
matmul(I as u64, J as u64, K as u64, a, b).await matmul(I as u64, J as u64, K as u64, a, b).await
}; };
let tiled_c = { assert_eq!(c.as_slice::<i32>(), &*correct_c);
let mut tiled_c = {
let a = HerculesBox::from_slice(&a); let a = HerculesBox::from_slice(&a);
let b = HerculesBox::from_slice(&b); let b = HerculesBox::from_slice(&b);
tiled_64_matmul(I as u64, J as u64, K as u64, a, b).await tiled_64_matmul(I as u64, J as u64, K as u64, a, b).await
}; };
assert_eq!(c.as_slice::<i32>(), &*correct_c);
assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
}); });
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment