From 966a8cc49c1a90da3fd4dfdef71aeec7fe242840 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 17:54:12 -0600 Subject: [PATCH] wtf --- hercules_samples/matmul/src/main.rs | 6 +++--- hercules_samples/matmul/src/matmul.hir | 4 ++-- juno_samples/concat/src/main.rs | 14 ++++++-------- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index cd49b8e9..29415b51 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -10,7 +10,7 @@ juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { const I: usize = 256; - const J: usize = 8; // hardcoded constant in matmul.hir + const J: usize = 64; const K: usize = 128; let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect(); let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect(); @@ -22,8 +22,8 @@ fn main() { } } } - let a = HerculesImmBox::from(&a as &[f32]); - let b = HerculesImmBox::from(&b as &[f32]); + let a = HerculesImmBox::from(a.as_ref()); + let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) { diff --git a/hercules_samples/matmul/src/matmul.hir b/hercules_samples/matmul/src/matmul.hir index f9d37afc..b0c31da4 100644 --- a/hercules_samples/matmul/src/matmul.hir +++ b/hercules_samples/matmul/src/matmul.hir @@ -1,9 +1,9 @@ -fn matmul<3>(a: array(i32, #0, 8), b: array(i32, 8, #2)) -> array(i32, #0, #2) +fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2) c = constant(array(i32, #0, #2), []) i_j_ctrl = fork(start, #0, #2) i_idx = thread_id(i_j_ctrl, 0) j_idx = thread_id(i_j_ctrl, 1) - k_ctrl = fork(i_j_ctrl, 8) + k_ctrl = fork(i_j_ctrl, #1) k_idx = thread_id(k_ctrl, 0) k_join_ctrl = join(k_ctrl) i_j_join_ctrl = join(k_join_ctrl) diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index 547dee08..2f704f16 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -10,12 +10,12 @@ juno_build::juno!("concat"); fn main() { async_std::task::block_on(async { let mut r = runner!(concat_entry); - let mut a_data = [7, 7, 0]; - let mut b_data = [7, 7, 0, 0, 7, 7]; + let mut a_data = Box::new([7, 7, 0]); + let mut b_data = Box::new([7, 7, 0, 0, 7, 7]); #[cfg(not(feature = "cuda"))] { - let a = HerculesCPURef::from_slice(&mut a_data); - let b = HerculesCPURef::from_slice(&mut b_data); + let a = HerculesCPURef::from_slice(a_data.as_ref()); + let b = HerculesCPURef::from_slice(b_data.as_ref()); let output = r.run(3, 6, a, b).await; assert_eq!(output, 42); @@ -36,10 +36,8 @@ fn main() { } #[cfg(feature = "cuda")] { - let mut a_data = [7, 7, 0]; - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a_data)); - let mut b_data = [7, 7, 0, 0, 7, 7]; - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b_data)); + let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(a_data.as_ref())); + let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(b_data.as_ref())); let output = r.run(3, 6, a.get_ref(), b.get_ref()).await; assert_eq!(output, 42); } -- GitLab