From 69d7a09c27fe28d43b99709e9858bc349ae32cb5 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 16:09:27 -0600
Subject: [PATCH] assert hercules cpu refs are given aligned pointers

---
 hercules_rt/src/lib.rs              |  9 ++++++--
 hercules_samples/matmul/src/main.rs | 36 +++++++++--------------------
 juno_samples/matmul/src/main.rs     | 30 ++++++------------------
 3 files changed, 25 insertions(+), 50 deletions(-)

diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 419a760f..3b79dc48 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -1,4 +1,4 @@
-#![feature(once_cell_try)]
+#![feature(once_cell_try, pointer_is_aligned_to)]
 
 use std::alloc::{alloc, dealloc, Layout};
 use std::marker::PhantomData;
@@ -189,6 +189,7 @@ pub struct CUDABox {
 
 impl<'a> HerculesCPURef<'a> {
     pub fn from_slice<T>(slice: &'a [T]) -> Self {
+        assert!(slice.as_ptr().is_aligned_to(32));
         let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
         let size = slice.len() * size_of::<T>();
         Self {
@@ -201,7 +202,6 @@ impl<'a> HerculesCPURef<'a> {
     pub fn as_slice<T>(self) -> &'a [T] {
         let ptr = self.ptr.as_ptr() as *const T;
         assert_eq!(self.size % size_of::<T>(), 0);
-        assert!(ptr.is_aligned());
         unsafe { from_raw_parts(ptr, self.size / size_of::<T>()) }
     }
 
@@ -214,6 +214,7 @@ impl<'a> HerculesCPURef<'a> {
     }
 
     pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
+        assert!(ptr.is_aligned_to(32));
         Self {
             ptr: NonNull::new(ptr).unwrap(),
             size,
@@ -224,6 +225,7 @@ impl<'a> HerculesCPURef<'a> {
 
 impl<'a> HerculesCPURefMut<'a> {
     pub fn from_slice<T>(slice: &'a mut [T]) -> Self {
+        assert!(slice.as_ptr().is_aligned_to(32));
         let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
         let size = slice.len() * size_of::<T>();
         Self {
@@ -257,6 +259,7 @@ impl<'a> HerculesCPURefMut<'a> {
     }
 
     pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
+        assert!(ptr.is_aligned_to(32));
         Self {
             ptr: NonNull::new(ptr).unwrap(),
             size,
@@ -268,6 +271,7 @@ impl<'a> HerculesCPURefMut<'a> {
 #[cfg(feature = "cuda")]
 impl<'a> HerculesCUDARef<'a> {
     pub fn to_cpu_ref<'b, T>(self, dst: &'b mut [T]) -> HerculesCPURefMut<'b> {
+        assert!(dst.as_ptr().is_aligned_to(32));
         unsafe {
             let size = self.size;
             assert_eq!(size, dst.len() * size_of::<T>());
@@ -309,6 +313,7 @@ impl<'a> HerculesCUDARefMut<'a> {
     }
 
     pub fn to_cpu_ref<'b, T>(self, dst: &mut [T]) -> HerculesCPURefMut<'b> {
+        assert!(dst.as_ptr().is_aligned_to(32));
         unsafe {
             let size = self.size;
             let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 5c879915..cd49b8e9 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -1,10 +1,9 @@
 #![feature(concat_idents)]
+use std::iter::zip;
 
 use rand::random;
 
-#[cfg(feature = "cuda")]
-use hercules_rt::CUDABox;
-use hercules_rt::{runner, HerculesCPURef};
+use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
 
 juno_build::juno!("matmul");
 
@@ -13,9 +12,9 @@ fn main() {
         const I: usize = 256;
         const J: usize = 8; // hardcoded constant in matmul.hir
         const K: usize = 128;
-        let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
-        let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
-        let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
+        let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect();
+        let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect();
+        let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect();
         for i in 0..I {
             for k in 0..K {
                 for j in 0..J {
@@ -23,25 +22,12 @@ fn main() {
                 }
             }
         }
-        #[cfg(not(feature = "cuda"))]
-        {
-            let a = HerculesCPURef::from_slice(&mut a);
-            let b = HerculesCPURef::from_slice(&mut b);
-            let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a, b).await;
-            assert_eq!(c.as_slice::<i32>(), &*correct_c);
-        }
-        #[cfg(feature = "cuda")]
-        {
-            let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
-            let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
-            let mut r = runner!(matmul);
-            let c = r
-                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
-                .await;
-            let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
-            c.to_cpu_ref(&mut c_cpu);
-            assert_eq!(&*c_cpu, &*correct_c);
+        let a = HerculesImmBox::from(&a as &[f32]);
+        let b = HerculesImmBox::from(&b as &[f32]);
+        let mut r = runner!(matmul);
+        let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
+        for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) {
+            assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct);
         }
     });
 }
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index f91b7d8a..cb078c74 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -3,9 +3,7 @@ use std::iter::zip;
 
 use rand::random;
 
-use hercules_rt::{runner, HerculesRefInto};
-#[cfg(feature = "cuda")]
-use hercules_rt::{CUDABox, HerculesCPURef};
+use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
 
 juno_build::juno!("matmul");
 
@@ -24,26 +22,12 @@ fn main() {
                 }
             }
         }
-        #[cfg(not(feature = "cuda"))]
-        {
-            let mut r = runner!(matmul);
-            let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await;
-            let c = c.as_slice::<f32>();
-            assert_eq!(c, &*correct_c);
-        }
-        #[cfg(feature = "cuda")]
-        {
-            let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a));
-            let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b));
-            let mut r = runner!(matmul);
-            let c = r
-                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
-                .await;
-            let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice();
-            c.to_cpu_ref(&mut c_cpu);
-            for (calc, correct) in zip(c_cpu, correct_c) {
-                assert!((calc - correct).abs() < 0.00001, "{} != {}", calc, correct);
-            }
+        let a = HerculesImmBox::from(&a as &[f32]);
+        let b = HerculesImmBox::from(&b as &[f32]);
+        let mut r = runner!(matmul);
+        let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
+        for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) {
+            assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct);
         }
     });
 }
-- 
GitLab