From 387098dcc3e326527b8631c8e3fa6cd44eefc994 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Sun, 12 Jan 2025 22:08:50 -0600
Subject: [PATCH] Hercules Box

---
 Cargo.lock                             |  11 +
 Cargo.toml                             |   1 +
 hercules_cg/src/rt.rs                  | 266 +++++++++++--------------
 hercules_rt/Cargo.toml                 |   7 +
 hercules_rt/src/lib.rs                 | 116 +++++++++++
 hercules_samples/dot/Cargo.toml        |   1 +
 hercules_samples/dot/src/main.rs       |  25 +--
 hercules_samples/matmul/Cargo.toml     |   1 +
 hercules_samples/matmul/src/main.rs    |  37 +---
 juno_samples/antideps/Cargo.toml       |   1 +
 juno_samples/implicit_clone/Cargo.toml |   1 +
 juno_samples/matmul/Cargo.toml         |   1 +
 juno_samples/matmul/src/main.rs        |  61 ++----
 juno_samples/nested_ccp/Cargo.toml     |   1 +
 juno_samples/nested_ccp/src/main.rs    |  27 +--
 juno_samples/simple3/Cargo.toml        |   1 +
 juno_samples/simple3/src/main.rs       |  25 +--
 17 files changed, 302 insertions(+), 281 deletions(-)
 create mode 100644 hercules_rt/Cargo.toml
 create mode 100644 hercules_rt/src/lib.rs

