From 0973399d45e4a85e4fd0653f61f6a381e13e122f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 25 Jan 2025 15:05:05 -0600
Subject: [PATCH 01/24] Plan

---
 hercules_cg/src/rt.rs  | 61 ++++++++++++++++++++++++++++++++++++++++++
 hercules_rt/src/lib.rs |  4 +++
 2 files changed, 65 insertions(+)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index d093b2b0..a80e6b07 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -10,6 +10,67 @@ use crate::*;
  * Entry Hercules functions are lowered to async Rust code to achieve easy task
  * level parallelism. This Rust is generated textually, and is included via a
  * procedural macro in the user's Rust code.
+ *
+ * Generating Rust that properly handles memory across devices is tricky. In
+ * particular, the following elements are challenges:
+ *
+ * 1. RT functions may return objects that were not first passed in as
+ *    parameters, via object constants.
+ * 2. We want to allocate as much memory upfront as possible. Our goal is for
+ *    each call to a RT function from host Rust code, there is at most one
+ *    memory allocation per device.
+ * 3. We want to statically determine when cross-device communication is
+ *    necessary - this should be a separate concern from allocation.
+ * 4. At the boundary between host Rust / RT functions, we want to encode
+ *    lifetime rules so that the Hercules API is memory safe.
+ * 5. We want to support efficient composition of Hercules RT functions in both
+ *    synchronous and asynchronous contexts.
+ *
+ * Challenges #1 and #2 require that RT functions themselves do not allocate
+ * memory. Instead, for every entry point function, a "runner" type will be
+ * generated. The host Rust code must instantiate a runner object to invoke an
+ * entry point function. This runner object contains a "backing" memory that is
+ * the single allocation of memories for this function. The runner object can be
+ * used to call the same entry point multiple times, and to the extent possible
+ * the backing memory will be re-used. The size of the backing memory depends on
+ * the dynamic constants passed in to the entry point, so it's lazily allocated
+ * on calls to the entry point to the needed size. To address challenge #4, any
+ * returned objects will be lifetime-bound to the runner object instance,
+ * ensuring that the reference cannot be used after the runner object has de-
+ * allocated the backing memory. This also ensures that the runner can't be run
+ * again while a returned object from a previous iteration is still live, since
+ * the entry point method requires an exclusive reference to the runner.
+ *
+ * Addressing challenge #3 requires we determine what objects are live on what
+ * devices at what times. This can be done fairly easily by coloring nodes by
+ * what device they produce their result on and inserting inter-device transfers
+ * along edges connecting nodes of different colors.
+ *
+ * Addressing challenge #5 requires runner objects for entry points accept and
+ * return objects that are not in their own backing memory and potentially on
+ * any device. For this reason, parameter and return nodes are not necessarily
+ * CPU colored. Instead, runners take and return Hercules reference objects that
+ * refer to memory on some device which have unknown origin. Hercules reference
+ * objects have a lifetime parameter, and when a runner may return a Hercules
+ * reference that refers to its backing memory, the lifetime of the Hercules
+ * reference is the same as the lifetime of the mutable reference of the runner
+ * used in the entry point signature. In other words, the RT backend infers the
+ * proper lifetime bounds on parameter and returned Hercules reference objects
+ * in relation to the runner's self reference using the collection objects
+ * analysis. There are the following kinds of Hercules reference objects:
+ *
+ * - HerculesCPURef
+ * - HerculesCPURefMut
+ * - HerculesCUDARef
+ * - HerculesCUDARefMut
+ *
+ * Essentially, there are types for each device, one for immutable refernences
+ * and one for exclusive references. Mutable references can decay into immutable
+ * references, and immutable references can be cloned. The CPU reference types
+ * can be created from normal Rust references. The CUDA reference types can't be
+ * created from normal Rust references - for that purpose, an additional type is
+ * given, CUDABox, which essentially allows the user to manually allocate and
+ * set some CUDA memory - the user can then take a CUDA reference to that box.
  */
 pub fn rt_codegen<W: Write>(
     func_id: FunctionID,
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 60d3470e..ad989048 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -4,6 +4,10 @@ use std::ptr::{copy_nonoverlapping, NonNull};
 use std::slice::from_raw_parts;
 use std::sync::atomic::{AtomicUsize, Ordering};
 
+/*
+ * Define supporting types, functions, and macros for Hercules RT functions.
+ */
+
 #[cfg(feature = "cuda")]
 extern "C" {
     fn cuda_alloc(size: usize) -> *mut u8;
-- 
GitLab


From a992d07d38617885efa3b5d5c9b211f5a821ec65 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 25 Jan 2025 16:03:18 -0600
Subject: [PATCH 02/24] Re-do hercules_rt

---
 hercules_cg/src/rt.rs     |   7 +-
 hercules_rt/src/lib.rs    | 390 ++++++++++++--------------------------
 hercules_rt/src/rtdefs.cu |  23 +--
 3 files changed, 135 insertions(+), 285 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index a80e6b07..5e40882e 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -44,7 +44,12 @@ use crate::*;
  * Addressing challenge #3 requires we determine what objects are live on what
  * devices at what times. This can be done fairly easily by coloring nodes by
  * what device they produce their result on and inserting inter-device transfers
- * along edges connecting nodes of different colors.
+ * along edges connecting nodes of different colors. Nodes may have multiple
+ * colors if the corresponding value is available on multiple devices. For
+ * example, if a call to a GPU function requires read-only access to an array
+ * and that array was originally on the CPU, after a transfer to the GPU, the
+ * array is still available on the CPU, since the GPU function doesn't modify
+ * the value.
  *
  * Addressing challenge #5 requires runner objects for entry points accept and
  * return objects that are not in their own backing memory and potentially on
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index ad989048..759bed0d 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -1,341 +1,197 @@
-use std::alloc::{alloc, alloc_zeroed, dealloc, Layout};
+use std::alloc::{alloc, dealloc, Layout};
 use std::marker::PhantomData;
 use std::ptr::{copy_nonoverlapping, NonNull};
-use std::slice::from_raw_parts;
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::slice::{from_raw_parts, from_raw_parts_mut};
 
 /*
- * Define supporting types, functions, and macros for Hercules RT functions.
+ * Define supporting types, functions, and macros for Hercules RT functions. For
+ * a more in-depth discussion of the design of these utilities, see hercules_cg/
+ * src/rt.rs (the RT backend).
  */
 
-#[cfg(feature = "cuda")]
-extern "C" {
-    fn cuda_alloc(size: usize) -> *mut u8;
-    fn cuda_alloc_zeroed(size: usize) -> *mut u8;
-    fn cuda_dealloc(ptr: *mut u8);
-    fn copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
-    fn copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
-    fn copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 {
+    alloc(Layout::from_size_align(size, 16).unwrap())
 }
 
-/*
- * Each object needs to get assigned a unique ID.
- */
-static NUM_OBJECTS: AtomicUsize = AtomicUsize::new(1);
+pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) {
+    dealloc(ptr, Layout::from_size_align(size, 16).unwrap())
+}
 
-/*
- * An in-memory collection object that can be used by functions compiled by the
- * Hercules compiler. Memory objects can be in these states:
- *
- * 1. Shared CPU - the object has a shared reference to some CPU memory, usually
- *    from the programmer using the Hercules RT API.
- * 2. Exclusive CPU - the object has an exclusive reference to some CPU memory,
- *    usually from the programmer using the Hercules RT API.
- * 3. Owned CPU - the object owns some allocated CPU memory.
- * 4. Owned GPU - the object owns some allocated GPU memory.
- *
- * A single object can be in some combination of these objects at the same time.
- * Only some combinations are valid, because only some combinations are
- * reachable. Under this assumption, we can model an object's placement as a
- * state machine, where states are combinations of the aforementioned states,
- * and actions are requests on the CPU or GPU, immutably or mutably. Here's the
- * state transition table:
- *
- * Shared CPU = CS
- * Exclusive CPU = CE
- * Owned CPU = CO
- * Owned GPU = GO
- *
- *          CPU     Mut CPU     GPU     Mut GPU
- *       *---------------------------------------
- *  CS   |   CS        CO      CS,GO       GO
- *  CE   |   CE        CE      CE,GO       GO
- *  CO   |   CO        CO      CO,GO       GO
- *  GO   |   CO        CO        GO        GO
- * CS,GO | CS,GO       CO      CS,GO       GO
- * CE,GO | CE,GO       CE      CE,GO       GO
- * CO,GO | CO,GO       CO      CO,GO       GO
- *       |
- *
- * A HerculesBox cannot be cloned, because it may have be a mutable reference to
- * some CPU memory.
- */
-#[derive(Debug)]
-pub struct HerculesBox<'a> {
-    cpu_shared: Option<NonOwned<'a>>,
-    cpu_exclusive: Option<NonOwned<'a>>,
-    cpu_owned: Option<Owned>,
+pub unsafe fn __copy_cpu_to_cpu(dst: *mut u8, src: *mut u8, size: usize) {
+    copy_nonoverlapping(src, dst, size);
+}
 
-    #[cfg(feature = "cuda")]
-    cuda_owned: Option<Owned>,
+#[cfg(feature = "cuda")]
+extern "C" {
+    pub fn __cuda_alloc(size: usize) -> *mut u8;
+    pub fn __cuda_dealloc(ptr: *mut u8, size: usize);
+    pub fn __copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    pub fn __copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
+    pub fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+}
 
+#[derive(Clone, Debug)]
+pub struct HerculesCPURef<'a> {
+    ptr: NonNull<u8>,
     size: usize,
-    id: usize,
+    _phantom: PhantomData<&'a u8>,
 }
 
-#[derive(Clone, Debug)]
-struct NonOwned<'a> {
+#[derive(Debug)]
+pub struct HerculesCPURefMut<'a> {
     ptr: NonNull<u8>,
-    offset: usize,
+    size: usize,
     _phantom: PhantomData<&'a u8>,
 }
 
+#[cfg(feature = "cuda")]
 #[derive(Clone, Debug)]
-struct Owned {
+pub struct HerculesCUDARef<'a> {
     ptr: NonNull<u8>,
-    alloc_size: usize,
-    offset: usize,
+    size: usize,
+    _phantom: PhantomData<&'a u8>,
 }
 
-impl<'b, 'a: 'b> HerculesBox<'a> {
-    pub fn from_slice<T>(slice: &'a [T]) -> Self {
-        let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
-        let size = slice.len() * size_of::<T>();
-        let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed);
-        HerculesBox {
-            cpu_shared: Some(NonOwned {
-                ptr,
-                offset: 0,
-                _phantom: PhantomData,
-            }),
-            cpu_exclusive: None,
-            cpu_owned: None,
-
-            #[cfg(feature = "cuda")]
-            cuda_owned: None,
+#[cfg(feature = "cuda")]
+#[derive(Debug)]
+pub struct HerculesCUDARefMut<'a> {
+    ptr: NonNull<u8>,
+    size: usize,
+    _phantom: PhantomData<&'a u8>,
+}
 
-            size,
-            id,
-        }
-    }
+#[cfg(feature = "cuda")]
+#[derive(Debug)]
+pub struct CUDABox {
+    ptr: NonNull<u8>,
+    size: usize,
+}
 
