From a84d8c89f93ccab54cfdb7677f6518b10664a741 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 17:42:05 -0600
Subject: [PATCH] Use AlignedAlloc

---
 Cargo.lock                       |  1 -
 hercules_rt/src/lib.rs           | 35 +++++++++++++++++++++++++-------
 hercules_samples/dot/Cargo.toml  |  1 -
 hercules_samples/dot/src/main.rs |  6 ++----
 juno_samples/matmul/src/main.rs  |  4 ++--
 5 files changed, 32 insertions(+), 15 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 0cad8e19..1973fbbe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -670,7 +670,6 @@ dependencies = [
 name = "dot"
 version = "0.1.0"
 dependencies = [
- "aligned-vec",
  "async-std",
  "clap",
  "hercules_rt",
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 3b79dc48..a245a264 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -1,6 +1,6 @@
 #![feature(once_cell_try, pointer_is_aligned_to)]
 
-use std::alloc::{alloc, dealloc, Layout};
+use std::alloc::{alloc, dealloc, GlobalAlloc, Layout, System};
 use std::marker::PhantomData;
 use std::ptr::{copy_nonoverlapping, write_bytes, NonNull};
 use std::slice::{from_raw_parts, from_raw_parts_mut};
@@ -189,7 +189,7 @@ pub struct CUDABox {
 
 impl<'a> HerculesCPURef<'a> {
     pub fn from_slice<T>(slice: &'a [T]) -> Self {
-        assert!(slice.as_ptr().is_aligned_to(32));
+        assert!(slice.as_ptr().is_aligned_to(LARGEST_ALIGNMENT));
         let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
         let size = slice.len() * size_of::<T>();
         Self {
@@ -214,7 +214,7 @@ impl<'a> HerculesCPURef<'a> {
     }
 
     pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
-        assert!(ptr.is_aligned_to(32));
+        assert!(ptr.is_aligned_to(LARGEST_ALIGNMENT));
         Self {
             ptr: NonNull::new(ptr).unwrap(),
             size,
@@ -225,7 +225,7 @@ impl<'a> HerculesCPURef<'a> {
 
 impl<'a> HerculesCPURefMut<'a> {
     pub fn from_slice<T>(slice: &'a mut [T]) -> Self {
-        assert!(slice.as_ptr().is_aligned_to(32));
+        assert!(slice.as_ptr().is_aligned_to(LARGEST_ALIGNMENT));
         let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) };
         let size = slice.len() * size_of::<T>();
         Self {
@@ -259,7 +259,7 @@ impl<'a> HerculesCPURefMut<'a> {
     }
 
     pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self {
-        assert!(ptr.is_aligned_to(32));
+        assert!(ptr.is_aligned_to(LARGEST_ALIGNMENT));
         Self {
             ptr: NonNull::new(ptr).unwrap(),
             size,
@@ -271,7 +271,7 @@ impl<'a> HerculesCPURefMut<'a> {
 #[cfg(feature = "cuda")]
 impl<'a> HerculesCUDARef<'a> {
     pub fn to_cpu_ref<'b, T>(self, dst: &'b mut [T]) -> HerculesCPURefMut<'b> {
-        assert!(dst.as_ptr().is_aligned_to(32));
+        assert!(dst.as_ptr().is_aligned_to(LARGEST_ALIGNMENT));
         unsafe {
             let size = self.size;
             assert_eq!(size, dst.len() * size_of::<T>());
@@ -313,7 +313,7 @@ impl<'a> HerculesCUDARefMut<'a> {
     }
 
     pub fn to_cpu_ref<'b, T>(self, dst: &mut [T]) -> HerculesCPURefMut<'b> {
-        assert!(dst.as_ptr().is_aligned_to(32));
+        assert!(dst.as_ptr().is_aligned_to(LARGEST_ALIGNMENT));
         unsafe {
             let size = self.size;
             let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
@@ -872,3 +872,24 @@ impl<'a, T> HerculesRefInto<'a> for Box<[T]> {
         HerculesCPURef::from_slice(self)
     }
 }
+
+/*
+ * We need all allocations to be aligned to LARGEST_ALIGNMENT bytes for
+ * vectorization. This is the easiest way to do that.
+ */
+pub struct AlignedAlloc;
+
+unsafe impl GlobalAlloc for AlignedAlloc {
+    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
+        let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap();
+        System.alloc(layout)
+    }
+
+    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
+        let layout = layout.align_to(LARGEST_ALIGNMENT).unwrap();
+        System.dealloc(ptr, layout)
+    }
+}
+
+#[global_allocator]
+static A: AlignedAlloc = AlignedAlloc;
diff --git a/hercules_samples/dot/Cargo.toml b/hercules_samples/dot/Cargo.toml
index ab35cbaf..9b11ddc1 100644
--- a/hercules_samples/dot/Cargo.toml
+++ b/hercules_samples/dot/Cargo.toml
@@ -17,4 +17,3 @@ hercules_rt = { path = "../../hercules_rt" }
 rand = "*"
 async-std = "*"
 with_builtin_macros = "0.1.0"
-aligned-vec = "*"
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 7bcaaeba..1f28cee2 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -4,14 +4,12 @@
 use hercules_rt::CUDABox;
 use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
 
-use aligned_vec::ABox;
-
 juno_build::juno!("dot");
 
 fn main() {
     async_std::task::block_on(async {
-        let a: ABox<[f32; 8]> = ABox::new(32, [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]);
-        let b: ABox<[f32; 8]> = ABox::new(32, [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]);
+        let a: Box<[f32; 8]> = Box::new([0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]);
+        let b: Box<[f32; 8]> = Box::new([0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]);
         let a = HerculesImmBox::from(a.as_ref() as &[f32]);
         let b = HerculesImmBox::from(b.as_ref() as &[f32]);
         let mut r = runner!(dot);
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index cb078c74..29415b51 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -22,8 +22,8 @@ fn main() {
                 }
             }
         }
-        let a = HerculesImmBox::from(&a as &[f32]);
-        let b = HerculesImmBox::from(&b as &[f32]);
+        let a = HerculesImmBox::from(a.as_ref());
+        let b = HerculesImmBox::from(b.as_ref());
         let mut r = runner!(matmul);
         let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
         for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) {
-- 
GitLab