Skip to content
Snippets Groups Projects
Commit 67ac9a19 authored by Russel Arbore's avatar Russel Arbore
Browse files

HerculesBox works with dot test

parent 07b96113
No related branches found
No related tags found
1 merge request!100Hercules Box
Pipeline #200937 failed
......@@ -395,6 +395,7 @@ version = "0.1.0"
dependencies = [
"async-std",
"clap",
"hercules_rt",
"juno_build",
"rand",
"with_builtin_macros",
......
......@@ -58,7 +58,7 @@ impl<'a> RTContext<'a> {
// Dump the function signature.
write!(
w,
"#[allow(unused_variables,unused_mut)]\nasync fn {}(",
"#[allow(unused_variables,unused_mut)]\nasync fn {}<'a>(",
func.name
)?;
let mut first_param = true;
......@@ -81,75 +81,29 @@ impl<'a> RTContext<'a> {
if !self.module.types[func.param_types[idx].idx()].is_primitive() {
write!(w, "mut ")?;
}
write!(
w,
"p_i{}: {}",
idx,
self.get_type_interface(func.param_types[idx])
)?;
}
write!(w, ") -> {} {{\n", self.get_type_interface(func.return_type))?;
// Copy the "interface" parameters to "non-interface" parameters.
// The purpose of this is to convert collection objects from a Box<[u8]>
// type to a *mut u8 type. This name copying is done so that we can
// easily construct objects just after this by moving the "inferface"
// parameters.
for (idx, ty) in func.param_types.iter().enumerate() {
if self.module.types[ty.idx()].is_primitive() {
write!(w, " let p{} = p_i{};\n", idx, idx)?;
} else {
write!(
w,
" let p{} = ::std::boxed::Box::as_mut_ptr(&mut p_i{}) as *mut u8;\n",
idx, idx
)?;
}
write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
}
write!(w, ") -> {} {{\n", self.get_type(func.return_type))?;
// Collect the boxes representing ownership over collection objects for
// this function. The actual emitted computation is done entirely using
// pointers, so these get emitted to hold onto ownership over the
// underlying memory and to automatically clean them up when this
// function returns. Collection objects are inside Options, since their
// ownership may get passed to other called RT functions. If this
// function returns a collection object, then at the very end, right
// before the return, the to-be-returned pointer is compared against the
// owned collection objects - it should match exactly one of those
// objects, and that box is what's actually returned.
let mem_obj_ty = "::core::option::Option<::std::boxed::Box<[u8]>>";
// Allocate collection constants.
for object in self.collection_objects[&self.func_id].iter_objects() {
match self.collection_objects[&self.func_id].origin(object) {
CollectionObjectOrigin::Parameter(index) => write!(
w,
" let mut obj{}: {} = Some(p_i{});\n",
object.idx(),
mem_obj_ty,
index
)?,
CollectionObjectOrigin::Constant(id) => {
let size = self.codegen_type_size(self.typing[id.idx()]);
write!(
w,
" let mut obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n",
object.idx(),
mem_obj_ty,
size
)?
}
CollectionObjectOrigin::Call(_) | CollectionObjectOrigin::Undef(_) => write!(
if let CollectionObjectOrigin::Constant(id) =
self.collection_objects[&self.func_id].origin(object)
{
let size = self.codegen_type_size(self.typing[id.idx()]);
write!(
w,
" let mut obj{}: {} = None;\n",
" let mut obj{}: ::hercules_rt::HerculesBox = ::hercules_rt::HerculesBox::__zeros({});\n",
object.idx(),
mem_obj_ty,
)?,
size
)?
}
}
// Dump signatures for called CPU functions.
// Dump signatures for called device functions.
write!(w, " extern \"C\" {{\n")?;
for callee in self.callgraph.get_callees(self.func_id) {
if self.devices[callee.idx()] != Device::LLVM {
if self.devices[callee.idx()] == Device::AsyncRust {
continue;
}
let callee = &self.module.functions[callee.idx()];
......@@ -169,9 +123,9 @@ impl<'a> RTContext<'a> {
} else {
write!(w, ", ")?;
}
write!(w, "p{}: {}", idx, self.get_type(*ty))?;
write!(w, "p{}: {}", idx, self.device_get_type(*ty))?;
}
write!(w, ") -> {};\n", self.get_type(callee.return_type))?;
write!(w, ") -> {};\n", self.device_get_type(callee.return_type))?;
}
write!(w, " }}\n")?;
......@@ -190,7 +144,7 @@ impl<'a> RTContext<'a> {
} else if self.module.types[self.typing[idx].idx()].is_float() {
"0.0"
} else {
"::core::ptr::null::<u8>() as _"
"unsafe { ::hercules_rt::HerculesBox::__null() }"
}
)?;
}
......@@ -281,20 +235,7 @@ impl<'a> RTContext<'a> {
}
Node::Return { control: _, data } => {
let block = &mut blocks.get_mut(&id).unwrap();
let objects = self.collection_objects[&self.func_id].objects(data);
if objects.is_empty() {
write!(block, " return {};\n", self.get_value(data))?
} else {
// If the value to return is a collection object, figure out
// which object it actually is at runtime and return that
// box.
for object in objects {
write!(block, " if let Some(mut obj) = obj{} && ::std::boxed::Box::as_mut_ptr(&mut obj) as *mut u8 == {} {{\n", object.idx(), self.get_value(data))?;
write!(block, " return obj;\n")?;
write!(block, " }}\n")?;
}
write!(block, " panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known collection objects.\");\n")?
}
write!(block, " return {};\n", self.get_value(data))?
}
_ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
}
......@@ -313,12 +254,21 @@ impl<'a> RTContext<'a> {
match func.nodes[id.idx()] {
Node::Parameter { index } => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
write!(
block,
" {} = p{};\n",
self.get_value(id),
index
)?
if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
write!(
block,
" {} = p{};\n",
self.get_value(id),
index
)?
} else {
write!(
block,
" {} = unsafe {{ p{}.__take() }};\n",
self.get_value(id),
index
)?
}
}
Node::Constant { id: cons_id } => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
......@@ -339,11 +289,7 @@ impl<'a> RTContext<'a> {
let objects = self.collection_objects[&self.func_id].objects(id);
assert_eq!(objects.len(), 1);
let object = objects[0];
write!(
block,
"::std::boxed::Box::as_mut_ptr(obj{}.as_mut().unwrap()) as *mut u8",
object.idx()
)?
write!(block, "unsafe {{ obj{}.__take() }}", object.idx())?
}
}
write!(block, ";\n")?
......@@ -357,83 +303,82 @@ impl<'a> RTContext<'a> {
match self.devices[callee_id.idx()] {
Device::LLVM => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
// First, get the raw pointers to collections that the
// device function takes as input.
let callee_objs = &self.collection_objects[&callee_id];
for (idx, arg) in args.into_iter().enumerate() {
if let Some(obj) = callee_objs.param_to_object(idx) {
// Extract a raw pointer from the HerculesBox.
if callee_objs.is_mutated(obj) {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n",
idx,
self.get_value(*arg)
)?;
} else {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n",
idx,
self.get_value(*arg)
)?;
}
} else {
write!(
block,
" let arg_tmp{} = {};\n",
idx,
self.get_value(*arg)
)?;
}
}
// Emit the call.
write!(
block,
" {} = unsafe {{ {}(",
self.get_value(id),
" let call_tmp = unsafe {{ {}(",
self.module.functions[callee_id.idx()].name
)?;
for dc in dynamic_constants {
self.codegen_dynamic_constant(*dc, block)?;
write!(block, ", ")?;
}
for arg in args {
write!(block, "{}, ", self.get_value(*arg))?;
for idx in 0..args.len() {
write!(block, "arg_tmp{}, ", idx)?;
}
write!(block, ") }};\n")?;
// When a CPU function is called that returns a
// When a device function is called that returns a
// collection object, that object must have come from
// one of its parameters. Dynamically figure out which
// one it came from, so that we can move it to the slot
// of the output object.
let call_objects = self.collection_objects[&self.func_id].objects(id);
if !call_objects.is_empty() {
assert_eq!(call_objects.len(), 1);
let call_object = call_objects[0];
let callee_returned_objects =
self.collection_objects[&callee_id].returned_objects();
let possible_params: Vec<_> =
(0..self.module.functions[callee_id.idx()].param_types.len())
.filter(|idx| {
let object_of_param = self.collection_objects[&callee_id]
.param_to_object(*idx);
// Look at parameters that could be the
// source of the memory object returned
// by the function.
object_of_param
.map(|object_of_param| {
callee_returned_objects.contains(&object_of_param)
})
.unwrap_or(false)
})
.collect();
let arg_objects = args
.into_iter()
.enumerate()
.filter(|(idx, _)| possible_params.contains(idx))
.map(|(_, arg)| {
self.collection_objects[&self.func_id]
.objects(*arg)
.into_iter()
})
.flatten();
// Dynamically check which of the memory objects
// corresponding to arguments to the call was
// returned by the call. Move that memory object
// into the memory object of the call.
let mut first_obj = true;
for arg_object in arg_objects {
write!(block, " ")?;
if first_obj {
first_obj = false;
} else {
write!(block, "else ")?;
let caller_objects = self.collection_objects[&self.func_id].objects(id);
if !caller_objects.is_empty() {
for (idx, arg) in args.into_iter().enumerate() {
if idx != 0 {
write!(block, " else\n")?;
}
write!(block, "if let Some(obj) = obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(obj) as *mut u8 == {} {{\n", arg_object.idx(), self.get_value(id))?;
write!(block, " if call_tmp == arg_tmp{} {{", idx)?;
write!(
block,
" obj{} = obj{}.take();\n",
call_object.idx(),
arg_object.idx()
" {} = {}.__take();\n",
idx,
self.get_value(*arg)
)?;
write!(block, " }}\n")?;
write!(block, " }}")?;
}
write!(block, " else {{\n")?;
write!(block, " panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known collection objects.\");\n")?;
write!(block, " panic!(\"HERCULES PANIC: Pointer returned from device function doesn't match an argument pointer.\");\n")?;
write!(block, " }}\n")?;
} else {
write!(
block,
" {} = call_tmp;\n",
self.get_value(id)
)?;
}
}
Device::AsyncRust => {
......@@ -452,7 +397,7 @@ impl<'a> RTContext<'a> {
if self.module.types[self.typing[arg.idx()].idx()].is_primitive() {
write!(block, "{}, ", self.get_value(*arg))?;
} else {
write!(block, "{}.take(), ", self.get_value(*arg))?;
write!(block, "{}.__take(), ", self.get_value(*arg))?;
}
}
write!(block, ").await;\n")?;
......@@ -603,8 +548,8 @@ impl<'a> RTContext<'a> {
convert_type(&self.module.types[id.idx()])
}
fn get_type_interface(&self, id: TypeID) -> &'static str {
convert_type_interface(&self.module.types[id.idx()])
fn device_get_type(&self, id: TypeID) -> &'static str {
device_convert_type(&self.module.types[id.idx()])
}
}
......@@ -621,18 +566,27 @@ fn convert_type(ty: &Type) -> &'static str {
Type::UnsignedInteger64 => "u64",
Type::Float32 => "f32",
Type::Float64 => "f64",
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8",
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
"::hercules_rt::HerculesBox<'a>"
}
_ => panic!(),
}
}
/*
* Collection types are passed to / returned from runtime functions through a
* wrapper type for ownership tracking reasons.
*/
fn convert_type_interface(ty: &Type) -> &'static str {
fn device_convert_type(ty: &Type) -> &'static str {
match ty {
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "Box<[u8]>",
_ => convert_type(ty),
Type::Boolean => "bool",
Type::Integer8 => "i8",
Type::Integer16 => "i16",
Type::Integer32 => "i32",
Type::Integer64 => "i64",
Type::UnsignedInteger8 => "u8",
Type::UnsignedInteger16 => "u16",
Type::UnsignedInteger32 => "u32",
Type::UnsignedInteger64 => "u64",
Type::Float32 => "f32",
Type::Float64 => "f64",
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8",
_ => panic!(),
}
}
use std::alloc::{alloc, alloc_zeroed, dealloc, Layout};
use std::marker::PhantomData;
use std::mem::swap;
use std::ptr::{copy_nonoverlapping, NonNull};
/*
......@@ -57,6 +58,7 @@ impl<'a> HerculesBox<'a> {
}
pub unsafe fn __zeros(size: usize) -> Self {
assert_ne!(size, 0);
HerculesBox {
cpu_shared: None,
cpu_exclusive: None,
......@@ -68,6 +70,22 @@ impl<'a> HerculesBox<'a> {
}
}
pub unsafe fn __null() -> Self {
HerculesBox {
cpu_shared: None,
cpu_exclusive: None,
cpu_owned: None,
size: 0,
_phantom: PhantomData,
}
}
pub unsafe fn __take(&mut self) -> Self {
let mut ret = Self::__null();
swap(&mut ret, self);
ret
}
pub unsafe fn __cpu_ptr(&mut self) -> *mut u8 {
self.into_cpu().as_ptr()
}
......
......@@ -10,6 +10,7 @@ juno_build = { path = "../../juno_build" }
[dependencies]
clap = { version = "*", features = ["derive"] }
juno_build = { path = "../../juno_build" }
hercules_rt = { path = "../../hercules_rt" }
rand = "*"
async-std = "*"
with_builtin_macros = "0.1.0"
#![feature(box_as_ptr, let_chains)]
extern crate async_std;
extern crate hercules_rt;
extern crate juno_build;
use core::ptr::copy_nonoverlapping;
use hercules_rt::HerculesBox;
juno_build::juno!("dot");
fn main() {
async_std::task::block_on(async {
let a: Box<[f32]> = Box::new([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]);
let b: Box<[f32]> = Box::new([0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]);
let mut a_bytes: Box<[u8]> = Box::new([0; 32]);
let mut b_bytes: Box<[u8]> = Box::new([0; 32]);
unsafe {
copy_nonoverlapping(
Box::as_ptr(&a) as *const u8,
Box::as_mut_ptr(&mut a_bytes) as *mut u8,
32,
);
copy_nonoverlapping(
Box::as_ptr(&b) as *const u8,
Box::as_mut_ptr(&mut b_bytes) as *mut u8,
32,
);
};
let c = dot(8, a_bytes, b_bytes).await;
let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
let a = HerculesBox::from_slice(&a);
let b = HerculesBox::from_slice(&b);
let c = dot(8, a, b).await;
println!("{}", c);
assert_eq!(c, 70.0);
});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment