diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 29415b511992946b08a1496f3eb92d957615d8aa..277276648e905186bfeb54714fb00f7275f17b22 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -12,9 +12,9 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect(); - let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect(); - let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect(); + let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { for j in 0..J { @@ -25,10 +25,8 @@ fn main() { let a = HerculesImmBox::from(a.as_ref()); let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); - let mut c = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); - for (calc, correct) in zip(c.as_slice().into_iter().map(|x: &mut f32| *x), correct_c) { - assert!((calc - correct).abs() < 0.0001, "{} != {}", calc, correct); - } + let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + assert_eq!(c.as_slice(), correct_c.as_ref()); }); }