diff --git a/Cargo.lock b/Cargo.lock
index 7b70c0b0..7e99e454 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -395,6 +395,7 @@ version = "0.1.0"
 dependencies = [
  "async-std",
  "clap",
+ "hercules_rt",
  "juno_build",
  "rand",
  "with_builtin_macros",
@@ -651,6 +652,10 @@ dependencies = [
  "take_mut",
 ]
 
+[[package]]
+name = "hercules_rt"
+version = "0.1.0"
+
 [[package]]
 name = "hermit-abi"
 version = "0.4.0"
@@ -702,6 +707,7 @@ name = "juno_antideps"
 version = "0.1.0"
 dependencies = [
  "async-std",
+ "hercules_rt",
  "juno_build",
  "with_builtin_macros",
 ]
@@ -746,6 +752,7 @@ name = "juno_implicit_clone"
 version = "0.1.0"
 dependencies = [
  "async-std",
+ "hercules_rt",
  "juno_build",
  "with_builtin_macros",
 ]
@@ -755,6 +762,7 @@ name = "juno_matmul"
 version = "0.1.0"
 dependencies = [
  "async-std",
+ "hercules_rt",
  "juno_build",
  "rand",
  "with_builtin_macros",
@@ -765,6 +773,7 @@ name = "juno_nested_ccp"
 version = "0.1.0"
 dependencies = [
  "async-std",
+ "hercules_rt",
  "juno_build",
  "with_builtin_macros",
 ]
@@ -784,6 +793,7 @@ name = "juno_simple3"
 version = "0.1.0"
 dependencies = [
  "async-std",
+ "hercules_rt",
  "juno_build",
  "with_builtin_macros",
 ]
@@ -905,6 +915,7 @@ version = "0.1.0"
 dependencies = [
  "async-std",
  "clap",
+ "hercules_rt",
  "juno_build",
  "rand",
  "with_builtin_macros",
diff --git a/Cargo.toml b/Cargo.toml
index dc0c6478..86307fd8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,6 +4,7 @@ members = [
 	"hercules_cg",
 	"hercules_ir",
 	"hercules_opt",
+	"hercules_rt",
 	
 	"hercules_tools/hercules_driver",
 
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index e484729d..6278b790 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -58,7 +58,7 @@ impl<'a> RTContext<'a> {
         // Dump the function signature.
         write!(
             w,
-            "#[allow(unused_variables,unused_mut)]\nasync fn {}(",
+            "#[allow(unused_variables,unused_mut,unused_parens)]\nasync fn {}<'a>(",
             func.name
         )?;
         let mut first_param = true;
@@ -81,75 +81,29 @@ impl<'a> RTContext<'a> {
             if !self.module.types[func.param_types[idx].idx()].is_primitive() {
                 write!(w, "mut ")?;
             }
-            write!(
-                w,
-                "p_i{}: {}",
-                idx,
-                self.get_type_interface(func.param_types[idx])
-            )?;
-        }
-        write!(w, ") -> {} {{\n", self.get_type_interface(func.return_type))?;
-
-        // Copy the "interface" parameters to "non-interface" parameters.
-        // The purpose of this is to convert collection objects from a Box<[u8]>
-        // type to a *mut u8 type. This name copying is done so that we can
-        // easily construct objects just after this by moving the "inferface"
-        // parameters.
-        for (idx, ty) in func.param_types.iter().enumerate() {
-            if self.module.types[ty.idx()].is_primitive() {
-                write!(w, "    let p{} = p_i{};\n", idx, idx)?;
-            } else {
-                write!(
-                    w,
-                    "    let p{} = ::std::boxed::Box::as_mut_ptr(&mut p_i{}) as *mut u8;\n",
-                    idx, idx
-                )?;
-            }
+            write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
         }
+        write!(w, ") -> {} {{\n", self.get_type(func.return_type))?;
 
-        // Collect the boxes representing ownership over collection objects for
-        // this function. The actual emitted computation is done entirely using
-        // pointers, so these get emitted to hold onto ownership over the
-        // underlying memory and to automatically clean them up when this
-        // function returns. Collection objects are inside Options, since their
-        // ownership may get passed to other called RT functions. If this
-        // function returns a collection object, then at the very end, right
-        // before the return, the to-be-returned pointer is compared against the
-        // owned collection objects - it should match exactly one of those
-        // objects, and that box is what's actually returned.
-        let mem_obj_ty = "::core::option::Option<::std::boxed::Box<[u8]>>";
+        // Allocate collection constants.
         for object in self.collection_objects[&self.func_id].iter_objects() {
-            match self.collection_objects[&self.func_id].origin(object) {
-                CollectionObjectOrigin::Parameter(index) => write!(
-                    w,
-                    "    let mut obj{}: {} = Some(p_i{});\n",
-                    object.idx(),
-                    mem_obj_ty,
-                    index
-                )?,
-                CollectionObjectOrigin::Constant(id) => {
-                    let size = self.codegen_type_size(self.typing[id.idx()]);
-                    write!(
-                        w,
-                        "    let mut obj{}: {} = Some((0..{}).map(|_| 0u8).collect());\n",
-                        object.idx(),
-                        mem_obj_ty,
-                        size
-                    )?
-                }
-                CollectionObjectOrigin::Call(_) | CollectionObjectOrigin::Undef(_) => write!(
+            if let CollectionObjectOrigin::Constant(id) =
+                self.collection_objects[&self.func_id].origin(object)
+            {
+                let size = self.codegen_type_size(self.typing[id.idx()]);
+                write!(
                     w,
-                    "    let mut obj{}: {} = None;\n",
+                    "    let mut obj{}: ::hercules_rt::HerculesBox = unsafe {{ ::hercules_rt::HerculesBox::__zeros({}) }};\n",
                     object.idx(),
-                    mem_obj_ty,
-                )?,
+                    size
+                )?
             }
         }
 
-        // Dump signatures for called CPU functions.
+        // Dump signatures for called device functions.
         write!(w, "    extern \"C\" {{\n")?;
         for callee in self.callgraph.get_callees(self.func_id) {
-            if self.devices[callee.idx()] != Device::LLVM {
+            if self.devices[callee.idx()] == Device::AsyncRust {
                 continue;
             }
             let callee = &self.module.functions[callee.idx()];
@@ -169,9 +123,9 @@ impl<'a> RTContext<'a> {
                 } else {
                     write!(w, ", ")?;
                 }
-                write!(w, "p{}: {}", idx, self.get_type(*ty))?;
+                write!(w, "p{}: {}", idx, self.device_get_type(*ty))?;
             }
-            write!(w, ") -> {};\n", self.get_type(callee.return_type))?;
+            write!(w, ") -> {};\n", self.device_get_type(callee.return_type))?;
         }
         write!(w, "    }}\n")?;
 
@@ -190,7 +144,7 @@ impl<'a> RTContext<'a> {
                 } else if self.module.types[self.typing[idx].idx()].is_float() {
                     "0.0"
                 } else {
-                    "::core::ptr::null::<u8>() as _"
+                    "unsafe { ::hercules_rt::HerculesBox::__null() }"
                 }
             )?;
         }
@@ -281,20 +235,7 @@ impl<'a> RTContext<'a> {
             }
             Node::Return { control: _, data } => {
                 let block = &mut blocks.get_mut(&id).unwrap();
-                let objects = self.collection_objects[&self.func_id].objects(data);
-                if objects.is_empty() {
-                    write!(block, "                return {};\n", self.get_value(data))?
-                } else {
-                    // If the value to return is a collection object, figure out
-                    // which object it actually is at runtime and return that
-                    // box.
-                    for object in objects {
-                        write!(block, "                if let Some(mut obj) = obj{} && ::std::boxed::Box::as_mut_ptr(&mut obj) as *mut u8 == {} {{\n", object.idx(), self.get_value(data))?;
-                        write!(block, "                    return obj;\n")?;
-                        write!(block, "                }}\n")?;
-                    }
-                    write!(block, "                panic!(\"HERCULES PANIC: Pointer to be returned doesn't match any known collection objects.\");\n")?
-                }
+                write!(block, "                return {};\n", self.get_value(data))?
             }
             _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
         }
@@ -313,12 +254,21 @@ impl<'a> RTContext<'a> {
         match func.nodes[id.idx()] {
             Node::Parameter { index } => {
                 let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
-                write!(
-                    block,
-                    "                {} = p{};\n",
-                    self.get_value(id),
-                    index
-                )?
+                if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
+                    write!(
+                        block,
+                        "                {} = p{};\n",
+                        self.get_value(id),
+                        index
+                    )?
+                } else {
+                    write!(
+                        block,
+                        "                {} = unsafe {{ p{}.__take() }};\n",
+                        self.get_value(id),
+                        index
+                    )?
+                }
             }
             Node::Constant { id: cons_id } => {
                 let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
@@ -339,11 +289,7 @@ impl<'a> RTContext<'a> {
                         let objects = self.collection_objects[&self.func_id].objects(id);
                         assert_eq!(objects.len(), 1);
                         let object = objects[0];
-                        write!(
-                            block,
-                            "::std::boxed::Box::as_mut_ptr(obj{}.as_mut().unwrap()) as *mut u8",
-                            object.idx()
-                        )?
+                        write!(block, "unsafe {{ obj{}.__take() }}", object.idx())?
                     }
                 }
                 write!(block, ";\n")?
@@ -357,83 +303,86 @@ impl<'a> RTContext<'a> {
                 match self.devices[callee_id.idx()] {
                     Device::LLVM => {
                         let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
+
+                        // First, get the raw pointers to collections that the
+                        // device function takes as input.
+                        let callee_objs = &self.collection_objects[&callee_id];
+                        for (idx, arg) in args.into_iter().enumerate() {
+                            if let Some(obj) = callee_objs.param_to_object(idx) {
+                                // Extract a raw pointer from the HerculesBox.
+                                if callee_objs.is_mutated(obj) {
+                                    write!(
+                                        block,
+                                        "                let arg_tmp{} = unsafe {{ {}.__cpu_ptr_mut() }};\n",
+                                        idx,
+                                        self.get_value(*arg)
+                                    )?;
+                                } else {
+                                    write!(
+                                        block,
+                                        "                let arg_tmp{} = unsafe {{ {}.__cpu_ptr() }};\n",
+                                        idx,
+                                        self.get_value(*arg)
+                                    )?;
+                                }
+                            } else {
+                                write!(
+                                    block,
+                                    "                let arg_tmp{} = {};\n",
+                                    idx,
+                                    self.get_value(*arg)
+                                )?;
+                            }
+                        }
+
+                        // Emit the call.
                         write!(
                             block,
-                            "                {} = unsafe {{ {}(",
-                            self.get_value(id),
+                            "                let call_tmp = unsafe {{ {}(",
                             self.module.functions[callee_id.idx()].name
                         )?;
                         for dc in dynamic_constants {
                             self.codegen_dynamic_constant(*dc, block)?;
                             write!(block, ", ")?;
                         }
-                        for arg in args {
-                            write!(block, "{}, ", self.get_value(*arg))?;
+                        for idx in 0..args.len() {
+                            write!(block, "arg_tmp{}, ", idx)?;
                         }
                         write!(block, ") }};\n")?;
 
-                        // When a CPU function is called that returns a
+                        // When a device function is called that returns a
                         // collection object, that object must have come from
                         // one of its parameters. Dynamically figure out which
                         // one it came from, so that we can move it to the slot
                         // of the output object.
-                        let call_objects = self.collection_objects[&self.func_id].objects(id);
-                        if !call_objects.is_empty() {
-                            assert_eq!(call_objects.len(), 1);
-                            let call_object = call_objects[0];
-
-                            let callee_returned_objects =
-                                self.collection_objects[&callee_id].returned_objects();
-                            let possible_params: Vec<_> =
-                                (0..self.module.functions[callee_id.idx()].param_types.len())
-                                    .filter(|idx| {
-                                        let object_of_param = self.collection_objects[&callee_id]
-                                            .param_to_object(*idx);
-                                        // Look at parameters that could be the
-                                        // source of the memory object returned
-                                        // by the function.
-                                        object_of_param
-                                            .map(|object_of_param| {
-                                                callee_returned_objects.contains(&object_of_param)
-                                            })
-                                            .unwrap_or(false)
-                                    })
-                                    .collect();
-                            let arg_objects = args
-                                .into_iter()
-                                .enumerate()
-                                .filter(|(idx, _)| possible_params.contains(idx))
-                                .map(|(_, arg)| {
-                                    self.collection_objects[&self.func_id]
-                                        .objects(*arg)
-                                        .into_iter()
-                                })
-                                .flatten();
-
-                            // Dynamically check which of the memory objects
-                            // corresponding to arguments to the call was
-                            // returned by the call. Move that memory object
-                            // into the memory object of the call.
-                            let mut first_obj = true;
-                            for arg_object in arg_objects {
-                                write!(block, "                ")?;
-                                if first_obj {
-                                    first_obj = false;
-                                } else {
-                                    write!(block, "else ")?;
+                        let caller_objects = self.collection_objects[&self.func_id].objects(id);
+                        if !caller_objects.is_empty() {
+                            for (idx, arg) in args.into_iter().enumerate() {
+                                if idx != 0 {
+                                    write!(block, "                else\n")?;
                                 }
-                                write!(block, "if let Some(obj) = obj{}.as_mut() && ::std::boxed::Box::as_mut_ptr(obj) as *mut u8 == {} {{\n", arg_object.idx(), self.get_value(id))?;
                                 write!(
                                     block,
-                                    "                    obj{} = obj{}.take();\n",
-                                    call_object.idx(),
-                                    arg_object.idx()
+                                    "                if call_tmp == arg_tmp{} {{\n",
+                                    idx
+                                )?;
+                                write!(
+                                    block,
+                                    "                    {} = unsafe {{ {}.__take() }};\n",
+                                    self.get_value(id),
+                                    self.get_value(*arg)
                                 )?;
-                                write!(block, "                }}\n")?;
+                                write!(block, "                }}")?;
                             }
                             write!(block, "                else {{\n")?;
-                            write!(block, "                    panic!(\"HERCULES PANIC: Pointer returned from called function doesn't match any known collection objects.\");\n")?;
+                            write!(block, "                    panic!(\"HERCULES PANIC: Pointer returned from device function doesn't match an argument pointer.\");\n")?;
                             write!(block, "                }}\n")?;
+                        } else {
+                            write!(
+                                block,
+                                "                {} = call_tmp;\n",
+                                self.get_value(id)
+                            )?;
                         }
                     }
                     Device::AsyncRust => {
@@ -452,7 +401,7 @@ impl<'a> RTContext<'a> {
                             if self.module.types[self.typing[arg.idx()].idx()].is_primitive() {
                                 write!(block, "{}, ", self.get_value(*arg))?;
                             } else {
-                                write!(block, "{}.take(), ", self.get_value(*arg))?;
+                                write!(block, "unsafe {{ {}.__take() }}, ", self.get_value(*arg))?;
                             }
                         }
                         write!(block, ").await;\n")?;
@@ -603,8 +552,8 @@ impl<'a> RTContext<'a> {
         convert_type(&self.module.types[id.idx()])
     }
 
-    fn get_type_interface(&self, id: TypeID) -> &'static str {
-        convert_type_interface(&self.module.types[id.idx()])
+    fn device_get_type(&self, id: TypeID) -> &'static str {
+        device_convert_type(&self.module.types[id.idx()])
     }
 }
 
@@ -621,18 +570,27 @@ fn convert_type(ty: &Type) -> &'static str {
         Type::UnsignedInteger64 => "u64",
         Type::Float32 => "f32",
         Type::Float64 => "f64",
-        Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8",
+        Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
+            "::hercules_rt::HerculesBox<'a>"
+        }
         _ => panic!(),
     }
 }
 
-/*
- * Collection types are passed to / returned from runtime functions through a
- * wrapper type for ownership tracking reasons.
- */
-fn convert_type_interface(ty: &Type) -> &'static str {
+fn device_convert_type(ty: &Type) -> &'static str {
     match ty {
-        Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "Box<[u8]>",
-        _ => convert_type(ty),
+        Type::Boolean => "bool",
+        Type::Integer8 => "i8",
+        Type::Integer16 => "i16",
+        Type::Integer32 => "i32",
+        Type::Integer64 => "i64",
+        Type::UnsignedInteger8 => "u8",
+        Type::UnsignedInteger16 => "u16",
+        Type::UnsignedInteger32 => "u32",
+        Type::UnsignedInteger64 => "u64",
+        Type::Float32 => "f32",
+        Type::Float64 => "f64",
+        Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8",
+        _ => panic!(),
     }
 }
diff --git a/hercules_rt/Cargo.toml b/hercules_rt/Cargo.toml
new file mode 100644
index 00000000..0bf19adf
--- /dev/null
+++ b/hercules_rt/Cargo.toml
@@ -0,0 +1,7 @@
+[package]
+name = "hercules_rt"
+version = "0.1.0"
+authors = ["Russel Arbore <rarbore2@illinois.edu>"]
+
+[dependencies]
+
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
new file mode 100644
index 00000000..50ac260c
--- /dev/null
+++ b/hercules_rt/src/lib.rs
@@ -0,0 +1,116 @@
+use std::alloc::{alloc, alloc_zeroed, dealloc, Layout};
+use std::marker::PhantomData;
+use std::mem::swap;
+use std::ptr::{copy_nonoverlapping, NonNull};
+use std::slice::from_raw_parts;
+
+/*
+ * An in-memory collection object that can be used by functions compiled by the
+ * Hercules compiler.
+ */
+pub struct HerculesBox<'a> {
+    cpu_shared: Option<NonNull<u8>>,
+    cpu_exclusive: Option<NonNull<u8>>,
+    cpu_owned: Option<NonNull<u8>>,
+
+    size: usize,
+    _phantom: PhantomData<&'a u8>,
+}
+
+impl<'a> HerculesBox<'a> {
+    pub fn from_slice<T>(slice: &'a [T]) -> Self {
+        HerculesBox {
+            cpu_shared: Some(unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }),
+            cpu_exclusive: None,
+            cpu_owned: None,
+            size: slice.len() * size_of::<T>(),
+            _phantom: PhantomData,
+        }
+    }
+
+    pub fn from_slice_mut<T>(slice: &'a mut [T]) -> Self {
+        HerculesBox {
+            cpu_shared: None,
+            cpu_exclusive: Some(unsafe { NonNull::new_unchecked(slice.as_mut_ptr() as *mut u8) }),
+            cpu_owned: None,
+            size: slice.len() * size_of::<T>(),
+            _phantom: PhantomData,
+        }
+    }
+
+    pub fn as_slice<T>(&'a self) -> &'a [T] {
+        assert_eq!(self.size % size_of::<T>(), 0);
+        unsafe { from_raw_parts(self.__cpu_ptr() as *const T, self.size / size_of::<T>()) }
+    }
+
+    unsafe fn into_cpu(&self) -> NonNull<u8> {
+        self.cpu_shared
+            .or(self.cpu_exclusive)
+            .or(self.cpu_owned)
+            .unwrap()
+    }
+
+    unsafe fn into_cpu_mut(&mut self) -> NonNull<u8> {
+        if let Some(ptr) = self.cpu_exclusive.or(self.cpu_owned) {
+            ptr
+        } else {
+            let ptr =
+                NonNull::new(alloc(Layout::from_size_align_unchecked(self.size, 16))).unwrap();
+            copy_nonoverlapping(self.cpu_shared.unwrap().as_ptr(), ptr.as_ptr(), self.size);
+            self.cpu_owned = Some(ptr);
+            self.cpu_shared = None;
+            ptr
+        }
+    }
+
+    pub unsafe fn __zeros(size: u64) -> Self {
+        assert_ne!(size, 0);
+        let size = size as usize;
+        HerculesBox {
+            cpu_shared: None,
+            cpu_exclusive: None,
+            cpu_owned: Some(
+                NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16))).unwrap(),
+            ),
+            size: size,
+            _phantom: PhantomData,
+        }
+    }
+
+    pub unsafe fn __null() -> Self {
+        HerculesBox {
+            cpu_shared: None,
+            cpu_exclusive: None,
+            cpu_owned: None,
+            size: 0,
+            _phantom: PhantomData,
+        }
+    }
+
+    pub unsafe fn __take(&mut self) -> Self {
+        let mut ret = Self::__null();
+        swap(&mut ret, self);
+        ret
+    }
+
+    pub unsafe fn __cpu_ptr(&self) -> *mut u8 {
+        self.into_cpu().as_ptr()
+    }
+
+    pub unsafe fn __cpu_ptr_mut(&mut self) -> *mut u8 {
+        self.into_cpu_mut().as_ptr()
+    }
+}
+
+impl<'a> Drop for HerculesBox<'a> {
+    fn drop(&mut self) {
+        if let Some(ptr) = self.cpu_owned {
+            unsafe {
+                dealloc(
+                    ptr.as_ptr(),
+                    Layout::from_size_align_unchecked(self.size, 16),
+                )
+            }
+        }
+    }
+}
diff --git a/hercules_samples/dot/Cargo.toml b/hercules_samples/dot/Cargo.toml
index f74ab1f6..69cd39e3 100644
--- a/hercules_samples/dot/Cargo.toml
+++ b/hercules_samples/dot/Cargo.toml
@@ -10,6 +10,7 @@ juno_build = { path = "../../juno_build" }
 [dependencies]
 clap = { version = "*", features = ["derive"] }
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 rand = "*"
 async-std = "*"
 with_builtin_macros = "0.1.0"
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 0f5ee518..34d397ef 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -1,31 +1,20 @@
 #![feature(box_as_ptr, let_chains)]
 
 extern crate async_std;
+extern crate hercules_rt;
 extern crate juno_build;
 
-use core::ptr::copy_nonoverlapping;
+use hercules_rt::HerculesBox;
 
 juno_build::juno!("dot");
 
 fn main() {
     async_std::task::block_on(async {
-        let a: Box<[f32]> = Box::new([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]);
-        let b: Box<[f32]> = Box::new([0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]);
-        let mut a_bytes: Box<[u8]> = Box::new([0; 32]);
-        let mut b_bytes: Box<[u8]> = Box::new([0; 32]);
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&a) as *const u8,
-                Box::as_mut_ptr(&mut a_bytes) as *mut u8,
-                32,
-            );
-            copy_nonoverlapping(
-                Box::as_ptr(&b) as *const u8,
-                Box::as_mut_ptr(&mut b_bytes) as *mut u8,
-                32,
-            );
-        };
-        let c = dot(8, a_bytes, b_bytes).await;
+        let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
+        let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
+        let a = HerculesBox::from_slice(&a);
+        let b = HerculesBox::from_slice(&b);
+        let c = dot(8, a, b).await;
         println!("{}", c);
         assert_eq!(c, 70.0);
     });
