#![feature(concat_idents)]
use std::iter::zip;

use rand::random;

use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};

juno_build::juno!("matmul");

fn main() {
    async_std::task::block_on(async {
        const I: usize = 256;
        const J: usize = 64;
        const K: usize = 128;
        let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
        let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
        let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
        for i in 0..I {
            for k in 0..K {
                for j in 0..J {
                    correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
                }
            }
        }
        let a = HerculesImmBox::from(a.as_ref());
        let b = HerculesImmBox::from(b.as_ref());
        let mut r = runner!(matmul);
        let mut c: HerculesMutBox<i32> =
            HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await);
        assert_eq!(c.as_slice(), correct_c.as_ref());
    });
}

#[test]
fn matmul_test() {
    main();
}