-    pub fn from_slice_mut<T>(slice: &'a mut [T]) -> Self {
+impl<'a> HerculesCPURef<'a> {
+    pub fn from_slice<T>(slice: &'a [T]) -> Self {
         let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
         let size = slice.len() * size_of::<T>();
-        let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed);
-        HerculesBox {
-            cpu_shared: None,
-            cpu_exclusive: Some(NonOwned {
-                ptr,
-                offset: 0,
-                _phantom: PhantomData,
-            }),
-            cpu_owned: None,
-
-            #[cfg(feature = "cuda")]
-            cuda_owned: None,
-
+        Self {
+            ptr,
             size,
-            id,
+            _phantom: PhantomData,
         }
     }
 
-    pub fn as_slice<T>(&'b mut self) -> &'b [T] {
+    pub fn as_slice<T>(self) -> &'a [T] {
+        let ptr = self.ptr.as_ptr() as *const T;
         assert_eq!(self.size % size_of::<T>(), 0);
-        unsafe { from_raw_parts(self.__cpu_ptr() as *const T, self.size / size_of::<T>()) }
+        assert!(ptr.is_aligned());
+        unsafe { from_raw_parts(ptr, self.size / size_of::<T>()) }
     }
 
-    unsafe fn get_cpu_ptr(&self) -> Option<NonNull<u8>> {
-        self.cpu_owned
-            .as_ref()
-            .map(|obj| obj.ptr.byte_add(obj.offset))
-            .or(self
-                .cpu_exclusive
-                .as_ref()
-                .map(|obj| obj.ptr.byte_add(obj.offset)))
-            .or(self
-                .cpu_shared
-                .as_ref()
-                .map(|obj| obj.ptr.byte_add(obj.offset)))
+    pub unsafe fn __ptr(&self) -> *mut u8 {
+        self.ptr.as_ptr() as *mut u8
     }
 
-    #[cfg(feature = "cuda")]
-    unsafe fn get_cuda_ptr(&self) -> Option<NonNull<u8>> {
-        self.cuda_owned
-            .as_ref()
-            .map(|obj| obj.ptr.byte_add(obj.offset))
+    pub unsafe fn __size(&self) -> usize {
+        self.size
     }
+}
 
-    unsafe fn allocate_cpu(&mut self) -> NonNull<u8> {
-        if let Some(obj) = self.cpu_owned.as_ref() {
-            obj.ptr
-        } else {
-            let ptr =
-                NonNull::new(alloc(Layout::from_size_align_unchecked(self.size, 16))).unwrap();
-            self.cpu_owned = Some(Owned {
-                ptr,
-                alloc_size: self.size,
-                offset: 0,
-            });
-            ptr
+impl<'a> HerculesCPURefMut<'a> {
+    pub fn from_slice<T>(slice: &'a mut [T]) -> Self {
+        let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
+        let size = slice.len() * size_of::<T>();
+        Self {
+            ptr,
+            size,
+            _phantom: PhantomData,
         }
     }
 
-    #[cfg(feature = "cuda")]
-    unsafe fn allocate_cuda(&mut self) -> NonNull<u8> {
-        if let Some(obj) = self.cuda_owned.as_ref() {
-            obj.ptr
-        } else {
-            let ptr = NonNull::new(cuda_alloc(self.size)).unwrap();
-            self.cuda_owned = Some(Owned {
-                ptr,
-                alloc_size: self.size,
-                offset: 0,
-            });
-            ptr
-        }
+    pub fn as_slice<T>(self) -> &'a mut [T] {
+        let ptr = self.ptr.as_ptr() as *mut T;
+        assert_eq!(self.size % size_of::<T>(), 0);
+        assert!(ptr.is_aligned());
+        unsafe { from_raw_parts_mut(ptr, self.size / size_of::<T>()) }
     }
 
-    unsafe fn deallocate_cpu(&mut self) {
-        if let Some(obj) = self.cpu_owned.take() {
-            dealloc(
-                obj.ptr.as_ptr(),
-                Layout::from_size_align_unchecked(obj.alloc_size, 16),
-            );
-        }
+    pub unsafe fn __ptr(&self) -> *mut u8 {
+        self.ptr.as_ptr()
     }
 
-    #[cfg(feature = "cuda")]
-    unsafe fn deallocate_cuda(&mut self) {
-        if let Some(obj) = self.cuda_owned.take() {
-            cuda_dealloc(obj.ptr.as_ptr());
-        }
+    pub unsafe fn __size(&self) -> usize {
+        self.size
     }
+}
 
-    pub unsafe fn __zeros(size: u64) -> Self {
-        let size = size as usize;
-        let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed);
-        HerculesBox {
-            cpu_shared: None,
-            cpu_exclusive: None,
-            cpu_owned: Some(Owned {
-                ptr: NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16)))
-                    .unwrap(),
-                alloc_size: size,
-                offset: 0,
-            }),
-
-            #[cfg(feature = "cuda")]
-            cuda_owned: None,
-
-            size,
-            id,
-        }
+#[cfg(feature = "cuda")]
+impl<'a> HerculesCUDARef<'a> {
+    pub unsafe fn __ptr(&self) -> *mut u8 {
+        self.ptr.as_ptr()
     }
 
-    pub unsafe fn __null() -> Self {
-        HerculesBox {
-            cpu_shared: None,
-            cpu_exclusive: None,
-            cpu_owned: None,
-
-            #[cfg(feature = "cuda")]
-            cuda_owned: None,
-
-            size: 0,
-            id: 0,
-        }
+    pub unsafe fn __size(&self) -> usize {
+        self.size
     }
+}
 
-    pub unsafe fn __cpu_ptr(&mut self) -> *mut u8 {
-        if let Some(ptr) = self.get_cpu_ptr() {
-            return ptr.as_ptr();
-        }
-        #[cfg(feature = "cuda")]
-        {
-            let cuda_ptr = self.get_cuda_ptr().unwrap();
-            let cpu_ptr = self.allocate_cpu();
-            copy_cuda_to_cpu(cpu_ptr.as_ptr(), cuda_ptr.as_ptr(), self.size);
-            return cpu_ptr.as_ptr();
-        }
-        panic!()
+#[cfg(feature = "cuda")]
+impl<'a> HerculesCUDARefMut<'a> {
+    pub unsafe fn __ptr(&self) -> *mut u8 {
+        self.ptr.as_ptr()
     }
 
-    pub unsafe fn __cpu_ptr_mut(&mut self) -> *mut u8 {
-        let cpu_ptr = self.__cpu_ptr();
-        if Some(cpu_ptr) == self.cpu_shared.as_ref().map(|obj| obj.ptr.as_ptr()) {
-            self.allocate_cpu();
-            copy_nonoverlapping(
-                cpu_ptr,
-                self.cpu_owned.as_ref().unwrap().ptr.as_ptr(),
-                self.size,
-            );
-        }
-        self.cpu_shared = None;
-        #[cfg(feature = "cuda")]
-        self.deallocate_cuda();
-        cpu_ptr
+    pub unsafe fn __size(&self) -> usize {
+        self.size
     }
+}
 
-    #[cfg(feature = "cuda")]
-    pub unsafe fn __cuda_ptr(&mut self) -> *mut u8 {
-        if let Some(ptr) = self.get_cuda_ptr() {
-            ptr.as_ptr()
-        } else {
-            let cpu_ptr = self.get_cpu_ptr().unwrap();
-            let cuda_ptr = self.allocate_cuda();
-            copy_cpu_to_cuda(cuda_ptr.as_ptr(), cpu_ptr.as_ptr(), self.size);
-            cuda_ptr.as_ptr()
+#[cfg(feature = "cuda")]
+impl CUDABox {
+    pub fn from_cpu_ref(cpu_ref: HerculesCPURef) -> Self {
+        unsafe {
+            let size = cpu_ref.size;
+            let ptr = NonNull::new(__cuda_alloc(size)).unwrap();
+            __copy_cpu_to_cuda(ptr.as_ptr(), cpu_ref.ptr.as_ptr(), size);
+            Self { ptr, size }
         }
     }
 
-    #[cfg(feature = "cuda")]
-    pub unsafe fn __cuda_ptr_mut(&mut self) -> *mut u8 {
-        let cuda_ptr = self.__cuda_ptr();
-        self.cpu_shared = None;
-        self.cpu_exclusive = None;
-        self.deallocate_cpu();
-        cuda_ptr
-    }
-
-    pub unsafe fn __clone(&self) -> Self {
-        Self {
-            cpu_shared: self.cpu_shared.clone(),
-            cpu_exclusive: self.cpu_exclusive.clone(),
-            cpu_owned: self.cpu_owned.clone(),
-            #[cfg(feature = "cuda")]
-            cuda_owned: self.cuda_owned.clone(),
-            size: self.size,
-            id: self.id,
+    pub fn from_cuda_ref(cuda_ref: HerculesCUDARef) -> Self {
+        unsafe {
+            let size = cuda_ref.size;
+            let ptr = NonNull::new(__cuda_alloc(size)).unwrap();
+            __copy_cuda_to_cuda(ptr.as_ptr(), cuda_ref.ptr.as_ptr(), size);
+            Self { ptr, size }
         }
     }
 
-    pub unsafe fn __forget(&mut self) {
-        self.cpu_owned = None;
-        #[cfg(feature = "cuda")]
-        {
-            self.cuda_owned = None;
+    pub fn get_ref<'a>(&'a self) -> HerculesCUDARef<'a> {
+        HerculesCUDARef {
+            ptr: self.ptr,
+            size: self.size,
+            _phantom: PhantomData,
         }
     }
 
-    pub unsafe fn __offset(&mut self, offset: u64, size: u64) {
-        if let Some(obj) = self.cpu_shared.as_mut() {
-            obj.offset += offset as usize;
-        }
-        if let Some(obj) = self.cpu_exclusive.as_mut() {
-            obj.offset += offset as usize;
-        }
-        if let Some(obj) = self.cpu_owned.as_mut() {
-            obj.offset += offset as usize;
-        }
-        #[cfg(feature = "cuda")]
-        if let Some(obj) = self.cuda_owned.as_mut() {
-            obj.offset += offset as usize;
+    pub fn get_ref_mut<'a>(&'a mut self) -> HerculesCUDARefMut<'a> {
+        HerculesCUDARefMut {
+            ptr: self.ptr,
+            size: self.size,
+            _phantom: PhantomData,
         }
-        self.size = size as usize;
     }
+}
 
-    pub unsafe fn __cmp_ids(&self, other: &HerculesBox<'_>) -> bool {
-        self.id == other.id
+#[cfg(feature = "cuda")]
+impl Clone for CUDABox {
+    fn clone(&self) -> Self {
+        Self::from_cuda_ref(self.get_ref())
     }
 }
 
-impl<'a> Drop for HerculesBox<'a> {
+#[cfg(feature = "cuda")]
+impl Drop for CUDABox {
     fn drop(&mut self) {
         unsafe {
-            self.deallocate_cpu();
-            #[cfg(feature = "cuda")]
-            self.deallocate_cuda();
+            __cuda_dealloc(self.ptr.as_ptr(), self.size);
         }
     }
 }
diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu
index b7378d81..ab67ec98 100644
--- a/hercules_rt/src/rtdefs.cu
+++ b/hercules_rt/src/rtdefs.cu
@@ -1,5 +1,5 @@
 extern "C" {
-	void *cuda_alloc(size_t size) {
+	void *__cuda_alloc(size_t size) {
 		void *ptr = NULL;
 		cudaError_t res = cudaMalloc(&ptr, size);
 		if (res != cudaSuccess) {
@@ -8,31 +8,20 @@ extern "C" {
 		return ptr;
 	}
 	
-	void *cuda_alloc_zeroed(size_t size) {
-		void *ptr = cuda_alloc(size);
-		if (!ptr) {
-			return NULL;
-		}
-		cudaError_t res = cudaMemset(ptr, 0, size);
-		if (res != cudaSuccess) {
-			return NULL;
-		}
-		return ptr;
-	}
-	
-	void cuda_dealloc(void *ptr) {
+	void __cuda_dealloc(void *ptr, size_t size) {
+		(void) size;
 		cudaFree(ptr);
 	}
 	
-	void copy_cpu_to_cuda(void *dst, void *src, size_t size) {
+	void __copy_cpu_to_cuda(void *dst, void *src, size_t size) {
 		cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
 	}
 	
-	void copy_cuda_to_cpu(void *dst, void *src, size_t size) {
+	void __copy_cuda_to_cpu(void *dst, void *src, size_t size) {
 		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
 	}
 	
-	void copy_cuda_to_cuda(void *dst, void *src, size_t size) {
+	void __copy_cuda_to_cuda(void *dst, void *src, size_t size) {
 		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
 	}
 }
-- 
GitLab


From 4b107497702b0df571fee75ca0a584cabe5c1478 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 25 Jan 2025 16:35:46 -0600
Subject: [PATCH 03/24] Move analysis

---
 hercules_ir/src/dom.rs                | 54 +++++++++++++++++++++++++++
 hercules_ir/src/fork_join_analysis.rs | 54 ---------------------------
 2 files changed, 54 insertions(+), 54 deletions(-)

diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs
index 2c0f085d..2cfd3b09 100644
--- a/hercules_ir/src/dom.rs
+++ b/hercules_ir/src/dom.rs
@@ -1,5 +1,7 @@
 use std::collections::HashMap;
 
+use bitvec::prelude::*;
+
 use crate::*;
 
 /*
@@ -304,3 +306,55 @@ pub fn postdominator(subgraph: Subgraph, fake_root: NodeID) -> DomTree {
     // root as the root of the dominator analysis.
     dominator(&reversed_subgraph, fake_root)
 }
+
+/*
+ * Check if a data node dominates a control node. This involves checking all
+ * immediate control uses to see if they dominate the queried control node.
+ */
+pub fn does_data_dom_control(
+    function: &Function,
+    data: NodeID,
+    control: NodeID,
+    dom: &DomTree,
+) -> bool {
+    let mut stack = vec![data];
+    let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
+    visited.set(data.idx(), true);
+
+    while let Some(pop) = stack.pop() {
+        let node = &function.nodes[pop.idx()];
+
+        let imm_control = match node {
+            Node::Phi { control, data: _ }
+            | Node::Reduce {
+                control,
+                init: _,
+                reduct: _,
+            }
+            | Node::Call {
+                control,
+                function: _,
+                dynamic_constants: _,
+                args: _,
+            } => Some(*control),
+            _ if node.is_control() => Some(pop),
+            _ => {
+                for u in get_uses(node).as_ref() {
+                    if !visited[u.idx()] {
+                        visited.set(u.idx(), true);
+                        stack.push(*u);
+                    }
+                }
+                None
+            }
+        };
+
+        if let Some(imm_control) = imm_control
+            && !dom.does_dom(imm_control, control)
+        {
+            return false;
+        }
+    }
+
+    true
+}
diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs
index 130bc2ed..1d089d76 100644
--- a/hercules_ir/src/fork_join_analysis.rs
+++ b/hercules_ir/src/fork_join_analysis.rs
@@ -1,7 +1,5 @@
 use std::collections::{HashMap, HashSet};
 
-use bitvec::prelude::*;
-
 use crate::*;
 
 /*
@@ -75,55 +73,3 @@ pub fn compute_fork_join_nesting(
         })
         .collect()
 }
-
-/*
- * Check if a data node dominates a control node. This involves checking all
- * immediate control uses to see if they dominate the queried control node.
- */
-pub fn does_data_dom_control(
-    function: &Function,
-    data: NodeID,
-    control: NodeID,
-    dom: &DomTree,
-) -> bool {
-    let mut stack = vec![data];
-    let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
-    visited.set(data.idx(), true);
-
-    while let Some(pop) = stack.pop() {
-        let node = &function.nodes[pop.idx()];
-
-        let imm_control = match node {
-            Node::Phi { control, data: _ }
-            | Node::Reduce {
-                control,
-                init: _,
-                reduct: _,
-            }
-            | Node::Call {
-                control,
-                function: _,
-                dynamic_constants: _,
-                args: _,
-            } => Some(*control),
-            _ if node.is_control() => Some(pop),
-            _ => {
-                for u in get_uses(node).as_ref() {
-                    if !visited[u.idx()] {
-                        visited.set(u.idx(), true);
-                        stack.push(*u);
-                    }
-                }
-                None
-            }
-        };
-
-        if let Some(imm_control) = imm_control
-            && !dom.does_dom(imm_control, control)
-        {
-            return false;
-        }
-    }
-
-    true
-}
-- 
GitLab


From efa307a4c9169e8fecefd4a118dffb977878d068 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 26 Jan 2025 11:01:28 -0600
Subject: [PATCH 04/24] Object device demand analysis

---
 hercules_cg/src/device.rs            |  22 -----
 hercules_cg/src/lib.rs               |   2 -
 hercules_cg/src/rt.rs                |   8 +-
 hercules_ir/src/device.rs            | 132 +++++++++++++++++++++++++++
 hercules_ir/src/ir.rs                |   2 +-
 hercules_ir/src/lib.rs               |   2 +
 hercules_opt/src/device_placement.rs |   3 +
 hercules_opt/src/lib.rs              |   2 +
 8 files changed, 142 insertions(+), 31 deletions(-)
 delete mode 100644 hercules_cg/src/device.rs
 create mode 100644 hercules_ir/src/device.rs
 create mode 100644 hercules_opt/src/device_placement.rs

diff --git a/hercules_cg/src/device.rs b/hercules_cg/src/device.rs
deleted file mode 100644
index 866fa6ad..00000000
--- a/hercules_cg/src/device.rs
+++ /dev/null
@@ -1,22 +0,0 @@
-use hercules_ir::*;
-
-/*
- * Top level function to definitively place functions onto devices. A function
- * may store a device placement, but only optionally - this function assigns
- * devices to the rest of the functions.
- */
-pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec<Device> {
-    let mut devices = vec![];
-
-    for (idx, function) in functions.into_iter().enumerate() {
-        if let Some(device) = function.device {
-            devices.push(device);
-        } else if function.entry || callgraph.num_callees(FunctionID::new(idx)) != 0 {
-            devices.push(Device::AsyncRust);
-        } else {
-            devices.push(Device::LLVM);
-        }
-    }
-
-    devices
-}
diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 8aaab214..47039737 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -1,11 +1,9 @@
 #![feature(if_let_guard, let_chains)]
 
 pub mod cpu;
-pub mod device;
 pub mod rt;
 
 pub use crate::cpu::*;
-pub use crate::device::*;
 pub use crate::rt::*;
 
 use hercules_ir::*;
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 5e40882e..edf2273d 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -44,12 +44,8 @@ use crate::*;
  * Addressing challenge #3 requires we determine what objects are live on what
  * devices at what times. This can be done fairly easily by coloring nodes by
  * what device they produce their result on and inserting inter-device transfers
- * along edges connecting nodes of different colors. Nodes may have multiple
- * colors if the corresponding value is available on multiple devices. For
- * example, if a call to a GPU function requires read-only access to an array
- * and that array was originally on the CPU, after a transfer to the GPU, the
- * array is still available on the CPU, since the GPU function doesn't modify
- * the value.
+ * along edges connecting nodes of different colors. Nodes can only have a
+ * single color - this is enforced by the DevicePlacement pass.
  *
  * Addressing challenge #5 requires runner objects for entry points accept and
  * return objects that are not in their own backing memory and potentially on
diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs
new file mode 100644
index 00000000..5a6e13a0
--- /dev/null
+++ b/hercules_ir/src/device.rs
@@ -0,0 +1,132 @@
+use std::collections::BTreeSet;
+use std::mem::take;
+
+use crate::*;
+
+/*
+ * Top level function to definitively place functions onto devices. A function
+ * may store a device placement, but only optionally - this function assigns
+ * devices to the rest of the functions.
+ */
+pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec<Device> {
+    let mut devices = vec![];
+
+    for (idx, function) in functions.into_iter().enumerate() {
+        if let Some(device) = function.device {
+            devices.push(device);
+        } else if function.entry || callgraph.num_callees(FunctionID::new(idx)) != 0 {
+            devices.push(Device::AsyncRust);
+        } else {
+            devices.push(Device::LLVM);
+        }
+    }
+
+    devices
+}
+
+pub type ObjectDeviceDemands = Vec<Vec<BTreeSet<Device>>>;
+
+/*
+ * This analysis figures out which device each collection object may be on. At
+ * first, an object may need to be on different devices at different times. This
+ * is fine during optimization.
+ */
+pub fn object_device_demands(
+    functions: &Vec<Function>,
+    types: &Vec<Type>,
+    typing: &ModuleTyping,
+    callgraph: &CallGraph,
+    objects: &CollectionObjects,
+    devices: &Vec<Device>,
+) -> ObjectDeviceDemands {
+    // An object is "demanded" on a device when:
+    // 1. The object is used by a primitive read node or write node in a device
+    //    function. This includes objects on the `data` input to write nodes.
+    //    Non-primitive reads don't demand an object on a device since they are
+    //    lowered to pointer math and no actual memory transfers.
+    // 2. The object is passed as input to a call node where the corresponding
+    //    object in the callee is demanded on a device.
+    // 3. The object is returned from a call node where the corresponding object
+    //    in the callee is demanded on a device.
+    // Note that reads and writes in a RT function don't induce a device demand.
+    // This is because RT functions can  call device functions as necessary to
+    // arbitrarily move data onto / off of devices (though this may be slow).
+    // Traverse the functions in a module in reverse topological order, since
+    // the analysis of a function depends on all functions it calls.
+    let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()];
+    let topo = callgraph.topo();
+
+    for func_id in topo {
+        let function = &functions[func_id.idx()];
+        let typing = &typing[func_id.idx()];
+        let device = devices[func_id.idx()];
+
+        demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new());
+        match device {
+            Device::LLVM | Device::CUDA => {
+                for (idx, node) in function.nodes.iter().enumerate() {
+                    // Condition #1.
+                    match node {
+                        Node::Read {
+                            collect,
+                            indices: _,
+                        } if types[typing[idx].idx()].is_primitive() => {
+                            for object in objects[&func_id].objects(*collect) {
+                                demands[func_id.idx()][object.idx()].insert(device);
+                            }
+                        }
+                        Node::Write {
+                            collect,
+                            data,
+                            indices: _,
+                        } => {
+                            for object in objects[&func_id]
+                                .objects(*collect)
+                                .into_iter()
+                                .chain(objects[&func_id].objects(*data).into_iter())
+                            {
+                                demands[func_id.idx()][object.idx()].insert(device);
+                            }
+                        }
+                        _ => {}
+                    }
+                }
+            }
+            Device::AsyncRust => {
+                for (idx, node) in function.nodes.iter().enumerate() {
+                    if let Node::Call {
+                        control: _,
+                        function: callee,
+                        dynamic_constants: _,
+                        args,
+                    } = node
+                    {
+                        // Condition #2.
+                        for (param_idx, arg) in args.into_iter().enumerate() {
+                            if let Some(callee_obj) = objects[callee].param_to_object(param_idx) {
+                                let callee_demands =
+                                    take(&mut demands[callee.idx()][callee_obj.idx()]);
+                                for object in objects[&func_id].objects(*arg) {
+                                    demands[func_id.idx()][object.idx()]
+                                        .extend(callee_demands.iter());
+                                }
+                                demands[callee.idx()][callee_obj.idx()] = callee_demands;
+                            }
+                        }
+
+                        // Condition #3.
+                        for callee_obj in objects[callee].returned_objects() {
+                            let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]);
+                            for object in objects[&func_id].objects(NodeID::new(idx)) {
+                                demands[func_id.idx()][object.idx()].extend(callee_demands.iter());
+                            }
+                            demands[callee.idx()][callee_obj.idx()] = callee_demands;
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    demands
+}
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 46d35f25..5577228f 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -332,7 +332,7 @@ pub enum Schedule {
  * The authoritative enumeration of supported backends. Multiple backends may
  * correspond to the same kind of hardware.
  */
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
 pub enum Device {
     LLVM,
     CUDA,
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index 32bbf631..85dc277f 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -11,6 +11,7 @@ pub mod callgraph;
 pub mod collections;
 pub mod dataflow;
 pub mod def_use;
+pub mod device;
 pub mod dom;
 pub mod dot;
 pub mod fork_join_analysis;
@@ -26,6 +27,7 @@ pub use crate::callgraph::*;
 pub use crate::collections::*;
 pub use crate::dataflow::*;
 pub use crate::def_use::*;
+pub use crate::device::*;
 pub use crate::dom::*;
 pub use crate::dot::*;
 pub use crate::fork_join_analysis::*;
diff --git a/hercules_opt/src/device_placement.rs b/hercules_opt/src/device_placement.rs
new file mode 100644
index 00000000..2badd69d
--- /dev/null
+++ b/hercules_opt/src/device_placement.rs
@@ -0,0 +1,3 @@
+use hercules_ir::ir::*;
+
+use crate::*;
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index 0b10bdae..4a90f698 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -4,6 +4,7 @@ pub mod ccp;
 pub mod crc;
 pub mod dce;
 pub mod delete_uncalled;
+pub mod device_placement;
 pub mod editor;
 pub mod float_collections;
 pub mod fork_concat_split;
@@ -27,6 +28,7 @@ pub use crate::ccp::*;
 pub use crate::crc::*;
 pub use crate::dce::*;
 pub use crate::delete_uncalled::*;
+pub use crate::device_placement::*;
 pub use crate::editor::*;
 pub use crate::float_collections::*;
 pub use crate::fork_concat_split::*;
-- 
GitLab


From 25987ace02b4e70b5c023736afb8d6669df6d0e1 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 26 Jan 2025 11:39:43 -0600
Subject: [PATCH 05/24] Integrate analyses into pm

---
 hercules_cg/src/cpu.rs    |  7 ++++++
 hercules_cg/src/rt.rs     |  9 ++++++++
 hercules_ir/src/device.rs |  3 ++-
 juno_scheduler/src/pm.rs  | 48 +++++++++++++++++++++++++++++++++++----
 4 files changed, 61 insertions(+), 6 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 3750c4f6..47016dda 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -22,8 +22,15 @@ pub fn cpu_codegen<W: Write>(
     typing: &Vec<TypeID>,
     control_subgraph: &Subgraph,
     bbs: &BasicBlocks,
+    object_device_demands: &FunctionObjectDeviceDemands,
     w: &mut W,
 ) -> Result<(), Error> {
+    // Check that every object that has a demand in this function are only
+    // demanded on the CPU.
+    for demands in object_device_demands {
+        assert!(demands.is_empty() || (demands.len() == 1 && demands.contains(&Device::LLVM)))
+    }
+
     let ctx = CPUContext {
         function,
         types,
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index edf2273d..97cbf608 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -82,8 +82,15 @@ pub fn rt_codegen<W: Write>(
     collection_objects: &CollectionObjects,
     callgraph: &CallGraph,
     devices: &Vec<Device>,
+    object_device_demands: &FunctionObjectDeviceDemands,
     w: &mut W,
 ) -> Result<(), Error> {
+    // Check that every object that has a demand in this function only has a
+    // demand from one device.
+    for demands in object_device_demands {
+        assert!(demands.len() <= 1);
+    }
+
     let ctx = RTContext {
         func_id,
         module,
@@ -93,6 +100,7 @@ pub fn rt_codegen<W: Write>(
         collection_objects,
         callgraph,
         devices,
+        object_device_demands,
     };
     ctx.codegen_function(w)
 }
@@ -106,6 +114,7 @@ struct RTContext<'a> {
     collection_objects: &'a CollectionObjects,
     callgraph: &'a CallGraph,
     devices: &'a Vec<Device>,
+    object_device_demands: &'a FunctionObjectDeviceDemands,
 }
 
 impl<'a> RTContext<'a> {
diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs
index 5a6e13a0..46a1af02 100644
--- a/hercules_ir/src/device.rs
+++ b/hercules_ir/src/device.rs
@@ -24,7 +24,8 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec
     devices
 }
 
-pub type ObjectDeviceDemands = Vec<Vec<BTreeSet<Device>>>;
+pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>;
+pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>;
 
 /*
  * This analysis figures out which device each collection object may be on. At
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 452c1995..b35fe2c1 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -17,7 +17,6 @@ use juno_utils::stringtab::StringTable;
 
 use std::cell::RefCell;
 use std::collections::{BTreeSet, HashMap, HashSet};
-use std::env::temp_dir;
 use std::fmt;
 use std::fs::File;
 use std::io::Write;
@@ -189,6 +188,8 @@ struct PassManager {
     pub bbs: Option<Vec<BasicBlocks>>,
     pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
+    pub devices: Option<Vec<Device>>,
+    pub object_device_demands: Option<ObjectDeviceDemands>,
 }
 
 impl PassManager {
@@ -220,6 +221,8 @@ impl PassManager {
             bbs: None,
             collection_objects: None,
             callgraph: None,
+            devices: None,
+            object_device_demands: None,
         }
     }
 
@@ -406,6 +409,35 @@ impl PassManager {
         }
     }
 
+    pub fn make_devices(&mut self) {
+        if self.devices.is_none() {
+            self.make_callgraph();
+            let callgraph = self.callgraph.as_ref().unwrap();
+            self.devices = Some(device_placement(&self.functions, callgraph));
+        }
+    }
+
+    pub fn make_object_device_demands(&mut self) {
+        if self.object_device_demands.is_none() {
+            self.make_typing();
+            self.make_callgraph();
+            self.make_collection_objects();
+            self.make_devices();
+            let typing = self.typing.as_ref().unwrap();
+            let callgraph = self.callgraph.as_ref().unwrap();
+            let collection_objects = self.collection_objects.as_ref().unwrap();
+            let devices = self.devices.as_ref().unwrap();
+            self.object_device_demands = Some(object_device_demands(
+                &self.functions,
+                &self.types.borrow(),
+                typing,
+                callgraph,
+                collection_objects,
+                devices,
+            ));
+        }
+    }
+
     pub fn delete_gravestones(&mut self) {
         for func in self.functions.iter_mut() {
             func.delete_gravestones();
@@ -427,6 +459,8 @@ impl PassManager {
         self.bbs = None;
         self.collection_objects = None;
         self.callgraph = None;
+        self.devices = None;
+        self.object_device_demands = None;
     }
 
     fn with_mod<B, F>(&mut self, mut f: F) -> B
@@ -464,6 +498,8 @@ impl PassManager {
         self.make_control_subgraphs();
         self.make_collection_objects();
         self.make_callgraph();
+        self.make_devices();
+        self.make_object_device_demands();
 
         let PassManager {
             functions,
@@ -476,6 +512,8 @@ impl PassManager {
             bbs: Some(bbs),
             collection_objects: Some(collection_objects),
             callgraph: Some(callgraph),
+            devices: Some(devices),
+            object_device_demands: Some(object_device_demands),
             ..
         } = self
         else {
@@ -493,8 +531,6 @@ impl PassManager {
             labels: labels.into_inner(),
         };
 
-        let devices = device_placement(&module.functions, &callgraph);
-
         let mut rust_rt = String::new();
         let mut llvm_ir = String::new();
         for idx in 0..module.functions.len() {
@@ -507,6 +543,7 @@ impl PassManager {
                     &typing[idx],
                     &control_subgraphs[idx],
                     &bbs[idx],
+                    &object_device_demands[idx],
                     &mut llvm_ir,
                 )
                 .map_err(|e| SchedulerError::PassError {
@@ -522,6 +559,7 @@ impl PassManager {
                     &collection_objects,
                     &callgraph,
                     &devices,
+                    &object_device_demands[idx],
                     &mut rust_rt,
                 )
                 .map_err(|e| SchedulerError::PassError {
@@ -1178,10 +1216,10 @@ fn run_pass(
 
             pm.make_typing();
             pm.make_callgraph();
+            pm.make_devices();
             let typing = pm.typing.take().unwrap();
             let callgraph = pm.callgraph.take().unwrap();
-
-            let devices = device_placement(&pm.functions, &callgraph);
+            let devices = pm.devices.take().unwrap();
 
             let mut editors = build_editors(pm);
             float_collections(&mut editors, &typing, &callgraph, &devices);
-- 
GitLab


From 5b8fcfb1c96e91c349104737f21e3a6c73c84290 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 26 Jan 2025 17:05:03 -0600
Subject: [PATCH 06/24] Analyze node colors in GCM

---
 hercules_cg/src/cpu.rs   |  7 ----
 hercules_cg/src/rt.rs    | 10 +----
 hercules_opt/src/gcm.rs  | 90 +++++++++++++++++++++++++++++++++++++---
 hercules_rt/src/lib.rs   | 16 +++++++
 juno_scheduler/src/pm.rs | 58 ++++++++++++--------------
 5 files changed, 129 insertions(+), 52 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 47016dda..3750c4f6 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -22,15 +22,8 @@ pub fn cpu_codegen<W: Write>(
     typing: &Vec<TypeID>,
     control_subgraph: &Subgraph,
     bbs: &BasicBlocks,
-    object_device_demands: &FunctionObjectDeviceDemands,
     w: &mut W,
 ) -> Result<(), Error> {
-    // Check that every object that has a demand in this function are only
-    // demanded on the CPU.
-    for demands in object_device_demands {
-        assert!(demands.is_empty() || (demands.len() == 1 && demands.contains(&Device::LLVM)))
-    }
-
     let ctx = CPUContext {
         function,
         types,
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 97cbf608..e281fecb 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -45,7 +45,7 @@ use crate::*;
  * devices at what times. This can be done fairly easily by coloring nodes by
  * what device they produce their result on and inserting inter-device transfers
  * along edges connecting nodes of different colors. Nodes can only have a
- * single color - this is enforced by the DevicePlacement pass.
+ * single color - this is enforced by the GCM pass.
  *
  * Addressing challenge #5 requires runner objects for entry points accept and
  * return objects that are not in their own backing memory and potentially on
@@ -65,7 +65,7 @@ use crate::*;
  * - HerculesCUDARef
  * - HerculesCUDARefMut
  *
- * Essentially, there are types for each device, one for immutable refernences
+ * Essentially, there are types for each device, one for immutable references
  * and one for exclusive references. Mutable references can decay into immutable
  * references, and immutable references can be cloned. The CPU reference types
  * can be created from normal Rust references. The CUDA reference types can't be
@@ -85,12 +85,6 @@ pub fn rt_codegen<W: Write>(
     object_device_demands: &FunctionObjectDeviceDemands,
     w: &mut W,
 ) -> Result<(), Error> {
-    // Check that every object that has a demand in this function only has a
-    // demand from one device.
-    for demands in object_device_demands {
-        assert!(demands.len() <= 1);
-    }
-
     let ctx = RTContext {
         func_id,
         module,
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 1323d5a0..d4f4a92d 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -35,8 +35,20 @@ use crate::*;
  * liveness analysis result, so every spill restarts the process of checking for
  * spills. Once no more spills are found, the process terminates. When a spill
  * is found, the basic block assignments, and all the other analyses, are not
- * necessarily valid anymore, so this function is called in a loop in pass.rs
- * until no more spills are found.
+ * necessarily valid anymore, so this function is called in a loop in the pass
+ * manager until no more spills are found.
+ *
+ * GCM is additionally complicated by the need to generate code that references
+ * objects across multiple devices. In particular, GCM makes sure that every
+ * object lives on exactly one device, so that references to that object always
+ * live on a single device. Additionally, GCM makes sure that the objects that a
+ * node may produce are all on the same device, so that a pointer produced by,
+ * for example, a select node can only refer to memory on a single device. Extra
+ * collection constants and potentially inter-device copies are inserted as
+ * necessary to make sure this is true - an inter-device copy is represented by
+ * a write where the `collect` and `data` inputs are on different devices. This
+ * is only valid in RT functions - it is asserted that this isn't necessary in
+ * device functions. This process "colors" the nodes in the function.
  */
 pub fn gcm(
     editor: &mut FunctionEditor,
@@ -48,6 +60,8 @@ pub fn gcm(
     fork_join_map: &HashMap<NodeID, NodeID>,
     loops: &LoopTree,
     objects: &CollectionObjects,
+    devices: &Vec<Device>,
+    object_device_demands: &FunctionObjectDeviceDemands,
 ) -> Option<BasicBlocks> {
     let bbs = basic_blocks(
         editor.func(),
@@ -59,11 +73,40 @@ pub fn gcm(
         fork_join_map,
         objects,
     );
+
     if spill_clones(editor, typing, control_subgraph, objects, &bbs) {
-        None
-    } else {
-        Some(bbs)
+        return None;
     }
+
+    let func_id = editor.func_id();
+    let Some(node_colors) = color_nodes(
+        editor,
+        reverse_postorder,
+        &objects[&func_id],
+        &object_device_demands,
+    ) else {
+        return None;
+    };
+
+    let device = devices[func_id.idx()];
+    match device {
+        Device::LLVM | Device::CUDA => {
+            // Check that every object that has a demand in this function are
+            // only demanded on this device.
+            for demands in object_device_demands {
+                assert!(demands.is_empty() || (demands.len() == 1 && demands.contains(&device)))
+            }
+        }
+        Device::AsyncRust => {
+            // Check that every object that has a demand in this function only
+            // has a demand from one device.
+            for demands in object_device_demands {
+                assert!(demands.len() <= 1);
+            }
+        }
+    }
+
+    Some(bbs)
 }
 
 /*
@@ -938,3 +981,40 @@ fn liveness_dataflow(
         }
     }
 }
+
+/*
+ * Determine what device each node produces a collection onto. Insert inter-
+ * device clones when a single node may potentially be on different devices.
+ */
+fn color_nodes(
+    editor: &mut FunctionEditor,
+    reverse_postorder: &Vec<NodeID>,
+    objects: &FunctionCollectionObjects,
+    object_device_demands: &FunctionObjectDeviceDemands,
+) -> Option<BTreeMap<NodeID, Device>> {
+    // First, try to give each node a single color.
+    let mut colors = BTreeMap::new();
+    let mut bad_node = None;
+    'nodes: for id in reverse_postorder {
+        let mut device = None;
+        for object in objects.objects(*id) {
+            for demand in object_device_demands[object.idx()].iter() {
+                if let Some(device) = device
+                    && device != *demand
+                {
+                    bad_node = Some(id);
+                    break 'nodes;
+                }
+                device = Some(*demand);
+            }
+        }
+        if let Some(device) = device {
+            colors.insert(*id, device);
+        }
+    }
+    if bad_node.is_some() {
+        todo!("Deal with inter-device demands.")
+    }
+
+    Some(colors)
+}
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 759bed0d..c244a611 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -112,6 +112,14 @@ impl<'a> HerculesCPURefMut<'a> {
         unsafe { from_raw_parts_mut(ptr, self.size / size_of::<T>()) }
     }
 
+    pub fn as_ref(self) -> HerculesCPURef<'a> {
+        HerculesCPURef {
+            ptr: self.ptr,
+            size: self.size,
+            _phantom: PhantomData,
+        }
+    }
+
     pub unsafe fn __ptr(&self) -> *mut u8 {
         self.ptr.as_ptr()
     }
@@ -134,6 +142,14 @@ impl<'a> HerculesCUDARef<'a> {
 
 #[cfg(feature = "cuda")]
 impl<'a> HerculesCUDARefMut<'a> {
+    pub fn as_ref(self) -> HerculesCUDARef<'a> {
+        HerculesCUDARef {
+            ptr: self.ptr,
+            size: self.size,
+            _phantom: PhantomData,
+        }
+    }
+
     pub unsafe fn __ptr(&self) -> *mut u8 {
         self.ptr.as_ptr()
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index b35fe2c1..96db42af 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -543,7 +543,6 @@ impl PassManager {
                     &typing[idx],
                     &control_subgraphs[idx],
                     &bbs[idx],
-                    &object_device_demands[idx],
                     &mut llvm_ir,
                 )
                 .map_err(|e| SchedulerError::PassError {
@@ -1266,6 +1265,12 @@ fn run_pass(
                 });
             }
 
+            // Iterate functions in reverse topological order, since inter-
+            // device copies introduced in a callee may affect demands in a
+            // caller.
+            pm.make_callgraph();
+            let callgraph = pm.callgraph.take().unwrap();
+            let topo = callgraph.topo();
             loop {
                 pm.make_def_uses();
                 pm.make_reverse_postorders();
@@ -1275,6 +1280,8 @@ fn run_pass(
                 pm.make_fork_join_maps();
                 pm.make_loops();
                 pm.make_collection_objects();
+                pm.make_devices();
+                pm.make_object_device_demands();
 
                 let def_uses = pm.def_uses.take().unwrap();
                 let reverse_postorders = pm.reverse_postorders.take().unwrap();
@@ -1284,42 +1291,29 @@ fn run_pass(
                 let loops = pm.loops.take().unwrap();
                 let control_subgraphs = pm.control_subgraphs.take().unwrap();
                 let collection_objects = pm.collection_objects.take().unwrap();
+                let devices = pm.devices.take().unwrap();
+                let object_device_demands = pm.object_device_demands.take().unwrap();
 
-                let mut bbs = vec![];
-
-                for (
-                    (
-                        (
-                            ((((mut func, def_use), reverse_postorder), typing), control_subgraph),
-                            doms,
-                        ),
-                        fork_join_map,
-                    ),
-                    loops,
-                ) in build_editors(pm)
-                    .into_iter()
-                    .zip(def_uses.iter())
-                    .zip(reverse_postorders.iter())
-                    .zip(typing.iter())
-                    .zip(control_subgraphs.iter())
-                    .zip(doms.iter())
-                    .zip(fork_join_maps.iter())
-                    .zip(loops.iter())
-                {
+                let mut bbs = vec![(vec![], vec![]); topo.len()];
+                let mut editors = build_editors(pm);
+                for id in topo.iter() {
+                    let editor = &mut editors[id.idx()];
                     if let Some(bb) = gcm(
-                        &mut func,
-                        def_use,
-                        reverse_postorder,
-                        typing,
-                        control_subgraph,
-                        doms,
-                        fork_join_map,
-                        loops,
+                        editor,
+                        &def_uses[id.idx()],
+                        &reverse_postorders[id.idx()],
+                        &typing[id.idx()],
+                        &control_subgraphs[id.idx()],
+                        &doms[id.idx()],
+                        &fork_join_maps[id.idx()],
+                        &loops[id.idx()],
                         &collection_objects,
+                        &devices,
+                        &object_device_demands[id.idx()],
                     ) {
-                        bbs.push(bb);
+                        bbs[id.idx()] = bb;
                     }
-                    changed |= func.modified();
+                    changed |= editor.modified();
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
-- 
GitLab


From 1a97fe8b43216e9e77bce8ec185efcec02c3bafe Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 26 Jan 2025 18:27:27 -0600
Subject: [PATCH 07/24] Get type sizes in dynamic constants

---
 hercules_cg/src/rt.rs      |   2 +-
 hercules_opt/src/editor.rs |   8 ++
 hercules_opt/src/gcm.rs    | 163 +++++++++++++++++++++++++++++++++----
 3 files changed, 155 insertions(+), 18 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index e281fecb..6ee513b8 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -786,7 +786,7 @@ impl<'a> RTContext<'a> {
                     acc_size = format!("::core::cmp::max({}, {})", acc_size, variant);
                 }
 
-                // No alignment is necessary for the 1 byte discriminant.
+                // No alignment is necessary before the 1 byte discriminant.
                 let total_align = get_type_alignment(&self.module.types, ty);
                 format!(
                     "(({} + 1 + {}) & !{})",
diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 60745f21..8b90710e 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -335,6 +335,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
         self.function_id
     }
 
+    pub fn get_types(&self) -> Ref<'_, Vec<Type>> {
+        self.types.borrow()
+    }
+
+    pub fn get_constants(&self) -> Ref<'_, Vec<Constant>> {
+        self.constants.borrow()
+    }
+
     pub fn get_dynamic_constants(&self) -> Ref<'_, Vec<DynamicConstant>> {
         self.dynamic_constants.borrow()
     }
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index d4f4a92d..392f7de7 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1,3 +1,4 @@
+use std::cell::Ref;
 use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
 use std::iter::{empty, once, zip, FromIterator};
 
@@ -5,6 +6,7 @@ use bitvec::prelude::*;
 use either::Either;
 use union_find::{QuickFindUf, UnionBySize, UnionFind};
 
+use hercules_cg::get_type_alignment;
 use hercules_ir::*;
 
 use crate::*;
@@ -49,6 +51,21 @@ use crate::*;
  * a write where the `collect` and `data` inputs are on different devices. This
  * is only valid in RT functions - it is asserted that this isn't necessary in
  * device functions. This process "colors" the nodes in the function.
+ *
+ * GCM has one final responsibility - object allocation. Each Hercules function
+ * receives a pointer to a "backing" memory where collection constants live. The
+ * backing memory a function receives is for the constants in that function and
+ * the constants of every called function. Concretely, a function will pass a
+ * sub-regions of its backing memory to a callee, which during the call is that
+ * function's backing memory. Object allocation consists of finding the required
+ * sizes of all collection constants and functions in terms of dynamic constants
+ * (dynamic constant math is expressive enough to represent sizes of types,
+ * which is very convenient) and determining the concrete offsets into the
+ * backing memory where constants and callee sub-regions live. When two users of
+ * backing memory are never live at once, they may share backing memory. This is
+ * done after nodes are given a single device color, since we need to know what
+ * values are on what devices before we can allocate them to backing memory,
+ * since there are separate backing memories per-device.
  */
 pub fn gcm(
     editor: &mut FunctionEditor,
@@ -74,7 +91,15 @@ pub fn gcm(
         objects,
     );
 
-    if spill_clones(editor, typing, control_subgraph, objects, &bbs) {
+    let liveness = liveness_dataflow(
+        editor.func(),
+        editor.func_id(),
+        control_subgraph,
+        objects,
+        &bbs,
+    );
+
+    if spill_clones(editor, typing, control_subgraph, objects, &bbs, &liveness) {
         return None;
     }
 
@@ -106,6 +131,8 @@ pub fn gcm(
         }
     }
 
+    let object_sizes = object_sizes(editor, typing, &objects[&func_id]);
+
     Some(bbs)
 }
 
@@ -623,8 +650,6 @@ fn mutating_writes<'a>(
     }
 }
 
-type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>;
-
 /*
  * Top level function to find implicit clones that need to be spilled. Returns
  * whether a clone was spilled, in which case the whole scheduling process must
@@ -636,19 +661,9 @@ fn spill_clones(
     control_subgraph: &Subgraph,
     objects: &CollectionObjects,
     bbs: &BasicBlocks,
+    liveness: &Liveness,
 ) -> bool {
-    // Step 1: compute a liveness analysis of collection values in the IR. This
-    // requires a dataflow analysis over the scheduled IR, which is not a common
-    // need in Hercules, so just hardcode the analysis.
-    let liveness = liveness_dataflow(
-        editor.func(),
-        editor.func_id(),
-        control_subgraph,
-        objects,
-        bbs,
-    );
-
-    // Step 2: compute an interference graph from the liveness result. This
+    // Step 1: compute an interference graph from the liveness result. This
     // graph contains a vertex per node ID producing a collection value and an
     // edge per pair of node IDs that interfere. Nodes A and B interfere if node
     // A is defined right above a point where node B is live and A != B. Extra
@@ -695,7 +710,7 @@ fn spill_clones(
         }
     }
 
-    // Step 3: filter edges (A, B) to just see edges where A uses B and A
+    // Step 2: filter edges (A, B) to just see edges where A uses B and A
     // mutates B. These are the edges that may require a spill.
     let mut spill_edges = edges.into_iter().filter(|(a, b)| {
         mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
@@ -707,7 +722,7 @@ fn spill_clones(
                     || editor.func().nodes[a.idx()].is_reduce()))
     });
 
-    // Step 4: if there is a spill edge, spill it and return true. Otherwise,
+    // Step 3: if there is a spill edge, spill it and return true. Otherwise,
     // return false.
     if let Some((user, obj)) = spill_edges.next() {
         // Figure out the most immediate dominating region for every basic
@@ -861,6 +876,8 @@ fn spill_clones(
     }
 }
 
+type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>;
+
 /*
  * Liveness dataflow analysis on scheduled Hercules IR. Just look at nodes that
  * involve collections.
@@ -1010,6 +1027,8 @@ fn color_nodes(
         }
         if let Some(device) = device {
             colors.insert(*id, device);
+        } else {
+            assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not.");
         }
     }
     if bad_node.is_some() {
@@ -1018,3 +1037,113 @@ fn color_nodes(
 
     Some(colors)
 }
+
+/*
+ * Determine the size of objects in terms of dynamic constants based on typing.
+ */
+fn object_sizes(
+    editor: &mut FunctionEditor,
+    typing: &Vec<TypeID>,
+    objects: &FunctionCollectionObjects,
+) -> Vec<DynamicConstantID> {
+    // First, compute the alignments of every type.
+    let mut alignments = vec![];
+    Ref::map(editor.get_types(), |types| {
+        for idx in 0..types.len() {
+            alignments.push(get_type_alignment(types, TypeID::new(idx)));
+        }
+        &()
+    });
+
+    // Second, actually compute object sizes.
+    let mut sizes = vec![];
+    for id in objects.iter_objects() {
+        let ty_id = match objects.origin(id) {
+            CollectionObjectOrigin::Parameter(idx) => editor.func().param_types[idx],
+            CollectionObjectOrigin::Constant(id)
+            | CollectionObjectOrigin::Call(id)
+            | CollectionObjectOrigin::Undef(id) => typing[id.idx()],
+        };
+        let success = editor.edit(|mut edit| {
+            sizes.push(type_size(&mut edit, ty_id, &alignments));
+            Ok(edit)
+        });
+        assert!(success);
+    }
+
+    sizes
+}
+
+/*
+ * Determine the size of a type in terms of dynamic constants.
+ */
+fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> DynamicConstantID {
+    let align =
+        |edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize| -> DynamicConstantID {
+            assert_ne!(align, 0);
+            if align != 1 {
+                let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align));
+                let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1));
+                acc = edit.add_dynamic_constant(DynamicConstant::Add(acc, align_m1_dc));
+                acc = edit.add_dynamic_constant(DynamicConstant::Div(acc, align_dc));
+                acc = edit.add_dynamic_constant(DynamicConstant::Mul(acc, align_dc));
+            }
+            acc
+        };
+
+    let ty = edit.get_type(ty_id).clone();
+    let size = match ty {
+        Type::Control => panic!(),
+        Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => {
+            edit.add_dynamic_constant(DynamicConstant::Constant(1))
+        }
+        Type::Integer16 | Type::UnsignedInteger16 => {
+            edit.add_dynamic_constant(DynamicConstant::Constant(2))
+        }
+        Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => {
+            edit.add_dynamic_constant(DynamicConstant::Constant(4))
+        }
+        Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => {
+            edit.add_dynamic_constant(DynamicConstant::Constant(8))
+        }
+        Type::Product(fields) => {
+            // The layout of product types is like the C-style layout.
+            let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0));
+            for field in fields {
+                // Round up to the alignment of the field, then add the size of
+                // the field.
+                let field_size = type_size(edit, field, alignments);
+                acc_size = align(edit, acc_size, alignments[field.idx()]);
+                acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, field_size));
+            }
+            // Finally, round up to the alignment of the whole product, since
+            // the size needs to be a multiple of the alignment.
+            acc_size = align(edit, acc_size, alignments[ty_id.idx()]);
+            acc_size
+        }
+        Type::Summation(variants) => {
+            // A summation holds every variant in the same memory.
+            let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0));
+            for variant in variants {
+                // Pick the size of the largest variant, since that's the most
+                // memory we would need.
+                let variant_size = type_size(edit, variant, alignments);
+                acc_size = edit.add_dynamic_constant(DynamicConstant::Max(acc_size, variant_size));
+            }
+            // Add one byte for the discriminant and align the whole summation.
+            let one = edit.add_dynamic_constant(DynamicConstant::Constant(1));
+            acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, one));
+            acc_size = align(edit, acc_size, alignments[ty_id.idx()]);
+            acc_size
+        }
+        Type::Array(elem, bounds) => {
+            // The layout of an array is row-major linear in memory.
+            let mut acc_size = type_size(edit, elem, alignments);
+            for bound in bounds {
+                acc_size = edit.add_dynamic_constant(DynamicConstant::Mul(acc_size, bound));
+            }
+            acc_size
+        }
+    };
+    size
+}
-- 
GitLab


From cbf77ac824590180b9a398104817c4c72811369d Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 26 Jan 2025 22:12:30 -0600
Subject: [PATCH 08/24] Fix

---
 hercules_opt/src/gcm.rs | 31 +++++++++++++++++--------------
 1 file changed, 17 insertions(+), 14 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 392f7de7..ab2275ff 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1050,7 +1050,11 @@ fn object_sizes(
     let mut alignments = vec![];
     Ref::map(editor.get_types(), |types| {
         for idx in 0..types.len() {
-            alignments.push(get_type_alignment(types, TypeID::new(idx)));
+            if types[idx].is_control() {
+                alignments.push(0);
+            } else {
+                alignments.push(get_type_alignment(types, TypeID::new(idx)));
+            }
         }
         &()
     });
@@ -1074,23 +1078,22 @@ fn object_sizes(
     sizes
 }
 
+fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID {
+    assert_ne!(align, 0);
+    if align != 1 {
+        let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align));
+        let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1));
+        acc = edit.add_dynamic_constant(DynamicConstant::Add(acc, align_m1_dc));
+        acc = edit.add_dynamic_constant(DynamicConstant::Div(acc, align_dc));
+        acc = edit.add_dynamic_constant(DynamicConstant::Mul(acc, align_dc));
+    }
+    acc
+}
+
 /*
  * Determine the size of a type in terms of dynamic constants.
  */
 fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> DynamicConstantID {
-    let align =
-        |edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize| -> DynamicConstantID {
-            assert_ne!(align, 0);
-            if align != 1 {
-                let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align));
-                let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1));
-                acc = edit.add_dynamic_constant(DynamicConstant::Add(acc, align_m1_dc));
-                acc = edit.add_dynamic_constant(DynamicConstant::Div(acc, align_dc));
-                acc = edit.add_dynamic_constant(DynamicConstant::Mul(acc, align_dc));
-            }
-            acc
-        };
-
     let ty = edit.get_type(ty_id).clone();
     let size = match ty {
         Type::Control => panic!(),
-- 
GitLab


From 9b8c0969ece0237ca210b97534a09d6df0f3d34b Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 09:02:00 -0600
Subject: [PATCH 09/24] skeleton

---
 hercules_opt/src/gcm.rs  | 36 ++++++++++++++++++++++++++++++++++--
 juno_scheduler/src/pm.rs |  7 +++++--
 2 files changed, 39 insertions(+), 4 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index ab2275ff..c235b2b5 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -79,7 +79,12 @@ pub fn gcm(
     objects: &CollectionObjects,
     devices: &Vec<Device>,
     object_device_demands: &FunctionObjectDeviceDemands,
-) -> Option<BasicBlocks> {
+    backing_allocations: &BackingAllocations,
+) -> Option<(
+    BasicBlocks,
+    BTreeMap<NodeID, Device>,
+    FunctionBackingAllocation,
+)> {
     let bbs = basic_blocks(
         editor.func(),
         editor.func_id(),
@@ -132,8 +137,9 @@ pub fn gcm(
     }
 
     let object_sizes = object_sizes(editor, typing, &objects[&func_id]);
+    let backing_allocation = object_allocation(editor, &object_sizes, backing_allocations);
 
-    Some(bbs)
+    Some((bbs, node_colors, backing_allocation))
 }
 
 /*
@@ -1150,3 +1156,29 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
     };
     size
 }
+
+/*
+ * The allocation information of each function is a size of the backing memory
+ * needed and offsets into that backing memory per constant object and call in
+ * the function. This is determined per device.
+ */
+pub type FunctionBackingAllocation = BTreeMap<
+    Device,
+    (
+        DynamicConstantID,
+        BTreeMap<CollectionObjectID, DynamicConstantID>,
+    ),
+>;
+pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
+
+/*
+ * Allocate objects in a function. Relies on the allocations of all called
+ * functions.
+ */
+fn object_allocation(
+    editor: &mut FunctionEditor,
+    object_sizes: &Vec<DynamicConstantID>,
+    backing_allocations: &BackingAllocations,
+) -> FunctionBackingAllocation {
+    todo!()
+}
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 96db42af..ae1f3dea 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -16,7 +16,7 @@ use juno_utils::env::Env;
 use juno_utils::stringtab::StringTable;
 
 use std::cell::RefCell;
-use std::collections::{BTreeSet, HashMap, HashSet};
+use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
 use std::fmt;
 use std::fs::File;
 use std::io::Write;
@@ -1295,10 +1295,11 @@ fn run_pass(
                 let object_device_demands = pm.object_device_demands.take().unwrap();
 
                 let mut bbs = vec![(vec![], vec![]); topo.len()];
+                let mut backing_allocations = BTreeMap::new();
                 let mut editors = build_editors(pm);
                 for id in topo.iter() {
                     let editor = &mut editors[id.idx()];
-                    if let Some(bb) = gcm(
+                    if let Some((bb, node_colors, backing_allocation)) = gcm(
                         editor,
                         &def_uses[id.idx()],
                         &reverse_postorders[id.idx()],
@@ -1310,8 +1311,10 @@ fn run_pass(
                         &collection_objects,
                         &devices,
                         &object_device_demands[id.idx()],
+                        &backing_allocations,
                     ) {
                         bbs[id.idx()] = bb;
+                        backing_allocations.insert(*id, backing_allocation);
                     }
                     changed |= editor.modified();
                 }
-- 
GitLab


From 429d98a80fbf3533f3f4f63e34b8572e6dd6e4e0 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 10:11:26 -0600
Subject: [PATCH 10/24] Naive object allocation

---
 hercules_opt/src/gcm.rs | 65 ++++++++++++++++++++++++++++++++++++++---
 1 file changed, 61 insertions(+), 4 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index c235b2b5..63123c08 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -137,7 +137,14 @@ pub fn gcm(
     }
 
     let object_sizes = object_sizes(editor, typing, &objects[&func_id]);
-    let backing_allocation = object_allocation(editor, &object_sizes, backing_allocations);
+    let backing_allocation = object_allocation(
+        editor,
+        &objects[&func_id],
+        &object_device_demands,
+        &object_sizes,
+        &liveness,
+        backing_allocations,
+    );
 
     Some((bbs, node_colors, backing_allocation))
 }
@@ -1159,8 +1166,8 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
 
 /*
  * The allocation information of each function is a size of the backing memory
- * needed and offsets into that backing memory per constant object and call in
- * the function. This is determined per device.
+ * needed and offsets into that backing memory per constant object and call
+ * object in the function. This is determined per device.
  */
 pub type FunctionBackingAllocation = BTreeMap<
     Device,
@@ -1177,8 +1184,58 @@ pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
  */
 fn object_allocation(
     editor: &mut FunctionEditor,
+    objects: &FunctionCollectionObjects,
+    object_device_demands: &FunctionObjectDeviceDemands,
     object_sizes: &Vec<DynamicConstantID>,
+    liveness: &Liveness,
     backing_allocations: &BackingAllocations,
 ) -> FunctionBackingAllocation {
-    todo!()
+    let mut fba = BTreeMap::new();
+
+    editor.edit(|mut edit| {
+        // For now, just allocate each object to its own slot.
+        for id in objects.iter_objects() {
+            match objects.origin(id) {
+                CollectionObjectOrigin::Constant(_) => {
+                    let demands = &object_device_demands[id.idx()];
+                    assert_eq!(demands.len(), 1);
+                    let device = *demands.first().unwrap();
+                    let (total, offsets) = fba.entry(device).or_insert_with(|| {
+                        (
+                            edit.add_dynamic_constant(DynamicConstant::Constant(0)),
+                            BTreeMap::new(),
+                        )
+                    });
+                    *total = align(&mut edit, *total, 8);
+                    offsets.insert(id, *total);
+
+                    *total = edit
+                        .add_dynamic_constant(DynamicConstant::Add(*total, object_sizes[id.idx()]));
+                }
+                CollectionObjectOrigin::Call(node_id) => {
+                    let demands = &object_device_demands[id.idx()];
+                    assert_eq!(demands.len(), 1);
+                    let device = *demands.first().unwrap();
+                    let (total, offsets) = fba.entry(device).or_insert_with(|| {
+                        (
+                            edit.add_dynamic_constant(DynamicConstant::Constant(0)),
+                            BTreeMap::new(),
+                        )
+                    });
+                    *total = align(&mut edit, *total, 8);
+                    offsets.insert(id, *total);
+
+                    let callee = edit.get_node(node_id).try_call().unwrap().1;
+                    *total = edit.add_dynamic_constant(DynamicConstant::Add(
+                        *total,
+                        backing_allocations[&callee][&device].0,
+                    ));
+                }
+                _ => {}
+            }
+        }
+        Ok(edit)
+    });
+
+    fba
 }
-- 
GitLab


From 8c7e027868e328e5ca1086df778193dc8e60b520 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 10:48:18 -0600
Subject: [PATCH 11/24] Use devices analysis in xdot

---
 hercules_ir/src/dot.rs              | 13 +++++++++++--
 hercules_opt/src/gcm.rs             |  8 ++------
 hercules_samples/matmul/src/cpu.sch |  4 ++++
 hercules_samples/matmul/src/gpu.sch |  3 +++
 juno_scheduler/src/pm.rs            | 29 ++++++++++++++++-------------
 5 files changed, 36 insertions(+), 21 deletions(-)

diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 22cd0beb..b67fecdd 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -18,6 +18,7 @@ pub fn xdot_module(
     reverse_postorders: &Vec<Vec<NodeID>>,
     doms: Option<&Vec<DomTree>>,
     fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>,
+    devices: Option<&Vec<Device>>,
     bbs: Option<&Vec<BasicBlocks>>,
 ) {
     let mut tmp_path = temp_dir();
@@ -31,6 +32,7 @@ pub fn xdot_module(
         &reverse_postorders,
         doms,
         fork_join_maps,
+        devices,
         bbs,
         &mut contents,
     )
@@ -53,6 +55,7 @@ pub fn write_dot<W: Write>(
     reverse_postorders: &Vec<Vec<NodeID>>,
     doms: Option<&Vec<DomTree>>,
     fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>,
+    devices: Option<&Vec<Device>>,
     bbs: Option<&Vec<BasicBlocks>>,
     w: &mut W,
 ) -> std::fmt::Result {
@@ -65,7 +68,12 @@ pub fn write_dot<W: Write>(
         for (idx, id) in reverse_postorder.iter().enumerate() {
             reverse_postorder_node_numbers[id.idx()] = idx;
         }
-        write_subgraph_header(function_id, module, w)?;
+        write_subgraph_header(
+            function_id,
+            module,
+            devices.map(|devices| devices[function_id.idx()]),
+            w,
+        )?;
 
         // Step 1: draw IR graph itself. This includes all IR nodes and all edges
         // between IR nodes.
@@ -204,6 +212,7 @@ fn write_digraph_header<W: Write>(w: &mut W) -> std::fmt::Result {
 fn write_subgraph_header<W: Write>(
     function_id: FunctionID,
     module: &Module,
+    device: Option<Device>,
     w: &mut W,
 ) -> std::fmt::Result {
     let function = &module.functions[function_id.idx()];
@@ -219,7 +228,7 @@ fn write_subgraph_header<W: Write>(
     } else {
         write!(w, "label=\"{}\"\n", function.name)?;
     }
-    let color = match function.device {
+    let color = match device.or(function.device) {
         Some(Device::LLVM) => "paleturquoise1",
         Some(Device::CUDA) => "darkseagreen1",
         Some(Device::AsyncRust) => "peachpuff1",
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 63123c08..e30eaacb 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -80,11 +80,7 @@ pub fn gcm(
     devices: &Vec<Device>,
     object_device_demands: &FunctionObjectDeviceDemands,
     backing_allocations: &BackingAllocations,
-) -> Option<(
-    BasicBlocks,
-    BTreeMap<NodeID, Device>,
-    FunctionBackingAllocation,
-)> {
+) -> Option<(BasicBlocks, FunctionBackingAllocation)> {
     let bbs = basic_blocks(
         editor.func(),
         editor.func_id(),
@@ -146,7 +142,7 @@ pub fn gcm(
         backing_allocations,
     );
 
-    Some((bbs, node_colors, backing_allocation))
+    Some((bbs, backing_allocation))
 }
 
 /*
diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index 42dda6e3..a275a236 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -10,5 +10,9 @@ fork-split(*);
 unforkify(*);
 dce(*);
 float-collections(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 
 gcm(*);
+xdot[true](*);
diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch
index 9067a190..99ac21a6 100644
--- a/hercules_samples/matmul/src/gpu.sch
+++ b/hercules_samples/matmul/src/gpu.sch
@@ -10,6 +10,9 @@ ip-sroa(*);
 sroa(*);
 dce(*);
 float-collections(*);
+gvn(*);
+phi-elim(*);
+dce(*);
 
 gcm(*);
 xdot[true](*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index ae1f3dea..be30197a 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2,13 +2,7 @@ use crate::ir::*;
 use crate::labels::*;
 use hercules_cg::*;
 use hercules_ir::*;
-use hercules_opt::FunctionEditor;
-use hercules_opt::{
-    ccp, collapse_returns, crc, dce, dumb_outline, ensure_between_control_flow, float_collections,
-    fork_split, gcm, gvn, infer_parallel_fork, infer_parallel_reduce, infer_tight_associative,
-    infer_vectorizable, inline, interprocedural_sroa, lift_dc_math, outline, phi_elim, predication,
-    slf, sroa, unforkify, write_predication,
-};
+use hercules_opt::*;
 
 use tempfile::TempDir;
 
@@ -185,11 +179,12 @@ struct PassManager {
     pub loops: Option<Vec<LoopTree>>,
     pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
-    pub bbs: Option<Vec<BasicBlocks>>,
     pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
     pub devices: Option<Vec<Device>>,
     pub object_device_demands: Option<ObjectDeviceDemands>,
+    pub bbs: Option<Vec<BasicBlocks>>,
+    pub backing_allocations: Option<BackingAllocations>,
 }
 
 impl PassManager {
@@ -218,11 +213,12 @@ impl PassManager {
             loops: None,
             reduce_cycles: None,
             data_nodes_in_fork_joins: None,
-            bbs: None,
             collection_objects: None,
             callgraph: None,
             devices: None,
             object_device_demands: None,
+            bbs: None,
+            backing_allocations: None,
         }
     }
 
@@ -456,11 +452,11 @@ impl PassManager {
         self.loops = None;
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
-        self.bbs = None;
         self.collection_objects = None;
         self.callgraph = None;
         self.devices = None;
         self.object_device_demands = None;
+        self.bbs = None;
     }
 
     fn with_mod<B, F>(&mut self, mut f: F) -> B
@@ -509,11 +505,11 @@ impl PassManager {
             labels,
             typing: Some(typing),
             control_subgraphs: Some(control_subgraphs),
-            bbs: Some(bbs),
             collection_objects: Some(collection_objects),
             callgraph: Some(callgraph),
             devices: Some(devices),
             object_device_demands: Some(object_device_demands),
+            bbs: Some(bbs),
             ..
         } = self
         else {
@@ -1297,9 +1293,10 @@ fn run_pass(
                 let mut bbs = vec![(vec![], vec![]); topo.len()];
                 let mut backing_allocations = BTreeMap::new();
                 let mut editors = build_editors(pm);
+                let mut any_failed = false;
                 for id in topo.iter() {
                     let editor = &mut editors[id.idx()];
-                    if let Some((bb, node_colors, backing_allocation)) = gcm(
+                    if let Some((bb, backing_allocation)) = gcm(
                         editor,
                         &def_uses[id.idx()],
                         &reverse_postorders[id.idx()],
@@ -1315,13 +1312,16 @@ fn run_pass(
                     ) {
                         bbs[id.idx()] = bb;
                         backing_allocations.insert(*id, backing_allocation);
+                    } else {
+                        any_failed = true;
                     }
                     changed |= editor.modified();
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
-                if bbs.len() == pm.functions.len() {
+                if !any_failed {
                     pm.bbs = Some(bbs);
+                    pm.backing_allocations = Some(backing_allocations);
                     break;
                 }
             }
@@ -1619,11 +1619,13 @@ fn run_pass(
             if force_analyses {
                 pm.make_doms();
                 pm.make_fork_join_maps();
+                pm.make_devices();
             }
 
             let reverse_postorders = pm.reverse_postorders.take().unwrap();
             let doms = pm.doms.take();
             let fork_join_maps = pm.fork_join_maps.take();
+            let devices = pm.devices.take();
             let bbs = pm.bbs.take();
             pm.with_mod(|module| {
                 xdot_module(
@@ -1631,6 +1633,7 @@ fn run_pass(
                     &reverse_postorders,
                     doms.as_ref(),
                     fork_join_maps.as_ref(),
+                    devices.as_ref(),
                     bbs.as_ref(),
                 )
             });
-- 
GitLab


From 9646a82c37d763d753ea6670d061c59d21acb785 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 10:48:36 -0600
Subject: [PATCH 12/24] whoops

---
 hercules_samples/matmul/src/cpu.sch | 1 -
 1 file changed, 1 deletion(-)

diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index a275a236..f7891b9b 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -15,4 +15,3 @@ phi-elim(*);
 dce(*);
 
 gcm(*);
-xdot[true](*);
-- 
GitLab


From 438b8c640adc4b2f077884b98a840ede0a428d07 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 11:15:42 -0600
Subject: [PATCH 13/24] Re-do object allocation

---
 hercules_cg/src/lib.rs   |   2 +
 hercules_ir/src/dot.rs   |   6 +-
 hercules_opt/src/gcm.rs  | 150 +++++++++++++++------------------------
 juno_scheduler/src/pm.rs |   4 +-
 4 files changed, 67 insertions(+), 95 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 47039737..8f613408 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -8,6 +8,8 @@ pub use crate::rt::*;
 
 use hercules_ir::*;
 
+pub const LARGEST_ALIGNMENT: usize = 8;
+
 /*
  * The alignment of a type does not depend on dynamic constants.
  */
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index b67fecdd..8efabd7a 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -176,7 +176,7 @@ pub fn write_dot<W: Write>(
             }
         }
 
-        // Step 4: draw basic block edges in indigo.
+        // Step 4: draw basic block edges in blue.
         if let Some(bbs) = bbs {
             let bbs = &bbs[function_id.idx()].0;
             for (idx, bb) in bbs.into_iter().enumerate() {
@@ -187,7 +187,7 @@ pub fn write_dot<W: Write>(
                         *bb,
                         function_id,
                         true,
-                        "indigo",
+                        "lightslateblue",
                         "dotted",
                         &module,
                         w,
@@ -229,7 +229,7 @@ fn write_subgraph_header<W: Write>(
         write!(w, "label=\"{}\"\n", function.name)?;
     }
     let color = match device.or(function.device) {
-        Some(Device::LLVM) => "paleturquoise1",
+        Some(Device::LLVM) => "slategray1",
         Some(Device::CUDA) => "darkseagreen1",
         Some(Device::AsyncRust) => "peachpuff1",
         None => "ivory2",
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index e30eaacb..188ac65e 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -6,7 +6,7 @@ use bitvec::prelude::*;
 use either::Either;
 use union_find::{QuickFindUf, UnionBySize, UnionFind};
 
-use hercules_cg::get_type_alignment;
+use hercules_cg::{get_type_alignment, LARGEST_ALIGNMENT};
 use hercules_ir::*;
 
 use crate::*;
@@ -132,12 +132,23 @@ pub fn gcm(
         }
     }
 
-    let object_sizes = object_sizes(editor, typing, &objects[&func_id]);
+    let mut alignments = vec![];
+    Ref::map(editor.get_types(), |types| {
+        for idx in 0..types.len() {
+            if types[idx].is_control() {
+                alignments.push(0);
+            } else {
+                alignments.push(get_type_alignment(types, TypeID::new(idx)));
+            }
+        }
+        &()
+    });
+
     let backing_allocation = object_allocation(
         editor,
-        &objects[&func_id],
-        &object_device_demands,
-        &object_sizes,
+        typing,
+        &node_colors,
+        &alignments,
         &liveness,
         backing_allocations,
     );
@@ -1047,46 +1058,6 @@ fn color_nodes(
     Some(colors)
 }
 
-/*
- * Determine the size of objects in terms of dynamic constants based on typing.
- */
-fn object_sizes(
-    editor: &mut FunctionEditor,
-    typing: &Vec<TypeID>,
-    objects: &FunctionCollectionObjects,
-) -> Vec<DynamicConstantID> {
-    // First, compute the alignments of every type.
-    let mut alignments = vec![];
-    Ref::map(editor.get_types(), |types| {
-        for idx in 0..types.len() {
-            if types[idx].is_control() {
-                alignments.push(0);
-            } else {
-                alignments.push(get_type_alignment(types, TypeID::new(idx)));
-            }
-        }
-        &()
-    });
-
-    // Second, actually compute object sizes.
-    let mut sizes = vec![];
-    for id in objects.iter_objects() {
-        let ty_id = match objects.origin(id) {
-            CollectionObjectOrigin::Parameter(idx) => editor.func().param_types[idx],
-            CollectionObjectOrigin::Constant(id)
-            | CollectionObjectOrigin::Call(id)
-            | CollectionObjectOrigin::Undef(id) => typing[id.idx()],
-        };
-        let success = editor.edit(|mut edit| {
-            sizes.push(type_size(&mut edit, ty_id, &alignments));
-            Ok(edit)
-        });
-        assert!(success);
-    }
-
-    sizes
-}
-
 fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID {
     assert_ne!(align, 0);
     if align != 1 {
@@ -1162,16 +1133,11 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
 
 /*
  * The allocation information of each function is a size of the backing memory
- * needed and offsets into that backing memory per constant object and call
- * object in the function. This is determined per device.
+ * needed and offsets into that backing memory per constant object and call node
+ * in the function.
  */
-pub type FunctionBackingAllocation = BTreeMap<
-    Device,
-    (
-        DynamicConstantID,
-        BTreeMap<CollectionObjectID, DynamicConstantID>,
-    ),
->;
+pub type FunctionBackingAllocation =
+    BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>;
 pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
 
 /*
@@ -1180,52 +1146,54 @@ pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
  */
 fn object_allocation(
     editor: &mut FunctionEditor,
-    objects: &FunctionCollectionObjects,
-    object_device_demands: &FunctionObjectDeviceDemands,
-    object_sizes: &Vec<DynamicConstantID>,
+    typing: &Vec<TypeID>,
+    node_colors: &BTreeMap<NodeID, Device>,
+    alignments: &Vec<usize>,
     liveness: &Liveness,
     backing_allocations: &BackingAllocations,
 ) -> FunctionBackingAllocation {
     let mut fba = BTreeMap::new();
 
+    let devices = &[Device::LLVM, Device::CUDA];
+    let node_ids = editor.node_ids();
     editor.edit(|mut edit| {
         // For now, just allocate each object to its own slot.
-        for id in objects.iter_objects() {
-            match objects.origin(id) {
-                CollectionObjectOrigin::Constant(_) => {
-                    let demands = &object_device_demands[id.idx()];
-                    assert_eq!(demands.len(), 1);
-                    let device = *demands.first().unwrap();
-                    let (total, offsets) = fba.entry(device).or_insert_with(|| {
-                        (
-                            edit.add_dynamic_constant(DynamicConstant::Constant(0)),
-                            BTreeMap::new(),
-                        )
-                    });
-                    *total = align(&mut edit, *total, 8);
+        let zero = edit.add_dynamic_constant(DynamicConstant::Constant(0));
+        for id in node_ids {
+            match *edit.get_node(id) {
+                Node::Constant { id: _ } => {
+                    let device = node_colors[&id];
+                    let (total, offsets) =
+                        fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
+                    *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]);
                     offsets.insert(id, *total);
-
-                    *total = edit
-                        .add_dynamic_constant(DynamicConstant::Add(*total, object_sizes[id.idx()]));
+                    let type_size = type_size(&mut edit, typing[id.idx()], alignments);
+                    *total = edit.add_dynamic_constant(DynamicConstant::Add(*total, type_size));
                 }
-                CollectionObjectOrigin::Call(node_id) => {
-                    let demands = &object_device_demands[id.idx()];
-                    assert_eq!(demands.len(), 1);
-                    let device = *demands.first().unwrap();
-                    let (total, offsets) = fba.entry(device).or_insert_with(|| {
-                        (
-                            edit.add_dynamic_constant(DynamicConstant::Constant(0)),
-                            BTreeMap::new(),
-                        )
-                    });
-                    *total = align(&mut edit, *total, 8);
-                    offsets.insert(id, *total);
-
-                    let callee = edit.get_node(node_id).try_call().unwrap().1;
-                    *total = edit.add_dynamic_constant(DynamicConstant::Add(
-                        *total,
-                        backing_allocations[&callee][&device].0,
-                    ));
+                Node::Call {
+                    control: _,
+                    function: callee,
+                    dynamic_constants: _,
+                    args: _,
+                } => {
+                    for device in devices {
+                        let (total, offsets) = fba
+                            .entry(*device)
+                            .or_insert_with(|| (zero, BTreeMap::new()));
+                        if let Some(callee_backing_size) = backing_allocations[&callee]
+                            .get(&device)
+                            .map(|(callee_total, _)| *callee_total)
+                        {
+                            // We don't know the alignment requirement of the memory
+                            // in the callee, so just assume the largest alignment.
+                            *total = align(&mut edit, *total, LARGEST_ALIGNMENT);
+                            offsets.insert(id, *total);
+                            *total = edit.add_dynamic_constant(DynamicConstant::Add(
+                                *total,
+                                callee_backing_size,
+                            ));
+                        }
+                    }
                 }
                 _ => {}
             }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index be30197a..bb4425c9 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -457,6 +457,7 @@ impl PassManager {
         self.devices = None;
         self.object_device_demands = None;
         self.bbs = None;
+        self.backing_allocations = None;
     }
 
     fn with_mod<B, F>(&mut self, mut f: F) -> B
@@ -510,12 +511,13 @@ impl PassManager {
             devices: Some(devices),
             object_device_demands: Some(object_device_demands),
             bbs: Some(bbs),
+            backing_allocations: Some(backing_allocations),
             ..
         } = self
         else {
             return Err(SchedulerError::PassError {
                 pass: "codegen".to_string(),
-                error: "Missing basic blocks".to_string(),
+                error: "Missing basic blocks or backing allocations".to_string(),
             });
         };
 
-- 
GitLab


From 53a49888503c55e10b8d84ecbb84445a1aa92ea4 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 11:18:28 -0600
Subject: [PATCH 14/24] fix

---
 hercules_opt/src/gcm.rs | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 188ac65e..35608743 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1162,13 +1162,15 @@ fn object_allocation(
         for id in node_ids {
             match *edit.get_node(id) {
                 Node::Constant { id: _ } => {
-                    let device = node_colors[&id];
-                    let (total, offsets) =
-                        fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
-                    *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]);
-                    offsets.insert(id, *total);
-                    let type_size = type_size(&mut edit, typing[id.idx()], alignments);
-                    *total = edit.add_dynamic_constant(DynamicConstant::Add(*total, type_size));
+                    if !edit.get_type(typing[id.idx()]).is_primitive() {
+                        let device = node_colors[&id];
+                        let (total, offsets) =
+                            fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
+                        *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]);
+                        offsets.insert(id, *total);
+                        let type_size = type_size(&mut edit, typing[id.idx()], alignments);
+                        *total = edit.add_dynamic_constant(DynamicConstant::Add(*total, type_size));
+                    }
                 }
                 Node::Call {
                     control: _,
-- 
GitLab


From c97eb8a43e006d8e21f6c502959a8988b6f32971 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 13:18:28 -0600
Subject: [PATCH 15/24] Emit RT functions correctly

---
 hercules_cg/src/lib.rs   |  12 ++
 hercules_cg/src/rt.rs    | 353 ++++++++-------------------------------
 hercules_ir/src/ir.rs    |  10 ++
 hercules_opt/src/gcm.rs  |  19 +--
 juno_scheduler/src/pm.rs |   6 +-
 5 files changed, 100 insertions(+), 300 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 8f613408..ae95ae78 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -6,6 +6,8 @@ pub mod rt;
 pub use crate::cpu::*;
 pub use crate::rt::*;
 
+use std::collections::BTreeMap;
+
 use hercules_ir::*;
 
 pub const LARGEST_ALIGNMENT: usize = 8;
@@ -28,3 +30,13 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
         Type::Array(elem, _) => get_type_alignment(types, elem),
     }
 }
+
+/*
+ * The allocation information of each function is a size of the backing memory
+ * needed and offsets into that backing memory per constant object and call node
+ * in the function.
+ */
+pub type FunctionBackingAllocation =
+    BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>;
+pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
+pub const BACKED_DEVICES: [Device; 2] = [Device::LLVM, Device::CUDA];
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 6ee513b8..692fce78 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -78,11 +78,11 @@ pub fn rt_codegen<W: Write>(
     module: &Module,
     typing: &Vec<TypeID>,
     control_subgraph: &Subgraph,
-    bbs: &BasicBlocks,
     collection_objects: &CollectionObjects,
     callgraph: &CallGraph,
     devices: &Vec<Device>,
-    object_device_demands: &FunctionObjectDeviceDemands,
+    bbs: &BasicBlocks,
+    backing_allocation: &FunctionBackingAllocation,
     w: &mut W,
 ) -> Result<(), Error> {
     let ctx = RTContext {
@@ -90,11 +90,11 @@ pub fn rt_codegen<W: Write>(
         module,
         typing,
         control_subgraph,
-        bbs,
         collection_objects,
         callgraph,
         devices,
-        object_device_demands,
+        bbs,
+        backing_allocation,
     };
     ctx.codegen_function(w)
 }
@@ -104,11 +104,11 @@ struct RTContext<'a> {
     module: &'a Module,
     typing: &'a Vec<TypeID>,
     control_subgraph: &'a Subgraph,
-    bbs: &'a BasicBlocks,
     collection_objects: &'a CollectionObjects,
     callgraph: &'a CallGraph,
     devices: &'a Vec<Device>,
-    object_device_demands: &'a FunctionObjectDeviceDemands,
+    bbs: &'a BasicBlocks,
+    backing_allocation: &'a FunctionBackingAllocation,
 }
 
 impl<'a> RTContext<'a> {
@@ -118,11 +118,20 @@ impl<'a> RTContext<'a> {
         // Dump the function signature.
         write!(
             w,
-            "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync fn {}<'a>(",
+            "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync unsafe fn {}(",
             func.name
         )?;
         let mut first_param = true;
-        // The first set of parameters are dynamic constants.
+        // The first set of parameters are pointers to backing memories.
+        for (device, _) in self.backing_allocation {
+            if first_param {
+                first_param = false;
+            } else {
+                write!(w, ", ")?;
+            }
+            write!(w, "backing_{}: *mut u8", device.name())?;
+        }
+        // The second set of parameters are dynamic constants.
         for idx in 0..func.num_dynamic_constants {
             if first_param {
                 first_param = false;
@@ -131,7 +140,7 @@ impl<'a> RTContext<'a> {
             }
             write!(w, "dc_p{}: u64", idx)?;
         }
-        // The second set of parameters are normal parameters.
+        // The third set of parameters are normal parameters.
         for idx in 0..func.param_types.len() {
             if first_param {
                 first_param = false;
@@ -145,21 +154,6 @@ impl<'a> RTContext<'a> {
         }
         write!(w, ") -> {} {{\n", self.get_type(func.return_type))?;
 
-        // Allocate collection constants.
-        for object in self.collection_objects[&self.func_id].iter_objects() {
-            if let CollectionObjectOrigin::Constant(id) =
-                self.collection_objects[&self.func_id].origin(object)
-            {
-                let size = self.codegen_type_size(self.typing[id.idx()]);
-                write!(
-                    w,
-                    "    let mut obj{}: ::hercules_rt::HerculesBox = unsafe {{ ::hercules_rt::HerculesBox::__zeros({}) }};\n",
-                    object.idx(),
-                    size
-                )?
-            }
-        }
-
         // Dump signatures for called device functions.
         write!(w, "    extern \"C\" {{\n")?;
         for callee in self.callgraph.get_callees(self.func_id) {
@@ -183,9 +177,9 @@ impl<'a> RTContext<'a> {
                 } else {
                     write!(w, ", ")?;
                 }
-                write!(w, "p{}: {}", idx, self.device_get_type(*ty))?;
+                write!(w, "p{}: {}", idx, self.get_type(*ty))?;
             }
-            write!(w, ") -> {};\n", self.device_get_type(callee.return_type))?;
+            write!(w, ") -> {};\n", self.get_type(callee.return_type))?;
         }
         write!(w, "    }}\n")?;
 
@@ -204,7 +198,7 @@ impl<'a> RTContext<'a> {
                 } else if self.module.types[self.typing[idx].idx()].is_float() {
                     "0.0"
                 } else {
-                    "unsafe { ::hercules_rt::HerculesBox::__null() }"
+                    "::core::ptr::null::<u8>() as _"
                 }
             )?;
         }
@@ -214,7 +208,7 @@ impl<'a> RTContext<'a> {
         // blocks to drive execution.
         write!(
             w,
-            "    let mut control_token: i8 = 0;\n    let return_value = loop {{\n        match control_token {{\n",
+            "    let mut control_token: i8 = 0;\n    loop {{\n        match control_token {{\n",
         )?;
 
         let mut blocks: BTreeMap<_, _> = (0..func.nodes.len())
@@ -248,39 +242,7 @@ impl<'a> RTContext<'a> {
         }
 
         // Close the match and loop.
-        write!(w, "            _ => panic!()\n        }}\n    }};\n")?;
-
-        // Emit the epilogue of the function.
-        write!(w, "    unsafe {{\n")?;
-        for idx in 0..func.param_types.len() {
-            if !self.module.types[func.param_types[idx].idx()].is_primitive() {
-                write!(w, "        p{}.__forget();\n", idx)?;
-            }
-        }
-        if !self.module.types[func.return_type.idx()].is_primitive() {
-            for object in self.collection_objects[&self.func_id].iter_objects() {
-                if let CollectionObjectOrigin::Constant(_) =
-                    self.collection_objects[&self.func_id].origin(object)
-                {
-                    write!(
-                        w,
-                        "        if obj{}.__cmp_ids(&return_value) {{\n",
-                        object.idx()
-                    )?;
-                    write!(w, "            obj{}.__forget();\n", object.idx())?;
-                    write!(w, "        }}\n")?;
-                }
-            }
-        }
-        for idx in 0..func.nodes.len() {
-            if !func.nodes[idx].is_control()
-                && !self.module.types[self.typing[idx].idx()].is_primitive()
-            {
-                write!(w, "        node_{}.__forget();\n", idx)?;
-            }
-        }
-        write!(w, "    }}\n")?;
-        write!(w, "    return_value\n")?;
+        write!(w, "            _ => panic!()\n        }}\n    }}\n")?;
         write!(w, "}}\n")?;
         Ok(())
     }
@@ -328,15 +290,7 @@ impl<'a> RTContext<'a> {
             }
             Node::Return { control: _, data } => {
                 let block = &mut blocks.get_mut(&id).unwrap();
-                if self.module.types[self.typing[data.idx()].idx()].is_primitive() {
-                    write!(block, "                break {};\n", self.get_value(data))?
-                } else {
-                    write!(
-                        block,
-                        "                break unsafe {{ {}.__clone() }};\n",
-                        self.get_value(data)
-                    )?
-                }
+                write!(block, "                return {};\n", self.get_value(data))?
             }
             _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
         }
@@ -355,21 +309,12 @@ impl<'a> RTContext<'a> {
         match func.nodes[id.idx()] {
             Node::Parameter { index } => {
                 let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
-                if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
-                    write!(
-                        block,
-                        "                {} = p{};\n",
-                        self.get_value(id),
-                        index
-                    )?
-                } else {
-                    write!(
-                        block,
-                        "                {} = unsafe {{ p{}.__clone() }};\n",
-                        self.get_value(id),
-                        index
-                    )?
-                }
+                write!(
+                    block,
+                    "                {} = p{};\n",
+                    self.get_value(id),
+                    index
+                )?
             }
             Node::Constant { id: cons_id } => {
                 let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
@@ -387,10 +332,17 @@ impl<'a> RTContext<'a> {
                     Constant::Float32(val) => write!(block, "{}f32", val)?,
                     Constant::Float64(val) => write!(block, "{}f64", val)?,
                     Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => {
-                        let objects = self.collection_objects[&self.func_id].objects(id);
-                        assert_eq!(objects.len(), 1);
-                        let object = objects[0];
-                        write!(block, "unsafe {{ obj{}.__clone() }}", object.idx())?
+                        let (device, offset) = self
+                            .backing_allocation
+                            .into_iter()
+                            .filter_map(|(device, (_, offsets))| {
+                                offsets.get(&id).map(|id| (*device, *id))
+                            })
+                            .next()
+                            .unwrap();
+                        write!(block, "backing_{}.byte_add(", device.name())?;
+                        self.codegen_dynamic_constant(offset, block)?;
+                        write!(block, ")")?
                     }
                 }
                 write!(block, ";\n")?
@@ -401,123 +353,36 @@ impl<'a> RTContext<'a> {
                 ref dynamic_constants,
                 ref args,
             } => {
+                // The device backends ensure that device functions have the
+                // same interface as AsyncRust functions.
+                let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
+                write!(
+                    block,
+                    "                {} = {}(",
+                    self.get_value(id),
+                    self.module.functions[callee_id.idx()].name
+                )?;
+                for (device, offset) in self
+                    .backing_allocation
+                    .into_iter()
+                    .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id)))
+                {
+                    write!(block, "backing_{}.byte_add(", device.name())?;
+                    self.codegen_dynamic_constant(offset, block)?;
+                    write!(block, ")")?
+                }
+                for dc in dynamic_constants {
+                    self.codegen_dynamic_constant(*dc, block)?;
+                    write!(block, ", ")?;
+                }
+                for arg in args {
+                    write!(block, "{}, ", self.get_value(*arg))?;
+                }
                 let device = self.devices[callee_id.idx()];
-                match device {
-                    // The device backends ensure that device functions have the
-                    // same C interface.
-                    Device::LLVM | Device::CUDA => {
-                        let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
-
-                        let device = match device {
-                            Device::LLVM => "cpu",
-                            Device::CUDA => "cuda",
-                            _ => panic!(),
-                        };
-
-                        // First, get the raw pointers to collections that the
-                        // device function takes as input.
-                        let callee_objs = &self.collection_objects[&callee_id];
-                        for (idx, arg) in args.into_iter().enumerate() {
-                            if let Some(obj) = callee_objs.param_to_object(idx) {
-                                // Extract a raw pointer from the HerculesBox.
-                                if callee_objs.is_mutated(obj) {
-                                    write!(
-                                        block,
-                                        "                let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n",
-                                        idx,
-                                        self.get_value(*arg),
-                                        device
-                                    )?;
-                                } else {
-                                    write!(
-                                        block,
-                                        "                let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n",
-                                        idx,
-                                        self.get_value(*arg),
-                                        device
-                                    )?;
-                                }
-                            } else {
-                                write!(
-                                    block,
-                                    "                let arg_tmp{} = {};\n",
-                                    idx,
-                                    self.get_value(*arg)
-                                )?;
-                            }
-                        }
-
-                        // Emit the call.
-                        write!(
-                            block,
-                            "                let call_tmp = unsafe {{ {}(",
-                            self.module.functions[callee_id.idx()].name
-                        )?;
-                        for dc in dynamic_constants {
-                            self.codegen_dynamic_constant(*dc, block)?;
-                            write!(block, ", ")?;
-                        }
-                        for idx in 0..args.len() {
-                            write!(block, "arg_tmp{}, ", idx)?;
-                        }
-                        write!(block, ") }};\n")?;
-
-                        // When a device function is called that returns a
-                        // collection object, that object must have come from
-                        // one of its parameters. Dynamically figure out which
-                        // one it came from, so that we can move it to the slot
-                        // of the output object.
-                        let caller_objects = self.collection_objects[&self.func_id].objects(id);
-                        if !caller_objects.is_empty() {
-                            for (idx, arg) in args.into_iter().enumerate() {
-                                if idx != 0 {
-                                    write!(block, "                else\n")?;
-                                }
-                                write!(
-                                    block,
-                                    "                if call_tmp == arg_tmp{} {{\n",
-                                    idx
-                                )?;
-                                write!(
-                                    block,
-                                    "                    {} = unsafe {{ {}.__clone() }};\n",
-                                    self.get_value(id),
-                                    self.get_value(*arg)
-                                )?;
-                                write!(block, "                }}")?;
-                            }
-                            write!(block, "                else {{\n")?;
-                            write!(block, "                    panic!(\"HERCULES PANIC: Pointer returned from device function doesn't match an argument pointer.\");\n")?;
-                            write!(block, "                }}\n")?;
-                        } else {
-                            write!(
-                                block,
-                                "                {} = call_tmp;\n",
-                                self.get_value(id)
-                            )?;
-                        }
-                    }
-                    Device::AsyncRust => {
-                        let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
-                        write!(
-                            block,
-                            "                {} = {}(",
-                            self.get_value(id),
-                            self.module.functions[callee_id.idx()].name
-                        )?;
-                        for dc in dynamic_constants {
-                            self.codegen_dynamic_constant(*dc, block)?;
-                            write!(block, ", ")?;
-                        }
-                        for arg in args {
-                            if self.module.types[self.typing[arg.idx()].idx()].is_primitive() {
-                                write!(block, "{}, ", self.get_value(*arg))?;
-                            } else {
-                                write!(block, "unsafe {{ {}.__clone() }}, ", self.get_value(*arg))?;
-                            }
-                        }
-                        write!(block, ").await;\n")?;
-                    }
+                if device == Device::AsyncRust {
+                    write!(block, ").await;\n")?;
+                } else {
+                    write!(block, ");\n")?;
                 }
             }
             Node::Read {
@@ -528,33 +393,7 @@ impl<'a> RTContext<'a> {
                 let collect_ty = self.typing[collect.idx()];
                 let out_size = self.codegen_type_size(self.typing[id.idx()]);
                 let offset = self.codegen_index_math(collect_ty, indices)?;
-                write!(
-                    block,
-                    "                let mut read_offset_obj = unsafe {{ {}.__clone() }};\n",
-                    self.get_value(collect)
-                )?;
-                write!(
-                    block,
-                    "                unsafe {{ read_offset_obj.__offset({}, {}) }};\n",
-                    offset, out_size,
-                )?;
-                if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
-                    write!(
-                        block,
-                        "                {} = unsafe {{ *(read_offset_obj.__cpu_ptr() as *const _) }};\n",
-                        self.get_value(id)
-                    )?;
-                    write!(
-                        block,
-                        "                unsafe {{ read_offset_obj.__forget() }};\n",
-                    )?;
-                } else {
-                    write!(
-                        block,
-                        "                {} = read_offset_obj;\n",
-                        self.get_value(id)
-                    )?;
-                }
+                todo!();
             }
             Node::Write {
                 collect,
@@ -565,31 +404,7 @@ impl<'a> RTContext<'a> {
                 let collect_ty = self.typing[collect.idx()];
                 let data_size = self.codegen_type_size(self.typing[data.idx()]);
                 let offset = self.codegen_index_math(collect_ty, indices)?;
-                write!(
-                    block,
-                    "                let mut write_offset_obj = unsafe {{ {}.__clone() }};\n",
-                    self.get_value(collect)
-                )?;
-                write!(block, "                let write_offset_ptr = unsafe {{ write_offset_obj.__cpu_ptr_mut().byte_add({}) }};\n", offset)?;
-                if self.module.types[self.typing[data.idx()].idx()].is_primitive() {
-                    write!(
-                        block,
-                        "                unsafe {{ *(write_offset_ptr as *mut _) = {} }};\n",
-                        self.get_value(data)
-                    )?;
-                } else {
-                    write!(
-                        block,
-                        "                unsafe {{ ::core::ptr::copy_nonoverlapping({}.__cpu_ptr(), write_offset_ptr as *mut _, {} as usize) }};\n",
-                        self.get_value(data),
-                        data_size,
-                    )?;
-                }
-                write!(
-                    block,
-                    "                {} = write_offset_obj;\n",
-                    self.get_value(id),
-                )?;
+                todo!();
             }
             _ => panic!(
                 "PANIC: Can't lower {:?} in {}.",
@@ -819,33 +634,9 @@ impl<'a> RTContext<'a> {
     fn get_type(&self, id: TypeID) -> &'static str {
         convert_type(&self.module.types[id.idx()])
     }
-
-    fn device_get_type(&self, id: TypeID) -> &'static str {
-        device_convert_type(&self.module.types[id.idx()])
-    }
 }
 
 fn convert_type(ty: &Type) -> &'static str {
-    match ty {
-        Type::Boolean => "bool",
-        Type::Integer8 => "i8",
-        Type::Integer16 => "i16",
-        Type::Integer32 => "i32",
-        Type::Integer64 => "i64",
-        Type::UnsignedInteger8 => "u8",
-        Type::UnsignedInteger16 => "u16",
-        Type::UnsignedInteger32 => "u32",
-        Type::UnsignedInteger64 => "u64",
-        Type::Float32 => "f32",
-        Type::Float64 => "f64",
-        Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
-            "::hercules_rt::HerculesBox<'a>"
-        }
-        _ => panic!(),
-    }
-}
-
-fn device_convert_type(ty: &Type) -> &'static str {
     match ty {
         Type::Boolean => "bool",
         Type::Integer8 => "i8",
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 5577228f..d8a124e2 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -1710,6 +1710,16 @@ impl Intrinsic {
     }
 }
 
+impl Device {
+    pub fn name(&self) -> &'static str {
+        match self {
+            Device::LLVM => "cpu",
+            Device::CUDA => "cuda",
+            Device::AsyncRust => "rt",
+        }
+    }
+}
+
 /*
  * Rust things to make newtyped IDs usable.
  */
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 35608743..f85ac18a 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -6,7 +6,7 @@ use bitvec::prelude::*;
 use either::Either;
 use union_find::{QuickFindUf, UnionBySize, UnionFind};
 
-use hercules_cg::{get_type_alignment, LARGEST_ALIGNMENT};
+use hercules_cg::*;
 use hercules_ir::*;
 
 use crate::*;
@@ -1131,15 +1131,6 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
     size
 }
 
-/*
- * The allocation information of each function is a size of the backing memory
- * needed and offsets into that backing memory per constant object and call node
- * in the function.
- */
-pub type FunctionBackingAllocation =
-    BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>;
-pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
-
 /*
  * Allocate objects in a function. Relies on the allocations of all called
  * functions.
@@ -1154,7 +1145,6 @@ fn object_allocation(
 ) -> FunctionBackingAllocation {
     let mut fba = BTreeMap::new();
 
-    let devices = &[Device::LLVM, Device::CUDA];
     let node_ids = editor.node_ids();
     editor.edit(|mut edit| {
         // For now, just allocate each object to its own slot.
@@ -1178,14 +1168,13 @@ fn object_allocation(
                     dynamic_constants: _,
                     args: _,
                 } => {
-                    for device in devices {
-                        let (total, offsets) = fba
-                            .entry(*device)
-                            .or_insert_with(|| (zero, BTreeMap::new()));
+                    for device in BACKED_DEVICES {
                         if let Some(callee_backing_size) = backing_allocations[&callee]
                             .get(&device)
                             .map(|(callee_total, _)| *callee_total)
                         {
+                            let (total, offsets) =
+                                fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
                             // We don't know the alignment requirement of the memory
                             // in the callee, so just assume the largest alignment.
                             *total = align(&mut edit, *total, LARGEST_ALIGNMENT);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index bb4425c9..5bf89f4d 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -496,7 +496,6 @@ impl PassManager {
         self.make_collection_objects();
         self.make_callgraph();
         self.make_devices();
-        self.make_object_device_demands();
 
         let PassManager {
             functions,
@@ -509,7 +508,6 @@ impl PassManager {
             collection_objects: Some(collection_objects),
             callgraph: Some(callgraph),
             devices: Some(devices),
-            object_device_demands: Some(object_device_demands),
             bbs: Some(bbs),
             backing_allocations: Some(backing_allocations),
             ..
@@ -552,11 +550,11 @@ impl PassManager {
                     &module,
                     &typing[idx],
                     &control_subgraphs[idx],
-                    &bbs[idx],
                     &collection_objects,
                     &callgraph,
                     &devices,
-                    &object_device_demands[idx],
+                    &bbs[idx],
+                    &backing_allocations[&FunctionID::new(idx)],
                     &mut rust_rt,
                 )
                 .map_err(|e| SchedulerError::PassError {
-- 
GitLab


From 48f38e306373d2c28b2912a430e3e574c0461606 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 17:04:37 -0600
Subject: [PATCH 16/24] Skeleton of runner type

---
 hercules_cg/src/lib.rs   |  7 ++++
 hercules_cg/src/rt.rs    | 80 +++++++++++++++++++++++++++++++++++++++-
 hercules_opt/src/gcm.rs  |  8 ++--
 juno_scheduler/src/pm.rs | 13 ++++++-
 4 files changed, 100 insertions(+), 8 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index ae95ae78..6a12901f 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -31,6 +31,13 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
     }
 }
 
+/*
+ * Nodes producing collection values are "colored" with what device their
+ * underlying memory lives on.
+ */
+pub type FunctionNodeColors = BTreeMap<NodeID, Device>;
+pub type NodeColors = Vec<FunctionNodeColors>;
+
 /*
  * The allocation information of each function is a size of the backing memory
  * needed and offsets into that backing memory per constant object and call node
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 692fce78..8800e04c 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -82,6 +82,7 @@ pub fn rt_codegen<W: Write>(
     callgraph: &CallGraph,
     devices: &Vec<Device>,
     bbs: &BasicBlocks,
+    node_colors: &FunctionNodeColors,
     backing_allocation: &FunctionBackingAllocation,
     w: &mut W,
 ) -> Result<(), Error> {
@@ -94,6 +95,7 @@ pub fn rt_codegen<W: Write>(
         callgraph,
         devices,
         bbs,
+        node_colors,
         backing_allocation,
     };
     ctx.codegen_function(w)
@@ -108,14 +110,18 @@ struct RTContext<'a> {
     callgraph: &'a CallGraph,
     devices: &'a Vec<Device>,
     bbs: &'a BasicBlocks,
+    node_colors: &'a FunctionNodeColors,
     backing_allocation: &'a FunctionBackingAllocation,
 }
 
 impl<'a> RTContext<'a> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
-        let func = &self.get_func();
+        // If this is an entry function, generate a corresponding runner object
+        // type definition.
+        self.codegen_runner_object(w)?;
 
         // Dump the function signature.
+        let func = &self.get_func();
         write!(
             w,
             "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync unsafe fn {}(",
@@ -198,7 +204,7 @@ impl<'a> RTContext<'a> {
                 } else if self.module.types[self.typing[idx].idx()].is_float() {
                     "0.0"
                 } else {
-                    "::core::ptr::null::<u8>() as _"
+                    "::core::ptr::null_mut()"
                 }
             )?;
         }
@@ -623,6 +629,76 @@ impl<'a> RTContext<'a> {
         }
     }
 
+    /*
+     * Generate a runner object for this function.
+     */
+    fn codegen_runner_object<W: Write>(&self, w: &mut W) -> Result<(), Error> {
+        // Figure out the devices for the parameters and the return value if
+        // they are collections.
+        let func = self.get_func();
+        let mut param_devices = vec![None; func.param_types.len()];
+        let mut return_device = None;
+        for idx in 0..func.nodes.len() {
+            match func.nodes[idx] {
+                Node::Parameter { index } => {
+                    let device = self.node_colors.get(&NodeID::new(idx));
+                    assert!(param_devices[index].is_none() || param_devices[index] == device);
+                    param_devices[index] = device;
+                }
+                Node::Return { control: _, data } => {
+                    let device = self.node_colors.get(&data);
+                    assert!(return_device.is_none() || return_device == device);
+                    return_device = device;
+                }
+                _ => {}
+            }
+        }
+
+        // Emit the type definition. A runner object owns its backing memory.
+        write!(
+            w,
+            "#[allow(non_camel_case_types)]\nstruct HerculesRunner_{} {{\n",
+            func.name
+        )?;
+        for (device, _) in self.backing_allocation {
+            write!(
+                w,
+                "    backing_ptr_{}: *mut u8,\n    backing_size_{}: usize,\n",
+                device.name(),
+                device.name()
+            )?;
+        }
+        write!(
+            w,
+            "}}\nimpl HerculesRunner_{} {{\n    fn new() -> Self {{\n        Self {{\n",
+            func.name
+        )?;
+        for (device, _) in self.backing_allocation {
+            write!(
+                w,
+                "            backing_ptr_{}: ::core::ptr::null_mut(),\n            backing_size_{}: 0,\n",
+                device.name(),
+                device.name()
+            )?;
+        }
+        write!(
+            w,
+            "        }}\n    }}\n}}\nimpl Drop for HerculesRunner_{} {{\n    fn drop(&mut self) {{\n        unsafe {{\n",
+            func.name
+        )?;
+        for (device, _) in self.backing_allocation {
+            write!(
+                w,
+                "            ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n",
+                device.name(),
+                device.name(),
+                device.name()
+            )?;
+        }
+        write!(w, "        }}\n    }}\n}}\n")?;
+        Ok(())
+    }
+
     fn get_func(&self) -> &Function {
         &self.module.functions[self.func_id.idx()]
     }
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index f85ac18a..9929f6d6 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -80,7 +80,7 @@ pub fn gcm(
     devices: &Vec<Device>,
     object_device_demands: &FunctionObjectDeviceDemands,
     backing_allocations: &BackingAllocations,
-) -> Option<(BasicBlocks, FunctionBackingAllocation)> {
+) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> {
     let bbs = basic_blocks(
         editor.func(),
         editor.func_id(),
@@ -153,7 +153,7 @@ pub fn gcm(
         backing_allocations,
     );
 
-    Some((bbs, backing_allocation))
+    Some((bbs, node_colors, backing_allocation))
 }
 
 /*
@@ -1028,7 +1028,7 @@ fn color_nodes(
     reverse_postorder: &Vec<NodeID>,
     objects: &FunctionCollectionObjects,
     object_device_demands: &FunctionObjectDeviceDemands,
-) -> Option<BTreeMap<NodeID, Device>> {
+) -> Option<FunctionNodeColors> {
     // First, try to give each node a single color.
     let mut colors = BTreeMap::new();
     let mut bad_node = None;
@@ -1138,7 +1138,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) ->
 fn object_allocation(
     editor: &mut FunctionEditor,
     typing: &Vec<TypeID>,
-    node_colors: &BTreeMap<NodeID, Device>,
+    node_colors: &FunctionNodeColors,
     alignments: &Vec<usize>,
     liveness: &Liveness,
     backing_allocations: &BackingAllocations,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5bf89f4d..6205fa75 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -184,6 +184,7 @@ struct PassManager {
     pub devices: Option<Vec<Device>>,
     pub object_device_demands: Option<ObjectDeviceDemands>,
     pub bbs: Option<Vec<BasicBlocks>>,
+    pub node_colors: Option<NodeColors>,
     pub backing_allocations: Option<BackingAllocations>,
 }
 
@@ -218,6 +219,7 @@ impl PassManager {
             devices: None,
             object_device_demands: None,
             bbs: None,
+            node_colors: None,
             backing_allocations: None,
         }
     }
@@ -457,6 +459,7 @@ impl PassManager {
         self.devices = None;
         self.object_device_demands = None;
         self.bbs = None;
+        self.node_colors = None;
         self.backing_allocations = None;
     }
 
@@ -509,6 +512,7 @@ impl PassManager {
             callgraph: Some(callgraph),
             devices: Some(devices),
             bbs: Some(bbs),
+            node_colors: Some(node_colors),
             backing_allocations: Some(backing_allocations),
             ..
         } = self
@@ -554,6 +558,7 @@ impl PassManager {
                     &callgraph,
                     &devices,
                     &bbs[idx],
+                    &node_colors[idx],
                     &backing_allocations[&FunctionID::new(idx)],
                     &mut rust_rt,
                 )
@@ -1263,7 +1268,8 @@ fn run_pass(
 
             // Iterate functions in reverse topological order, since inter-
             // device copies introduced in a callee may affect demands in a
-            // caller.
+            // caller, and the object allocation of a callee affects the object
+            // allocation of its callers.
             pm.make_callgraph();
             let callgraph = pm.callgraph.take().unwrap();
             let topo = callgraph.topo();
@@ -1291,12 +1297,13 @@ fn run_pass(
                 let object_device_demands = pm.object_device_demands.take().unwrap();
 
                 let mut bbs = vec![(vec![], vec![]); topo.len()];
+                let mut node_colors = vec![BTreeMap::new(); topo.len()];
                 let mut backing_allocations = BTreeMap::new();
                 let mut editors = build_editors(pm);
                 let mut any_failed = false;
                 for id in topo.iter() {
                     let editor = &mut editors[id.idx()];
-                    if let Some((bb, backing_allocation)) = gcm(
+                    if let Some((bb, function_node_colors, backing_allocation)) = gcm(
                         editor,
                         &def_uses[id.idx()],
                         &reverse_postorders[id.idx()],
@@ -1311,6 +1318,7 @@ fn run_pass(
                         &backing_allocations,
                     ) {
                         bbs[id.idx()] = bb;
+                        node_colors[id.idx()] = function_node_colors;
                         backing_allocations.insert(*id, backing_allocation);
                     } else {
                         any_failed = true;
@@ -1321,6 +1329,7 @@ fn run_pass(
                 pm.clear_analyses();
                 if !any_failed {
                     pm.bbs = Some(bbs);
+                    pm.node_colors = Some(node_colors);
                     pm.backing_allocations = Some(backing_allocations);
                     break;
                 }
-- 
GitLab


From 781db782903c8127b94d8760284cd76f8c670bfe Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 19:13:42 -0600
Subject: [PATCH 17/24] fix

---
 hercules_cg/src/rt.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 8800e04c..316449d9 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -348,7 +348,7 @@ impl<'a> RTContext<'a> {
                             .unwrap();
                         write!(block, "backing_{}.byte_add(", device.name())?;
                         self.codegen_dynamic_constant(offset, block)?;
-                        write!(block, ")")?
+                        write!(block, " as usize)")?
                     }
                 }
                 write!(block, ";\n")?
-- 
GitLab


From 6ffce538d98980656798ec6d00cbd176ec5779ab Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 22:03:47 -0600
Subject: [PATCH 18/24] Generate runner object

---
 hercules_cg/src/rt.rs  | 110 ++++++++++++++++++++++++++++++++++++-----
 hercules_rt/src/lib.rs |  32 ++++++++++++
 2 files changed, 130 insertions(+), 12 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 316449d9..8e777e5c 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -153,9 +153,6 @@ impl<'a> RTContext<'a> {
             } else {
                 write!(w, ", ")?;
             }
-            if !self.module.types[func.param_types[idx].idx()].is_primitive() {
-                write!(w, "mut ")?;
-            }
             write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
         }
         write!(w, ") -> {} {{\n", self.get_type(func.return_type))?;
@@ -634,7 +631,8 @@ impl<'a> RTContext<'a> {
      */
     fn codegen_runner_object<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // Figure out the devices for the parameters and the return value if
-        // they are collections.
+        // they are collections and whether they should be immutable or mutable
+        // references.
         let func = self.get_func();
         let mut param_devices = vec![None; func.param_types.len()];
         let mut return_device = None;
@@ -653,6 +651,8 @@ impl<'a> RTContext<'a> {
                 _ => {}
             }
         }
+        let param_muts = vec![true; func.param_types.len()];
+        let return_mut = false;
 
         // Emit the type definition. A runner object owns its backing memory.
         write!(
@@ -661,16 +661,13 @@ impl<'a> RTContext<'a> {
             func.name
         )?;
         for (device, _) in self.backing_allocation {
-            write!(
-                w,
-                "    backing_ptr_{}: *mut u8,\n    backing_size_{}: usize,\n",
-                device.name(),
-                device.name()
-            )?;
+            write!(w, "    backing_ptr_{}: *mut u8,\n", device.name(),)?;
+            write!(w, "    backing_size_{}: usize,\n", device.name(),)?;
         }
+        write!(w, "}}\n")?;
         write!(
             w,
-            "}}\nimpl HerculesRunner_{} {{\n    fn new() -> Self {{\n        Self {{\n",
+            "impl HerculesRunner_{} {{\n    fn new() -> Self {{\n        Self {{\n",
             func.name
         )?;
         for (device, _) in self.backing_allocation {
@@ -681,9 +678,98 @@ impl<'a> RTContext<'a> {
                 device.name()
             )?;
         }
+        write!(w, "        }}\n    }}\n")?;
+        write!(w, "    async fn run<'a>(&'a mut self")?;
+        for idx in 0..func.num_dynamic_constants {
+            write!(w, ", dc_p{}: u64", idx)?;
+        }
+        for idx in 0..func.param_types.len() {
+            if self.module.types[func.param_types[idx].idx()].is_primitive() {
+                write!(w, ", p{}: {}", idx, self.get_type(func.param_types[idx]))?;
+            } else {
+                let device = match param_devices[idx] {
+                    Some(Device::LLVM) => "CPU",
+                    Some(Device::CUDA) => "CUDA",
+                    _ => panic!(),
+                };
+                let mutability = if param_muts[idx] { "Mut" } else { "" };
+                write!(
+                    w,
+                    ", p{}: ::hercules_rt::Hercules{}Ref{}<'a>",
+                    idx, device, mutability
+                )?;
+            }
+        }
+        if self.module.types[func.return_type.idx()].is_primitive() {
+            write!(w, ") -> {} {{\n", self.get_type(func.return_type))?;
+        } else {
+            let device = match return_device {
+                Some(Device::LLVM) => "CPU",
+                Some(Device::CUDA) => "CUDA",
+                _ => panic!(),
+            };
+            let mutability = if return_mut { "Mut" } else { "" };
+            write!(
+                w,
+                ") -> ::hercules_rt::Hercules{}Ref{}<'a> {{\n",
+                device, mutability
+            )?;
+        }
+        write!(w, "    unsafe {{\n")?;
+        for (device, (total, _)) in self.backing_allocation {
+            write!(w, "        let size = ")?;
+            self.codegen_dynamic_constant(*total, w)?;
+            write!(
+                w,
+                " as usize;\n        if self.backing_size_{} < size {{\n",
+                device.name()
+            )?;
+            write!(w, "            ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", device.name(), device.name(), device.name())?;
+            write!(
+                w,
+                "            self.backing_size_{} = size;\n",
+                device.name()
+            )?;
+            write!(w, "            self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});\n", device.name(), device.name(), device.name())?;
+            write!(w, "        }}\n")?;
+        }
+        for idx in 0..func.param_types.len() {
+            if !self.module.types[func.param_types[idx].idx()].is_primitive() {
+                write!(w, "        let p{} = p{}.__ptr();\n", idx, idx)?;
+            }
+        }
+        write!(w, "        let ret = {}(", func.name)?;
+        for (device, _) in self.backing_allocation {
+            write!(w, "self.backing_ptr_{}, ", device.name())?;
+        }
+        for idx in 0..func.num_dynamic_constants {
+            write!(w, "dc_p{}, ", idx)?;
+        }
+        for idx in 0..func.param_types.len() {
+            write!(w, "p{}, ", idx)?;
+        }
+        write!(w, ").await;\n")?;
+        if self.module.types[func.return_type.idx()].is_primitive() {
+            write!(w, "        ret\n")?;
+        } else {
+            let device = match return_device {
+                Some(Device::LLVM) => "CPU",
+                Some(Device::CUDA) => "CUDA",
+                _ => panic!(),
+            };
+            let mutability = if return_mut { "Mut" } else { "" };
+            write!(
+                w,
+                "        ::hercules_rt::Hercules{}Ref{}::__from_parts(ret, {} as usize)\n",
+                device,
+                mutability,
+                self.codegen_type_size(func.return_type)
+            )?;
+        }
+        write!(w, "    }}\n    }}\n")?;
         write!(
             w,
-            "        }}\n    }}\n}}\nimpl Drop for HerculesRunner_{} {{\n    fn drop(&mut self) {{\n        unsafe {{\n",
+            "}}\nimpl Drop for HerculesRunner_{} {{\n    fn drop(&mut self) {{\n        unsafe {{\n",
             func.name
         )?;
         for (device, _) in self.backing_allocation {
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index c244a611..46c1b4c1 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -92,6 +92,14 @@ impl<'a> HerculesCPURef<'a> {
     pub unsafe fn __size(&self) -> usize {
         self.size
     }
+
+    pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
+        Self {
+            ptr: NonNull::new(ptr).unwrap(),
+            size,
+            _phantom: PhantomData,
+        }
+    }
 }
 
 impl<'a> HerculesCPURefMut<'a> {
@@ -127,6 +135,14 @@ impl<'a> HerculesCPURefMut<'a> {
     pub unsafe fn __size(&self) -> usize {
         self.size
     }
+
+    pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
+        Self {
+            ptr: NonNull::new(ptr).unwrap(),
+            size,
+            _phantom: PhantomData,
+        }
+    }
 }
 
 #[cfg(feature = "cuda")]
@@ -138,6 +154,14 @@ impl<'a> HerculesCUDARef<'a> {
     pub unsafe fn __size(&self) -> usize {
         self.size
     }
+
+    pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
+        Self {
+            ptr: NonNull::new(ptr).unwrap(),
+            size,
+            _phantom: PhantomData,
+        }
+    }
 }
 
 #[cfg(feature = "cuda")]
@@ -157,6 +181,14 @@ impl<'a> HerculesCUDARefMut<'a> {
     pub unsafe fn __size(&self) -> usize {
         self.size
     }
+
+    pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
+        Self {
+            ptr: NonNull::new(ptr).unwrap(),
+            size,
+            _phantom: PhantomData,
+        }
+    }
 }
 
 #[cfg(feature = "cuda")]
-- 
GitLab


From e7156955a06ec5f940003ff575958412bde2fb2b Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 22:09:50 -0600
Subject: [PATCH 19/24] Matmul test works

---
 hercules_rt/src/lib.rs              |  7 +++++++
 hercules_samples/matmul/src/main.rs | 11 ++++++-----
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 46c1b4c1..db2dee77 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -243,3 +243,10 @@ impl Drop for CUDABox {
         }
     }
 }
+
+#[macro_export]
+macro_rules! runner {
+    ($x: ident) => {
+        <concat_idents!(HerculesRunner_, $x)>::new()
+    };
+}
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 767fda07..08e66f64 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -1,8 +1,8 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(box_as_ptr, let_chains, concat_idents)]
 
 use rand::random;
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURefMut};
 
 juno_build::juno!("matmul");
 
@@ -21,9 +21,10 @@ fn main() {
                 }
             }
         }
-        let a = HerculesBox::from_slice_mut(&mut a);
-        let b = HerculesBox::from_slice_mut(&mut b);
-        let mut c = matmul(I as u64, J as u64, K as u64, a, b).await;
+        let a = HerculesCPURefMut::from_slice(&mut a);
+        let b = HerculesCPURefMut::from_slice(&mut b);
+        let mut r = runner!(matmul);
+        let c = r.run(I as u64, J as u64, K as u64, a, b).await;
         assert_eq!(c.as_slice::<i32>(), &*correct_c);
     });
 }
-- 
GitLab


From d0684833f8ae35838e74768b4a7b74971708d994 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 27 Jan 2025 22:18:30 -0600
Subject: [PATCH 20/24] Be smarter about requiring immut/mut references in
 runner api

---
 hercules_cg/src/rt.rs               | 19 +++++++++++++++++--
 hercules_samples/matmul/src/main.rs |  6 +++---
 2 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 8e777e5c..190f979d 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -651,8 +651,23 @@ impl<'a> RTContext<'a> {
                 _ => {}
             }
         }
-        let param_muts = vec![true; func.param_types.len()];
-        let return_mut = false;
+        let mut param_muts = vec![false; func.param_types.len()];
+        let mut return_mut = true;
+        let objects = &self.collection_objects[&self.func_id];
+        for idx in 0..func.param_types.len() {
+            if let Some(object) = objects.param_to_object(idx)
+                && objects.is_mutated(object)
+            {
+                param_muts[idx] = true;
+            }
+        }
+        for object in objects.returned_objects() {
+            if let Some(idx) = objects.origin(*object).try_parameter()
+                && !param_muts[idx]
+            {
+                return_mut = false;
+            }
+        }
 
         // Emit the type definition. A runner object owns its backing memory.
         write!(
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 08e66f64..e19ef8ec 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -2,7 +2,7 @@
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesCPURefMut};
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("matmul");
 
@@ -21,8 +21,8 @@ fn main() {
                 }
             }
         }
-        let a = HerculesCPURefMut::from_slice(&mut a);
-        let b = HerculesCPURefMut::from_slice(&mut b);
+        let a = HerculesCPURef::from_slice(&mut a);
+        let b = HerculesCPURef::from_slice(&mut b);
         let mut r = runner!(matmul);
         let c = r.run(I as u64, J as u64, K as u64, a, b).await;
         assert_eq!(c.as_slice::<i32>(), &*correct_c);
-- 
GitLab


From c8a69ae3aee1584a53ff541ceba724debbed392d Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 28 Jan 2025 10:03:35 -0600
Subject: [PATCH 21/24] hercules_samples tests ported

---
 hercules_cg/src/rt.rs             |  2 +-
 hercules_samples/call/src/main.rs | 10 +++++++---
 hercules_samples/ccp/src/main.rs  |  7 +++++--
 hercules_samples/dot/src/main.rs  | 11 ++++++-----
 hercules_samples/fac/src/main.rs  |  7 ++++++-
 5 files changed, 25 insertions(+), 12 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 190f979d..df943661 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -784,7 +784,7 @@ impl<'a> RTContext<'a> {
         write!(w, "    }}\n    }}\n")?;
         write!(
             w,
-            "}}\nimpl Drop for HerculesRunner_{} {{\n    fn drop(&mut self) {{\n        unsafe {{\n",
+            "}}\nimpl Drop for HerculesRunner_{} {{\n    #[allow(unused_unsafe)]\n    fn drop(&mut self) {{\n        unsafe {{\n",
             func.name
         )?;
         for (device, _) in self.backing_allocation {
diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs
index 0b657dd8..75c58761 100644
--- a/hercules_samples/call/src/main.rs
+++ b/hercules_samples/call/src/main.rs
@@ -1,11 +1,15 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(box_as_ptr, let_chains, concat_idents)]
+
+use hercules_rt::runner;
 
 juno_build::juno!("call");
 
 fn main() {
     async_std::task::block_on(async {
-        let x = myfunc(7).await;
-        let y = add(10, 2, 18).await;
+        let mut r = runner!(myfunc);
+        let x = r.run(7).await;
+        let mut r = runner!(add);
+        let y = r.run(10, 2, 18).await;
         assert_eq!(x, y);
     });
 }
diff --git a/hercules_samples/ccp/src/main.rs b/hercules_samples/ccp/src/main.rs
index 7f6459a0..e51d1eb3 100644
--- a/hercules_samples/ccp/src/main.rs
+++ b/hercules_samples/ccp/src/main.rs
@@ -1,10 +1,13 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(box_as_ptr, let_chains, concat_idents)]
+
+use hercules_rt::runner;
 
 juno_build::juno!("ccp");
 
 fn main() {
     async_std::task::block_on(async {
-        let x = tricky(7).await;
+        let mut r = runner!(tricky);
+        let x = r.run(7).await;
         assert_eq!(x, 1);
     });
 }
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 0b5c6a93..4e45d4de 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -1,6 +1,6 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(box_as_ptr, let_chains, concat_idents)]
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("dot");
 
@@ -8,9 +8,10 @@ fn main() {
     async_std::task::block_on(async {
         let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
         let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
-        let a = HerculesBox::from_slice(&a);
-        let b = HerculesBox::from_slice(&b);
-        let c = dot(8, a, b).await;
+        let a = HerculesCPURef::from_slice(&a);
+        let b = HerculesCPURef::from_slice(&b);
+        let mut r = runner!(dot);
+        let c = r.run(8, a, b).await;
         println!("{}", c);
         assert_eq!(c, 70.0);
     });
diff --git a/hercules_samples/fac/src/main.rs b/hercules_samples/fac/src/main.rs
index b6e0257b..40180d44 100644
--- a/hercules_samples/fac/src/main.rs
+++ b/hercules_samples/fac/src/main.rs
@@ -1,8 +1,13 @@
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
+
 juno_build::juno!("fac");
 
 fn main() {
     async_std::task::block_on(async {
-        let f = fac(8).await;
+        let mut r = runner!(fac);
+        let f = r.run(8).await;
         println!("{}", f);
         assert_eq!(f, 40320);
     });
-- 
GitLab


From 6b902d98bacb0ef78dcd3388ac55a0bc5d478572 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 28 Jan 2025 10:12:24 -0600
Subject: [PATCH 22/24] remove unused features

---
 hercules_samples/call/src/main.rs   | 2 +-
 hercules_samples/ccp/src/main.rs    | 2 +-
 hercules_samples/dot/src/main.rs    | 2 +-
 hercules_samples/matmul/src/main.rs | 2 +-
 4 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs
index 75c58761..ff4b6f4a 100644
--- a/hercules_samples/call/src/main.rs
+++ b/hercules_samples/call/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(box_as_ptr, let_chains, concat_idents)]
+#![feature(concat_idents)]
 
 use hercules_rt::runner;
 
diff --git a/hercules_samples/ccp/src/main.rs b/hercules_samples/ccp/src/main.rs
index e51d1eb3..ecf37973 100644
--- a/hercules_samples/ccp/src/main.rs
+++ b/hercules_samples/ccp/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(box_as_ptr, let_chains, concat_idents)]
+#![feature(concat_idents)]
 
 use hercules_rt::runner;
 
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 4e45d4de..335e8909 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(box_as_ptr, let_chains, concat_idents)]
+#![feature(concat_idents)]
 
 use hercules_rt::{runner, HerculesCPURef};
 
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index e19ef8ec..8757a0fd 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(box_as_ptr, let_chains, concat_idents)]
+#![feature(concat_idents)]
 
 use rand::random;
 
-- 
GitLab


From 2adf74a07c317d8be4bbb0702498773d96d35180 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 28 Jan 2025 10:29:09 -0600
Subject: [PATCH 23/24] antideps test

---
 hercules_cg/src/rt.rs             | 18 +++++++++++++++---
 juno_samples/antideps/src/main.rs | 25 +++++++++++++++++--------
 2 files changed, 32 insertions(+), 11 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index df943661..8efaa26d 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -322,6 +322,7 @@ impl<'a> RTContext<'a> {
             Node::Constant { id: cons_id } => {
                 let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
                 write!(block, "                {} = ", self.get_value(id))?;
+                let mut size = None;
                 match self.module.constants[cons_id.idx()] {
                     Constant::Boolean(val) => write!(block, "{}bool", val)?,
                     Constant::Integer8(val) => write!(block, "{}i8", val)?,
@@ -334,7 +335,9 @@ impl<'a> RTContext<'a> {
                     Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?,
                     Constant::Float32(val) => write!(block, "{}f32", val)?,
                     Constant::Float64(val) => write!(block, "{}f64", val)?,
-                    Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => {
+                    Constant::Product(ty, _)
+                    | Constant::Summation(ty, _, _)
+                    | Constant::Array(ty) => {
                         let (device, offset) = self
                             .backing_allocation
                             .into_iter()
@@ -345,10 +348,19 @@ impl<'a> RTContext<'a> {
                             .unwrap();
                         write!(block, "backing_{}.byte_add(", device.name())?;
                         self.codegen_dynamic_constant(offset, block)?;
-                        write!(block, " as usize)")?
+                        write!(block, " as usize)")?;
+                        size = Some(self.codegen_type_size(ty));
                     }
                 }
-                write!(block, ";\n")?
+                write!(block, ";\n")?;
+                if let Some(size) = size {
+                    write!(
+                        block,
+                        "                ::core::ptr::write_bytes({}, 0, {});\n",
+                        self.get_value(id),
+                        size
+                    )?;
+                }
             }
             Node::Call {
                 control: _,
diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs
index 2f1e8efc..9c37bd01 100644
--- a/juno_samples/antideps/src/main.rs
+++ b/juno_samples/antideps/src/main.rs
@@ -1,34 +1,43 @@
-#![feature(future_join, box_as_ptr)]
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
 
 juno_build::juno!("antideps");
 
 fn main() {
     async_std::task::block_on(async {
-        let output = simple_antideps(1, 1).await;
+        let mut r = runner!(simple_antideps);
+        let output = r.run(1, 1).await;
         println!("{}", output);
         assert_eq!(output, 5);
 
-        let output = loop_antideps(11).await;
+        let mut r = runner!(loop_antideps);
+        let output = r.run(11).await;
         println!("{}", output);
         assert_eq!(output, 5);
 
-        let output = complex_antideps1(9).await;
+        let mut r = runner!(complex_antideps1);
+        let output = r.run(9).await;
         println!("{}", output);
         assert_eq!(output, 20);
 
-        let output = complex_antideps2(44).await;
+        let mut r = runner!(complex_antideps2);
+        let output = r.run(44).await;
         println!("{}", output);
         assert_eq!(output, 226);
 
-        let output = very_complex_antideps(3).await;
+        let mut r = runner!(very_complex_antideps);
+        let output = r.run(3).await;
         println!("{}", output);
         assert_eq!(output, 144);
 
-        let output = read_chains(2).await;
+        let mut r = runner!(read_chains);
+        let output = r.run(2).await;
         println!("{}", output);
         assert_eq!(output, 14);
 
-        let output = array_of_structs(2).await;
+        let mut r = runner!(array_of_structs);
+        let output = r.run(2).await;
         println!("{}", output);
         assert_eq!(output, 14);
     });
-- 
GitLab


From b39d6339b5bd090b1eeb46fcd28c222aef617bc1 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 28 Jan 2025 13:42:03 -0600
Subject: [PATCH 24/24] fix juno_samples tests

---
 hercules_cg/src/rt.rs                         | 13 ++++++---
 juno_samples/casts_and_intrinsics/src/main.rs |  7 +++--
 juno_samples/cava/src/main.rs                 | 21 +++++++-------
 juno_samples/concat/src/main.rs               |  7 +++--
 juno_samples/implicit_clone/src/main.rs       | 28 +++++++++++++------
 juno_samples/matmul/src/main.rs               | 20 ++++++-------
 juno_samples/nested_ccp/src/main.rs           | 19 +++++++------
 juno_samples/schedule_test/src/main.rs        | 15 +++++-----
 juno_samples/simple3/src/main.rs              | 15 +++++-----
 juno_scheduler/src/pm.rs                      |  3 ++
 10 files changed, 86 insertions(+), 62 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 8efaa26d..445647ef 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -118,13 +118,15 @@ impl<'a> RTContext<'a> {
     fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
         // If this is an entry function, generate a corresponding runner object
         // type definition.
-        self.codegen_runner_object(w)?;
+        let func = &self.get_func();
+        if func.entry {
+            self.codegen_runner_object(w)?;
+        }
 
         // Dump the function signature.
-        let func = &self.get_func();
         write!(
             w,
-            "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync unsafe fn {}(",
+            "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]\nasync unsafe fn {}(",
             func.name
         )?;
         let mut first_param = true;
@@ -356,7 +358,7 @@ impl<'a> RTContext<'a> {
                 if let Some(size) = size {
                     write!(
                         block,
-                        "                ::core::ptr::write_bytes({}, 0, {});\n",
+                        "                ::core::ptr::write_bytes({}, 0, {} as usize);\n",
                         self.get_value(id),
                         size
                     )?;
@@ -717,6 +719,9 @@ impl<'a> RTContext<'a> {
                 let device = match param_devices[idx] {
                     Some(Device::LLVM) => "CPU",
                     Some(Device::CUDA) => "CUDA",
+                    // For parameters that are unused, it doesn't really matter
+                    // what device is required, so just pick CPU for now.
+                    None => "CPU",
                     _ => panic!(),
                 };
                 let mutability = if param_muts[idx] { "Mut" } else { "" };
diff --git a/juno_samples/casts_and_intrinsics/src/main.rs b/juno_samples/casts_and_intrinsics/src/main.rs
index 8ee509bf..6b27c60c 100644
--- a/juno_samples/casts_and_intrinsics/src/main.rs
+++ b/juno_samples/casts_and_intrinsics/src/main.rs
@@ -1,10 +1,13 @@
-#![feature(future_join)]
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
 
 juno_build::juno!("casts_and_intrinsics");
 
 fn main() {
     async_std::task::block_on(async {
-        let output = casts_and_intrinsics(16.0).await;
+        let mut r = runner!(casts_and_intrinsics);
+        let output = r.run(16.0).await;
         println!("{}", output);
         assert_eq!(output, 4);
     });
diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs
index 9c2f99a8..73a75a94 100644
--- a/juno_samples/cava/src/main.rs
+++ b/juno_samples/cava/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(future_join, box_as_ptr, let_chains)]
+#![feature(concat_idents)]
 
 mod camera_model;
 mod cava_rust;
@@ -8,7 +8,7 @@ use self::camera_model::*;
 use self::cava_rust::CHAN;
 use self::image_proc::*;
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURef};
 
 use image::ImageError;
 
@@ -28,25 +28,26 @@ fn run_cava(
     tonemap: &[f32],
 ) -> Box<[u8]> {
     assert_eq!(image.len(), CHAN * rows * cols);
-    let image = HerculesBox::from_slice(image);
+    let image = HerculesCPURef::from_slice(image);
 
     assert_eq!(tstw.len(), CHAN * CHAN);
-    let tstw = HerculesBox::from_slice(tstw);
+    let tstw = HerculesCPURef::from_slice(tstw);
 
     assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN);
-    let ctrl_pts = HerculesBox::from_slice(ctrl_pts);
+    let ctrl_pts = HerculesCPURef::from_slice(ctrl_pts);
 
     assert_eq!(weights.len(), num_ctrl_pts * CHAN);
-    let weights = HerculesBox::from_slice(weights);
+    let weights = HerculesCPURef::from_slice(weights);
 
     assert_eq!(coefs.len(), 4 * CHAN);
-    let coefs = HerculesBox::from_slice(coefs);
+    let coefs = HerculesCPURef::from_slice(coefs);
 
     assert_eq!(tonemap.len(), 256 * CHAN);
-    let tonemap = HerculesBox::from_slice(tonemap);
+    let tonemap = HerculesCPURef::from_slice(tonemap);
 
+    let mut r = runner!(cava);
     async_std::task::block_on(async {
-        cava(
+        r.run(
             rows as u64,
             cols as u64,
             num_ctrl_pts as u64,
@@ -58,7 +59,7 @@ fn run_cava(
             tonemap,
         )
         .await
-    }).as_slice::<u8>().into()
+    }).as_slice::<u8>().to_vec().into_boxed_slice()
 }
 
 enum Error {
diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs
index 17a0ab96..db3f37fd 100644
--- a/juno_samples/concat/src/main.rs
+++ b/juno_samples/concat/src/main.rs
@@ -1,10 +1,13 @@
-#![feature(future_join, box_as_ptr)]
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
 
 juno_build::juno!("concat");
 
 fn main() {
     async_std::task::block_on(async {
-        let output = concat_entry(7).await;
+        let mut r = runner!(concat_entry);
+        let output = r.run(7).await;
         println!("{}", output);
         assert_eq!(output, 42);
     });
diff --git a/juno_samples/implicit_clone/src/main.rs b/juno_samples/implicit_clone/src/main.rs
index bc687ed3..1e94ff89 100644
--- a/juno_samples/implicit_clone/src/main.rs
+++ b/juno_samples/implicit_clone/src/main.rs
@@ -1,38 +1,48 @@
-#![feature(future_join, box_as_ptr)]
+#![feature(concat_idents)]
+
+use hercules_rt::runner;
 
 juno_build::juno!("implicit_clone");
 
 fn main() {
     async_std::task::block_on(async {
-        let output = simple_implicit_clone(3).await;
+        let mut r = runner!(simple_implicit_clone);
+        let output = r.run(3).await;
         println!("{}", output);
         assert_eq!(output, 11);
 
-        let output = loop_implicit_clone(100).await;
+        let mut r = runner!(loop_implicit_clone);
+        let output = r.run(100).await;
         println!("{}", output);
         assert_eq!(output, 7);
 
-        let output = double_loop_implicit_clone(3).await;
+        let mut r = runner!(double_loop_implicit_clone);
+        let output = r.run(3).await;
         println!("{}", output);
         assert_eq!(output, 42);
 
-        let output = tricky_loop_implicit_clone(2, 2).await;
+        let mut r = runner!(tricky_loop_implicit_clone);
+        let output = r.run(2, 2).await;
         println!("{}", output);
         assert_eq!(output, 130);
 
-        let output = tricky2_loop_implicit_clone(2, 3).await;
+        let mut r = runner!(tricky2_loop_implicit_clone);
+        let output = r.run(2, 3).await;
         println!("{}", output);
         assert_eq!(output, 39);
 
-        let output = tricky3_loop_implicit_clone(5, 7).await;
+        let mut r = runner!(tricky3_loop_implicit_clone);
+        let output = r.run(5, 7).await;
         println!("{}", output);
         assert_eq!(output, 7);
 
-        let output = no_implicit_clone(4).await;
+        let mut r = runner!(no_implicit_clone);
+        let output = r.run(4).await;
         println!("{}", output);
         assert_eq!(output, 13);
 
-        let output = mirage_implicit_clone(73).await;
+        let mut r = runner!(mirage_implicit_clone);
+        let output = r.run(73).await;
         println!("{}", output);
         assert_eq!(output, 843);
     });
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index bace3765..fa5d1f04 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -1,8 +1,8 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(concat_idents)]
 
 use rand::random;
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("matmul");
 
@@ -21,17 +21,13 @@ fn main() {
                 }
             }
         }
-        let mut c = {
-            let a = HerculesBox::from_slice(&a);
-            let b = HerculesBox::from_slice(&b);
-            matmul(I as u64, J as u64, K as u64, a, b).await
-        };
+        let a = HerculesCPURef::from_slice(&a);
+        let b = HerculesCPURef::from_slice(&b);
+        let mut r = runner!(matmul);
+        let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await;
         assert_eq!(c.as_slice::<i32>(), &*correct_c);
-        let mut tiled_c = {
-            let a = HerculesBox::from_slice(&a);
-            let b = HerculesBox::from_slice(&b);
-            tiled_64_matmul(I as u64, J as u64, K as u64, a, b).await
-        };
+        let mut r = runner!(tiled_64_matmul);
+        let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await;
         assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
     });
 }
diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs
index f49171ce..423b66fb 100644
--- a/juno_samples/nested_ccp/src/main.rs
+++ b/juno_samples/nested_ccp/src/main.rs
@@ -1,18 +1,21 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(concat_idents)]
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut};
 
 juno_build::juno!("nested_ccp");
 
 fn main() {
     async_std::task::block_on(async {
-        let mut a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]);
+        let a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]);
         let mut b: Box<[i32]> = Box::new([12, 16, 4, 18, 23, 56, 93, 22, 14]);
-        let a = HerculesBox::from_slice_mut(&mut a);
-        let b = HerculesBox::from_slice_mut(&mut b);
-        let output_example = ccp_example(a).await;
-        let output_median = median_array(9, b).await;
-        let out_no_underflow = no_underflow().await;
+        let a = HerculesCPURef::from_slice(&a);
+        let b = HerculesCPURefMut::from_slice(&mut b);
+        let mut r = runner!(ccp_example);
+        let output_example = r.run(a).await;
+        let mut r = runner!(median_array);
+        let output_median = r.run(9, b).await;
+        let mut r = runner!(no_underflow);
+        let out_no_underflow = r.run().await;
         println!("{}", output_example);
         println!("{}", output_median);
         println!("{}", out_no_underflow);
diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs
index a64cd16f..2e63babf 100644
--- a/juno_samples/schedule_test/src/main.rs
+++ b/juno_samples/schedule_test/src/main.rs
@@ -1,8 +1,8 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(concat_idents)]
 
 use rand::random;
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("code");
 
@@ -26,12 +26,11 @@ fn main() {
             }
         }
 
-        let mut res = {
-            let a = HerculesBox::from_slice(&a);
-            let b = HerculesBox::from_slice(&b);
-            let c = HerculesBox::from_slice(&c);
-            test(N as u64, M as u64, K as u64, a, b, c).await
-        };
+        let a = HerculesCPURef::from_slice(&a);
+        let b = HerculesCPURef::from_slice(&b);
+        let c = HerculesCPURef::from_slice(&c);
+        let mut r = runner!(test);
+        let res = r.run(N as u64, M as u64, K as u64, a, b, c).await;
         assert_eq!(res.as_slice::<i32>(), &*correct_res);
     });
 }
diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs
index 1f6e213c..4f9fe6a7 100644
--- a/juno_samples/simple3/src/main.rs
+++ b/juno_samples/simple3/src/main.rs
@@ -1,16 +1,17 @@
-#![feature(box_as_ptr, let_chains)]
+#![feature(concat_idents)]
 
-use hercules_rt::HerculesBox;
+use hercules_rt::{runner, HerculesCPURef};
 
 juno_build::juno!("simple3");
 
 fn main() {
     async_std::task::block_on(async {
-        let mut a: Box<[u32]> = Box::new([1, 2, 3, 4, 5, 6, 7, 8]);
-        let mut b: Box<[u32]> = Box::new([8, 7, 6, 5, 4, 3, 2, 1]);
-        let a = HerculesBox::from_slice_mut(&mut a);
-        let b = HerculesBox::from_slice_mut(&mut b);
-        let c = simple3(8, a, b).await;
+        let a: Box<[u32]> = Box::new([1, 2, 3, 4, 5, 6, 7, 8]);
+        let b: Box<[u32]> = Box::new([8, 7, 6, 5, 4, 3, 2, 1]);
+        let a = HerculesCPURef::from_slice(&a);
+        let b = HerculesCPURef::from_slice(&b);
+        let mut r = runner!(simple3);
+        let c = r.run(8, a, b).await;
         println!("{}", c);
         assert_eq!(c, 120);
     });
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 6205fa75..aa540064 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1324,6 +1324,9 @@ fn run_pass(
                         any_failed = true;
                     }
                     changed |= editor.modified();
+                    if any_failed {
+                        break;
+                    }
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
-- 
GitLab