diff --git a/hercules_samples/matmul/Cargo.toml b/hercules_samples/matmul/Cargo.toml
index d3975c5c..9066c153 100644
--- a/hercules_samples/matmul/Cargo.toml
+++ b/hercules_samples/matmul/Cargo.toml
@@ -10,6 +10,7 @@ juno_build = { path = "../../juno_build" }
 [dependencies]
 clap = { version = "*", features = ["derive"] }
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 rand = "*"
 async-std = "*"
 with_builtin_macros = "0.1.0"
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 93d007c7..34612801 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -1,13 +1,14 @@
 #![feature(box_as_ptr, let_chains)]
 
 extern crate async_std;
+extern crate hercules_rt;
 extern crate juno_build;
 extern crate rand;
 
-use core::ptr::copy_nonoverlapping;
-
 use rand::random;
 
+use hercules_rt::HerculesBox;
+
 juno_build::juno!("matmul");
 
 fn main() {
@@ -15,31 +16,8 @@ fn main() {
         const I: usize = 256;
         const J: usize = 64;
         const K: usize = 128;
-        let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
-        let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
-        let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]);
-        let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]);
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&a) as *const u8,
-                Box::as_mut_ptr(&mut a_bytes) as *mut u8,
-                I * J * 4,
-            );
-            copy_nonoverlapping(
-                Box::as_ptr(&b) as *const u8,
-                Box::as_mut_ptr(&mut b_bytes) as *mut u8,
-                J * K * 4,
-            );
-        };
-        let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
-        let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&c_bytes) as *const u8,
-                Box::as_mut_ptr(&mut c) as *mut u8,
-                I * K * 4,
-            );
-        };
+        let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
+        let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
         let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
         for i in 0..I {
             for k in 0..K {
@@ -48,7 +26,10 @@ fn main() {
                 }
             }
         }
