diff --git a/Cargo.lock b/Cargo.lock index 7cac317a31d70bdd9019f15c2b40dc54baf050ce..53e26bbb3a7fe7062df1012ecf93fb8706c6baad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -910,6 +910,7 @@ version = "0.1.0" dependencies = [ "async-std", "clap", + "hercules_rt", "juno_build", "rand", "with_builtin_macros", diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index a4c6169e9a003695b2072b8f6151a71015c0974e..6278b790d849237b758c1c8ab011d5894731d7cb 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -58,7 +58,7 @@ impl<'a> RTContext<'a> { // Dump the function signature. write!( w, - "#[allow(unused_variables,unused_mut)]\nasync fn {}<'a>(", + "#[allow(unused_variables,unused_mut,unused_parens)]\nasync fn {}<'a>(", func.name )?; let mut first_param = true; @@ -93,7 +93,7 @@ impl<'a> RTContext<'a> { let size = self.codegen_type_size(self.typing[id.idx()]); write!( w, - " let mut obj{}: ::hercules_rt::HerculesBox = ::hercules_rt::HerculesBox::__zeros({});\n", + " let mut obj{}: ::hercules_rt::HerculesBox = unsafe {{ ::hercules_rt::HerculesBox::__zeros({}) }};\n", object.idx(), size )? @@ -361,11 +361,15 @@ impl<'a> RTContext<'a> { if idx != 0 { write!(block, " else\n")?; } - write!(block, " if call_tmp == arg_tmp{} {{", idx)?; write!( block, - " {} = {}.__take();\n", - idx, + " if call_tmp == arg_tmp{} {{\n", + idx + )?; + write!( + block, + " {} = unsafe {{ {}.__take() }};\n", + self.get_value(id), self.get_value(*arg) )?; write!(block, " }}")?; @@ -397,7 +401,7 @@ impl<'a> RTContext<'a> { if self.module.types[self.typing[arg.idx()].idx()].is_primitive() { write!(block, "{}, ", self.get_value(*arg))?; } else { - write!(block, "{}.__take(), ", self.get_value(*arg))?; + write!(block, "unsafe {{ {}.__take() }}, ", self.get_value(*arg))?; } } write!(block, ").await;\n")?; diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index c1f075bbc2cb1248473f1e928b34b762848bf549..50ac260c160deb37c40707809438870a1f3acf57 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -2,6 +2,7 @@ use std::alloc::{alloc, alloc_zeroed, dealloc, Layout}; use std::marker::PhantomData; use std::mem::swap; use std::ptr::{copy_nonoverlapping, NonNull}; +use std::slice::from_raw_parts; /* * An in-memory collection object that can be used by functions compiled by the @@ -22,7 +23,7 @@ impl<'a> HerculesBox<'a> { cpu_shared: Some(unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }), cpu_exclusive: None, cpu_owned: None, - size: slice.len() * align_of::<T>(), + size: slice.len() * size_of::<T>(), _phantom: PhantomData, } } @@ -32,12 +33,17 @@ impl<'a> HerculesBox<'a> { cpu_shared: None, cpu_exclusive: Some(unsafe { NonNull::new_unchecked(slice.as_mut_ptr() as *mut u8) }), cpu_owned: None, - size: slice.len() * align_of::<T>(), + size: slice.len() * size_of::<T>(), _phantom: PhantomData, } } - unsafe fn into_cpu(&mut self) -> NonNull<u8> { + pub fn as_slice<T>(&'a self) -> &'a [T] { + assert_eq!(self.size % size_of::<T>(), 0); + unsafe { from_raw_parts(self.__cpu_ptr() as *const T, self.size / size_of::<T>()) } + } + + unsafe fn into_cpu(&self) -> NonNull<u8> { self.cpu_shared .or(self.cpu_exclusive) .or(self.cpu_owned) @@ -57,8 +63,9 @@ impl<'a> HerculesBox<'a> { } } - pub unsafe fn __zeros(size: usize) -> Self { + pub unsafe fn __zeros(size: u64) -> Self { assert_ne!(size, 0); + let size = size as usize; HerculesBox { cpu_shared: None, cpu_exclusive: None, @@ -86,7 +93,7 @@ impl<'a> HerculesBox<'a> { ret } - pub unsafe fn __cpu_ptr(&mut self) -> *mut u8 { + pub unsafe fn __cpu_ptr(&self) -> *mut u8 { self.into_cpu().as_ptr() } diff --git a/hercules_samples/matmul/Cargo.toml b/hercules_samples/matmul/Cargo.toml index d3975c5ca58b68cdb3fef0f6d8a3cf8e106408d6..9066c1535e2c40400bdb3b5ca20a3e38237ef597 100644 --- a/hercules_samples/matmul/Cargo.toml +++ b/hercules_samples/matmul/Cargo.toml @@ -10,6 +10,7 @@ juno_build = { path = "../../juno_build" } [dependencies] clap = { version = "*", features = ["derive"] } juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } rand = "*" async-std = "*" with_builtin_macros = "0.1.0" diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 93d007c791579a75dea65d5680ab3018e9b00085..34612801b4d27c456251cda01550c7cbd0700ebf 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -1,13 +1,14 @@ #![feature(box_as_ptr, let_chains)] extern crate async_std; +extern crate hercules_rt; extern crate juno_build; extern crate rand; -use core::ptr::copy_nonoverlapping; - use rand::random; +use hercules_rt::HerculesBox; + juno_build::juno!("matmul"); fn main() { @@ -15,31 +16,8 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); - let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]); - let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]); - unsafe { - copy_nonoverlapping( - Box::as_ptr(&a) as *const u8, - Box::as_mut_ptr(&mut a_bytes) as *mut u8, - I * J * 4, - ); - copy_nonoverlapping( - Box::as_ptr(&b) as *const u8, - Box::as_mut_ptr(&mut b_bytes) as *mut u8, - J * K * 4, - ); - }; - let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await; - let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); - unsafe { - copy_nonoverlapping( - Box::as_ptr(&c_bytes) as *const u8, - Box::as_mut_ptr(&mut c) as *mut u8, - I * K * 4, - ); - }; + let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { @@ -48,7 +26,10 @@ fn main() { } } } - assert_eq!(c, correct_c); + let a = HerculesBox::from_slice_mut(&mut a); + let b = HerculesBox::from_slice_mut(&mut b); + let c = matmul(I as u64, J as u64, K as u64, a, b).await; + assert_eq!(c.as_slice::<i32>(), &*correct_c); }); }