From 966a8cc49c1a90da3fd4dfdef71aeec7fe242840 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 17:54:12 -0600
Subject: [PATCH] wtf

---
 hercules_samples/matmul/src/main.rs    |  6 +++---
 hercules_samples/matmul/src/matmul.hir |  4 ++--
 juno_samples/concat/src/main.rs        | 14 ++++++--------
 3 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index cd49b8e9..29415b51 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -10,7 +10,7 @@ juno_build::juno!("matmul");
 fn main() {
     async_std::task::block_on(async {
         const I: usize = 256;
-        const J: usize = 8; // hardcoded constant in matmul.hir
+        const J: usize = 64;
         const K: usize = 128;
         let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect();
         let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect();
@@ -22,8 +22,8 @@ fn main() {
                 }
             }
         }
-        let a = HerculesImmBox::from(&a as &[f32]);
-        let b = HerculesImmBox::from(&b as &[f32]);
+        let a = HerculesImmBox::from(a.as_ref());
+        let b = HerculesImmBox::from(b.as_ref());
         let mut r = runner!(matmul);
         let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
         for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) {
diff --git a/hercules_samples/matmul/src/matmul.hir b/hercules_samples/matmul/src/matmul.hir
index f9d37afc..b0c31da4 100644
--- a/hercules_samples/matmul/src/matmul.hir
+++ b/hercules_samples/matmul/src/matmul.hir
@@ -1,9 +1,9 @@
-fn matmul<3>(a: array(i32, #0, 8), b: array(i32, 8, #2)) -> array(i32, #0, #2)
+fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2)
   c = constant(array(i32, #0, #2), [])
   i_j_ctrl = fork(start, #0, #2)
   i_idx = thread_id(i_j_ctrl, 0)
   j_idx = thread_id(i_j_ctrl, 1)
-  k_ctrl = fork(i_j_ctrl, 8)
+  k_ctrl = fork(i_j_ctrl, #1)
   k_idx = thread_id(k_ctrl, 0)
   k_join_ctrl = join(k_ctrl)
   i_j_join_ctrl = join(k_join_ctrl)
diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs
index 547dee08..2f704f16 100644
--- a/juno_samples/concat/src/main.rs
+++ b/juno_samples/concat/src/main.rs
@@ -10,12 +10,12 @@ juno_build::juno!("concat");
 fn main() {
     async_std::task::block_on(async {
         let mut r = runner!(concat_entry);
-        let mut a_data = [7, 7, 0];
-        let mut b_data = [7, 7, 0, 0, 7, 7];
+        let mut a_data = Box::new([7, 7, 0]);
+        let mut b_data = Box::new([7, 7, 0, 0, 7, 7]);
         #[cfg(not(feature = "cuda"))]
         {
-            let a = HerculesCPURef::from_slice(&mut a_data);
-            let b = HerculesCPURef::from_slice(&mut b_data);
+            let a = HerculesCPURef::from_slice(a_data.as_ref());
+            let b = HerculesCPURef::from_slice(b_data.as_ref());
             let output = r.run(3, 6, a, b).await;
             assert_eq!(output, 42);
 
@@ -36,10 +36,8 @@ fn main() {
         }
         #[cfg(feature = "cuda")]
         {
-            let mut a_data = [7, 7, 0];
-            let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a_data));
-            let mut b_data = [7, 7, 0, 0, 7, 7];
-            let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b_data));
+            let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(a_data.as_ref()));
+            let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(b_data.as_ref()));
             let output = r.run(3, 6, a.get_ref(), b.get_ref()).await;
             assert_eq!(output, 42);
         }
-- 
GitLab