-        assert_eq!(c, correct_c);
+        let a = HerculesBox::from_slice_mut(&mut a);
+        let b = HerculesBox::from_slice_mut(&mut b);
+        let c = matmul(I as u64, J as u64, K as u64, a, b).await;
+        assert_eq!(c.as_slice::<i32>(), &*correct_c);
     });
 }
 
diff --git a/juno_samples/antideps/Cargo.toml b/juno_samples/antideps/Cargo.toml
index 40b4d47c..9bd1d5a0 100644
--- a/juno_samples/antideps/Cargo.toml
+++ b/juno_samples/antideps/Cargo.toml
@@ -13,5 +13,6 @@ juno_build = { path = "../../juno_build" }
 
 [dependencies]
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 with_builtin_macros = "0.1.0"
 async-std = "*"
diff --git a/juno_samples/implicit_clone/Cargo.toml b/juno_samples/implicit_clone/Cargo.toml
index 928fa1f2..b312f5de 100644
--- a/juno_samples/implicit_clone/Cargo.toml
+++ b/juno_samples/implicit_clone/Cargo.toml
@@ -13,5 +13,6 @@ juno_build = { path = "../../juno_build" }
 
 [dependencies]
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 with_builtin_macros = "0.1.0"
 async-std = "*"
diff --git a/juno_samples/matmul/Cargo.toml b/juno_samples/matmul/Cargo.toml
index ea705ddd..8ad95853 100644
--- a/juno_samples/matmul/Cargo.toml
+++ b/juno_samples/matmul/Cargo.toml
@@ -13,6 +13,7 @@ juno_build = { path = "../../juno_build" }
 
 [dependencies]
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 with_builtin_macros = "0.1.0"
 async-std = "*"
 rand = "*"
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 1c5b9d42..11066e8b 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -1,13 +1,14 @@
-#![feature(future_join, box_as_ptr, let_chains)]
+#![feature(box_as_ptr, let_chains)]
 
 extern crate async_std;
