main.rs 1.23 KiB
#![feature(box_as_ptr, let_chains)]
use rand::random;
use hercules_rt::HerculesBox;
juno_build::juno!("matmul");
fn main() {
async_std::task::block_on(async {
const I: usize = 256;
const J: usize = 64;
const K: usize = 128;
let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
for i in 0..I {
for k in 0..K {
for j in 0..J {
correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
}
}
}
let mut c = {
let a = HerculesBox::from_slice(&a);
let b = HerculesBox::from_slice(&b);
matmul(I as u64, J as u64, K as u64, a, b).await
};
assert_eq!(c.as_slice::<i32>(), &*correct_c);
let mut tiled_c = {
let a = HerculesBox::from_slice(&a);
let b = HerculesBox::from_slice(&b);
tiled_64_matmul(I as u64, J as u64, K as u64, a, b).await
};
assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
});
}
#[test]
fn matmul_test() {
main();
}