From 69d7a09c27fe28d43b99709e9858bc349ae32cb5 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 16:09:27 -0600 Subject: [PATCH] assert hercules cpu refs are given aligned pointers --- hercules_rt/src/lib.rs | 9 ++++++-- hercules_samples/matmul/src/main.rs | 36 +++++++++-------------------- juno_samples/matmul/src/main.rs | 30 ++++++------------------ 3 files changed, 25 insertions(+), 50 deletions(-) diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 419a760f..3b79dc48 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,4 +1,4 @@ -#![feature(once_cell_try)] +#![feature(once_cell_try, pointer_is_aligned_to)] use std::alloc::{alloc, dealloc, Layout}; use std::marker::PhantomData; @@ -189,6 +189,7 @@ pub struct CUDABox { impl<'a> HerculesCPURef<'a> { pub fn from_slice<T>(slice: &'a [T]) -> Self { + assert!(slice.as_ptr().is_aligned_to(32)); let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; let size = slice.len() * size_of::<T>(); Self { @@ -201,7 +202,6 @@ impl<'a> HerculesCPURef<'a> { pub fn as_slice<T>(self) -> &'a [T] { let ptr = self.ptr.as_ptr() as *const T; assert_eq!(self.size % size_of::<T>(), 0); - assert!(ptr.is_aligned()); unsafe { from_raw_parts(ptr, self.size / size_of::<T>()) } } @@ -214,6 +214,7 @@ impl<'a> HerculesCPURef<'a> { } pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { + assert!(ptr.is_aligned_to(32)); Self { ptr: NonNull::new(ptr).unwrap(), size, @@ -224,6 +225,7 @@ impl<'a> HerculesCPURef<'a> { impl<'a> HerculesCPURefMut<'a> { pub fn from_slice<T>(slice: &'a mut [T]) -> Self { + assert!(slice.as_ptr().is_aligned_to(32)); let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; let size = slice.len() * size_of::<T>(); Self { @@ -257,6 +259,7 @@ impl<'a> HerculesCPURefMut<'a> { } pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { + assert!(ptr.is_aligned_to(32)); Self { ptr: NonNull::new(ptr).unwrap(), size, @@ -268,6 +271,7 @@ impl<'a> HerculesCPURefMut<'a> { #[cfg(feature = "cuda")] impl<'a> HerculesCUDARef<'a> { pub fn to_cpu_ref<'b, T>(self, dst: &'b mut [T]) -> HerculesCPURefMut<'b> { + assert!(dst.as_ptr().is_aligned_to(32)); unsafe { let size = self.size; assert_eq!(size, dst.len() * size_of::<T>()); @@ -309,6 +313,7 @@ impl<'a> HerculesCUDARefMut<'a> { } pub fn to_cpu_ref<'b, T>(self, dst: &mut [T]) -> HerculesCPURefMut<'b> { + assert!(dst.as_ptr().is_aligned_to(32)); unsafe { let size = self.size; let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap(); diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 5c879915..cd49b8e9 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -1,10 +1,9 @@ #![feature(concat_idents)] +use std::iter::zip; use rand::random; -#[cfg(feature = "cuda")] -use hercules_rt::CUDABox; -use hercules_rt::{runner, HerculesCPURef}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; juno_build::juno!("matmul"); @@ -13,9 +12,9 @@ fn main() { const I: usize = 256; const J: usize = 8; // hardcoded constant in matmul.hir const K: usize = 128; - let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); - let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect(); + let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect(); + let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect(); for i in 0..I { for k in 0..K { for j in 0..J { @@ -23,25 +22,12 @@ fn main() { } } } - #[cfg(not(feature = "cuda"))] - { - let a = HerculesCPURef::from_slice(&mut a); - let b = HerculesCPURef::from_slice(&mut b); - let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a, b).await; - assert_eq!(c.as_slice::<i32>(), &*correct_c); - } - #[cfg(feature = "cuda")] - { - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); - let mut r = runner!(matmul); - let c = r - .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) - .await; - let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); - c.to_cpu_ref(&mut c_cpu); - assert_eq!(&*c_cpu, &*correct_c); + let a = HerculesImmBox::from(&a as &[f32]); + let b = HerculesImmBox::from(&b as &[f32]); + let mut r = runner!(matmul); + let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) { + assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct); } }); } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index f91b7d8a..cb078c74 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -3,9 +3,7 @@ use std::iter::zip; use rand::random; -use hercules_rt::{runner, HerculesRefInto}; -#[cfg(feature = "cuda")] -use hercules_rt::{CUDABox, HerculesCPURef}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; juno_build::juno!("matmul"); @@ -24,26 +22,12 @@ fn main() { } } } - #[cfg(not(feature = "cuda"))] - { - let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; - let c = c.as_slice::<f32>(); - assert_eq!(c, &*correct_c); - } - #[cfg(feature = "cuda")] - { - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a)); - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); - let mut r = runner!(matmul); - let c = r - .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) - .await; - let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice(); - c.to_cpu_ref(&mut c_cpu); - for (calc, correct) in zip(c_cpu, correct_c) { - assert!((calc - correct).abs() < 0.00001, "{} != {}", calc, correct); - } + let a = HerculesImmBox::from(&a as &[f32]); + let b = HerculesImmBox::from(&b as &[f32]); + let mut r = runner!(matmul); + let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) { + assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct); } }); } -- GitLab