+extern crate hercules_rt;
 extern crate juno_build;
 extern crate rand;
 
-use core::ptr::copy_nonoverlapping;
-
 use rand::random;
 
+use hercules_rt::HerculesBox;
+
 juno_build::juno!("matmul");
 
 fn main() {
@@ -17,45 +18,6 @@ fn main() {
         const K: usize = 128;
         let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
         let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
-        let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]);
-        let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]);
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&a) as *const u8,
-                Box::as_mut_ptr(&mut a_bytes) as *mut u8,
-                I * J * 4,
-            );
-            copy_nonoverlapping(
-                Box::as_ptr(&b) as *const u8,
-                Box::as_mut_ptr(&mut b_bytes) as *mut u8,
-                J * K * 4,
-            );
-        };
-        let c_bytes = matmul(
-            I as u64,
-            J as u64,
-            K as u64,
-            a_bytes.clone(),
-            b_bytes.clone(),
-        )
-        .await;
-        let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&c_bytes) as *const u8,
-                Box::as_mut_ptr(&mut c) as *mut u8,
-                I * K * 4,
-            );
-        };
-        let tiled_c_bytes = tiled_64_matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
-        let mut tiled_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&tiled_c_bytes) as *const u8,
-                Box::as_mut_ptr(&mut tiled_c) as *mut u8,
-                I * K * 4,
-            );
-        };
         let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
         for i in 0..I {
             for k in 0..K {
@@ -64,8 +26,18 @@ fn main() {
                 }
             }
         }
