From a84d8c89f93ccab54cfdb7677f6518b10664a741 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 17:42:05 -0600 Subject: [PATCH] Use AlignedAlloc --- Cargo.lock | 1 - hercules_rt/src/lib.rs | 35 +++++++++++++++++++++++++------- hercules_samples/dot/Cargo.toml | 1 - hercules_samples/dot/src/main.rs | 6 ++---- juno_samples/matmul/src/main.rs | 4 ++-- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0cad8e19..1973fbbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -670,7 +670,6 @@ dependencies = [ name = "dot" version = "0.1.0" dependencies = [ - "aligned-vec", "async-std", "clap", "hercules_rt", diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 3b79dc48..a245a264 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,6 +1,6 @@ #![feature(once_cell_try, pointer_is_aligned_to)] -use std::alloc::{alloc, dealloc, Layout}; +use std::alloc::{alloc, dealloc, GlobalAlloc, Layout, System}; use std::marker::PhantomData; use std::ptr::{copy_nonoverlapping, write_bytes, NonNull}; use std::slice::{from_raw_parts, from_raw_parts_mut}; @@ -189,7 +189,7 @@ pub struct CUDABox { impl<'a> HerculesCPURef<'a> { pub fn from_slice<T>(slice: &'a [T]) -> Self { - assert!(slice.as_ptr().is_aligned_to(32)); + assert!(slice.as_ptr().is_aligned_to(LARGEST_ALIGNMENT)); let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; let size = slice.len() * size_of::<T>(); Self { @@ -214,7 +214,7 @@ impl<'a> HerculesCPURef<'a> { } pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { - assert!(ptr.is_aligned_to(32)); + assert!(ptr.is_aligned_to(LARGEST_ALIGNMENT)); Self { ptr: NonNull::new(ptr).unwrap(), size, @@ -225,7 +225,7 @@ impl<'a> HerculesCPURef<'a> { impl<'a> HerculesCPURefMut<'a> { pub fn from_slice<T>(slice: &'a mut [T]) -> Self { - assert!(slice.as_ptr().is_aligned_to(32)); + assert!(slice.as_ptr().is_aligned_to(LARGEST_ALIGNMENT)); let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; let size = slice.len() * size_of::<T>(); Self { @@ -259,7 +259,7 @@ impl<'a> HerculesCPURefMut<'a> { } pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { - assert!(ptr.is_aligned_to(32)); + assert!(ptr.is_aligned_to(LARGEST_ALIGNMENT)); Self { ptr: NonNull::new(ptr).unwrap(), size, @@ -271,7 +271,7 @@ impl<'a> HerculesCPURefMut<'a> { #[cfg(feature = "cuda")] impl<'a> HerculesCUDARef<'a> { pub fn to_cpu_ref<'b, T>(self, dst: &'b mut [T]) -> HerculesCPURefMut<'b> { - assert!(dst.as_ptr().is_aligned_to(32)); + assert!(dst.as_ptr().is_aligned_to(LARGEST_ALIGNMENT)); unsafe { let size = self.size; assert_eq!(size, dst.len() * size_of::<T>()); @@ -313,7 +313,7 @@ impl<'a> HerculesCUDARefMut<'a> { } pub fn to_cpu_ref<'b, T>(self, dst: &mut [T]) -> HerculesCPURefMut<'b> { - assert!(dst.as_ptr().is_aligned_to(32)); + assert!(dst.as_ptr().is_aligned_to(LARGEST_ALIGNMENT)); unsafe { let size = self.size; let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap(); @@ -872,3 +872,24 @@ impl<'a, T> HerculesRefInto<'a> for Box<[T]> { HerculesCPURef::from_slice(self) } } + +/* + * We need all allocations to be aligned to LARGEST_ALIGNMENT bytes for + * vectorization. This is the easiest way to do that. + */ +pub struct AlignedAlloc; + +unsafe impl GlobalAlloc for AlignedAlloc { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap(); + System.alloc(layout) + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap(); + System.dealloc(ptr, layout) + } +} + +#[global_allocator] +static A: AlignedAlloc = AlignedAlloc; diff --git a/hercules_samples/dot/Cargo.toml b/hercules_samples/dot/Cargo.toml index ab35cbaf..9b11ddc1 100644 --- a/hercules_samples/dot/Cargo.toml +++ b/hercules_samples/dot/Cargo.toml @@ -17,4 +17,3 @@ hercules_rt = { path = "../../hercules_rt" } rand = "*" async-std = "*" with_builtin_macros = "0.1.0" -aligned-vec = "*" diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 7bcaaeba..1f28cee2 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -4,14 +4,12 @@ use hercules_rt::CUDABox; use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; -use aligned_vec::ABox; - juno_build::juno!("dot"); fn main() { async_std::task::block_on(async { - let a: ABox<[f32; 8]> = ABox::new(32, [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]); - let b: ABox<[f32; 8]> = ABox::new(32, [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]); + let a: Box<[f32; 8]> = Box::new([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]); + let b: Box<[f32; 8]> = Box::new([0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]); let a = HerculesImmBox::from(a.as_ref() as &[f32]); let b = HerculesImmBox::from(b.as_ref() as &[f32]); let mut r = runner!(dot); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index cb078c74..29415b51 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -22,8 +22,8 @@ fn main() { } } } - let a = HerculesImmBox::from(&a as &[f32]); - let b = HerculesImmBox::from(&b as &[f32]); + let a = HerculesImmBox::from(a.as_ref()); + let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) { -- GitLab