#![feature(box_as_ptr, let_chains)]

use rand::random;

use hercules_rt::HerculesBox;

juno_build::juno!("matmul");

fn main() {
    async_std::task::block_on(async {
        const I: usize = 256;
        const J: usize = 64;
        const K: usize = 128;
        let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
        let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
        let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
        for i in 0..I {
            for k in 0..K {
                for j in 0..J {
                    correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
                }
            }
        }
        let a = HerculesBox::from_slice_mut(&mut a);
        let b = HerculesBox::from_slice_mut(&mut b);
        let mut c = matmul(I as u64, J as u64, K as u64, a, b).await;
        assert_eq!(c.as_slice::<i32>(), &*correct_c);
    });
}

#[test]
fn matmul_test() {
    main();
}