-        assert_eq!(c, correct_c);
-        assert_eq!(tiled_c, correct_c);
+        let c = {
+            let a = HerculesBox::from_slice(&a);
+            let b = HerculesBox::from_slice(&b);
+            matmul(I as u64, J as u64, K as u64, a, b).await
+        };
+        let tiled_c = {
+            let a = HerculesBox::from_slice(&a);
+            let b = HerculesBox::from_slice(&b);
+            tiled_64_matmul(I as u64, J as u64, K as u64, a, b).await
+        };
+        assert_eq!(c.as_slice::<i32>(), &*correct_c);
+        assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
     });
 }
 
@@ -73,3 +45,4 @@ fn main() {
 fn matmul_test() {
     main();
 }
+
diff --git a/juno_samples/nested_ccp/Cargo.toml b/juno_samples/nested_ccp/Cargo.toml
index 7ffc13f2..8c9b969d 100644
--- a/juno_samples/nested_ccp/Cargo.toml
+++ b/juno_samples/nested_ccp/Cargo.toml
@@ -13,5 +13,6 @@ juno_build = { path = "../../juno_build" }
 
 [dependencies]
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 with_builtin_macros = "0.1.0"
 async-std = "*"
diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs
index 83132aca..11561eb8 100644
--- a/juno_samples/nested_ccp/src/main.rs
+++ b/juno_samples/nested_ccp/src/main.rs
@@ -1,32 +1,21 @@
 #![feature(box_as_ptr, let_chains)]
 
 extern crate async_std;
+extern crate hercules_rt;
 extern crate juno_build;
 
-use core::ptr::copy_nonoverlapping;
+use hercules_rt::HerculesBox;
 
 juno_build::juno!("nested_ccp");
 
 fn main() {
     async_std::task::block_on(async {
-        let a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]);
-        let b: Box<[i32]> = Box::new([12, 16, 4, 18, 23, 56, 93, 22, 14]);
-        let mut a_bytes: Box<[u8]> = Box::new([0; 12]);
-        let mut b_bytes: Box<[u8]> = Box::new([0; 36]);
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&a) as *const u8,
-                Box::as_mut_ptr(&mut a_bytes) as *mut u8,
-                12,
-            );
-            copy_nonoverlapping(
-                Box::as_ptr(&b) as *const u8,
-                Box::as_mut_ptr(&mut b_bytes) as *mut u8,
-                36,
-            );
-        };
-        let output_example = ccp_example(a_bytes).await;
-        let output_median = median_array(9, b_bytes).await;
+        let mut a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]);
+        let mut b: Box<[i32]> = Box::new([12, 16, 4, 18, 23, 56, 93, 22, 14]);
+        let a = HerculesBox::from_slice_mut(&mut a);
+        let b = HerculesBox::from_slice_mut(&mut b);
+        let output_example = ccp_example(a).await;
+        let output_median = median_array(9, b).await;
         println!("{}", output_example);
         println!("{}", output_median);
         assert_eq!(output_example, 1.0);
diff --git a/juno_samples/simple3/Cargo.toml b/juno_samples/simple3/Cargo.toml
index 201c8d37..8060c5b3 100644
--- a/juno_samples/simple3/Cargo.toml
+++ b/juno_samples/simple3/Cargo.toml
@@ -13,5 +13,6 @@ juno_build = { path = "../../juno_build" }
 
 [dependencies]
 juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
 with_builtin_macros = "0.1.0"
 async-std = "*"
diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs
index 89be5527..8ca54344 100644
--- a/juno_samples/simple3/src/main.rs
+++ b/juno_samples/simple3/src/main.rs
@@ -1,31 +1,20 @@
 #![feature(box_as_ptr, let_chains)]
 
 extern crate async_std;
+extern crate hercules_rt;
 extern crate juno_build;
 
-use core::ptr::copy_nonoverlapping;
+use hercules_rt::HerculesBox;
 
 juno_build::juno!("simple3");
 
 fn main() {
     async_std::task::block_on(async {
-        let a: Box<[u32]> = Box::new([1, 2, 3, 4, 5, 6, 7, 8]);
-        let b: Box<[u32]> = Box::new([8, 7, 6, 5, 4, 3, 2, 1]);
-        let mut a_bytes: Box<[u8]> = Box::new([0; 32]);
-        let mut b_bytes: Box<[u8]> = Box::new([0; 32]);
-        unsafe {
-            copy_nonoverlapping(
-                Box::as_ptr(&a) as *const u8,
-                Box::as_mut_ptr(&mut a_bytes) as *mut u8,
-                32,
-            );
-            copy_nonoverlapping(
-                Box::as_ptr(&b) as *const u8,
-                Box::as_mut_ptr(&mut b_bytes) as *mut u8,
-                32,
-            );
-        };
-        let c = simple3(8, a_bytes, b_bytes).await;
+        let mut a: Box<[u32]> = Box::new([1, 2, 3, 4, 5, 6, 7, 8]);
+        let mut b: Box<[u32]> = Box::new([8, 7, 6, 5, 4, 3, 2, 1]);
+        let a = HerculesBox::from_slice_mut(&mut a);
+        let b = HerculesBox::from_slice_mut(&mut b);
+        let c = simple3(8, a, b).await;
         println!("{}", c);
         assert_eq!(c, 120);
     });
-- 
GitLab