Skip to content
Snippets Groups Projects
Commit 7a73df97 authored by rarbore2's avatar rarbore2
Browse files

Bigger matmul test

parent d609f6f6
No related branches found
No related tags found
1 merge request!85Bigger matmul test
......@@ -738,6 +738,7 @@ version = "0.1.0"
dependencies = [
"async-std",
"juno_build",
"rand",
"with_builtin_macros",
]
......
......@@ -2,43 +2,53 @@
extern crate async_std;
extern crate juno_build;
extern crate rand;
use core::ptr::copy_nonoverlapping;
use rand::random;
juno_build::juno!("matmul");
fn main() {
async_std::task::block_on(async {
let a: Box<[f32]> = Box::new([1.0, 2.0, 3.0, 4.0]);
let b: Box<[f32]> = Box::new([5.0, 6.0, 7.0, 8.0]);
let mut a_bytes: Box<[u8]> = Box::new([0; 16]);
let mut b_bytes: Box<[u8]> = Box::new([0; 16]);
const I: usize = 256;
const J: usize = 64;
const K: usize = 128;
let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]);
let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]);
unsafe {
copy_nonoverlapping(
Box::as_ptr(&a) as *const u8,
Box::as_mut_ptr(&mut a_bytes) as *mut u8,
16,
I * J * 4,
);
copy_nonoverlapping(
Box::as_ptr(&b) as *const u8,
Box::as_mut_ptr(&mut b_bytes) as *mut u8,
16,
J * K * 4,
);
};
let c_bytes = matmul(2, 2, 2, a_bytes, b_bytes).await;
let mut c: Box<[f32]> = Box::new([0.0; 4]);
let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
unsafe {
copy_nonoverlapping(
Box::as_ptr(&c_bytes) as *const u8,
Box::as_mut_ptr(&mut c) as *mut u8,
16,
I * K * 4,
);
};
println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]);
assert_eq!(c[0], 19.0);
assert_eq!(c[1], 22.0);
assert_eq!(c[2], 43.0);
assert_eq!(c[3], 50.0);
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
for i in 0..I {
for k in 0..K {
for j in 0..J {
correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
}
}
}
assert_eq!(c, correct_c);
});
}
......
fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2)
c = constant(array(f32, #0, #2), [])
fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2)
c = constant(array(i32, #0, #2), [])
i_j_ctrl = fork(start, #0, #2)
i_idx = thread_id(i_j_ctrl, 0)
j_idx = thread_id(i_j_ctrl, 1)
......@@ -8,48 +8,11 @@ fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2)
k_join_ctrl = join(k_ctrl)
i_j_join_ctrl = join(k_join_ctrl)
r = return(i_j_join_ctrl, update_i_j_c)
zero = constant(f32, 0)
zero = constant(i32, 0)
a_val = read(a, position(i_idx, k_idx))
b_val = read(b, position(k_idx, j_idx))
mul = mul(a_val, b_val)
add = add(mul, dot)
dot = reduce(k_join_ctrl, zero, add)
update_c = write(update_i_j_c, dot, position(i_idx, j_idx))
update_i_j_c = reduce(i_j_join_ctrl, c, update_c)
update_i_j_c = reduce(i_j_join_ctrl, c, update_c)
\ No newline at end of file
......@@ -15,3 +15,4 @@ juno_build = { path = "../../juno_build" }
juno_build = { path = "../../juno_build" }
with_builtin_macros = "0.1.0"
async-std = "*"
rand = "*"
......@@ -2,43 +2,53 @@
extern crate async_std;
extern crate juno_build;
extern crate rand;
use core::ptr::copy_nonoverlapping;
use rand::random;
juno_build::juno!("matmul");
fn main() {
async_std::task::block_on(async {
let a: Box<[f32]> = Box::new([1.0, 2.0, 3.0, 4.0]);
let b: Box<[f32]> = Box::new([5.0, 6.0, 7.0, 8.0]);
let mut a_bytes: Box<[u8]> = Box::new([0; 16]);
let mut b_bytes: Box<[u8]> = Box::new([0; 16]);
const I: usize = 256;
const J: usize = 64;
const K: usize = 128;
let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]);
let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]);
unsafe {
copy_nonoverlapping(
Box::as_ptr(&a) as *const u8,
Box::as_mut_ptr(&mut a_bytes) as *mut u8,
16,
I * J * 4,
);
copy_nonoverlapping(
Box::as_ptr(&b) as *const u8,
Box::as_mut_ptr(&mut b_bytes) as *mut u8,
16,
J * K * 4,
);
};
let c_bytes = matmul(2, 2, 2, a_bytes, b_bytes).await;
let mut c: Box<[f32]> = Box::new([0.0; 4]);
let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
unsafe {
copy_nonoverlapping(
Box::as_ptr(&c_bytes) as *const u8,
Box::as_mut_ptr(&mut c) as *mut u8,
16,
I * K * 4,
);
};
println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]);
assert_eq!(c[0], 19.0);
assert_eq!(c[1], 22.0);
assert_eq!(c[2], 43.0);
assert_eq!(c[3], 50.0);
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
for i in 0..I {
for k in 0..K {
for j in 0..J {
correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
}
}
}
assert_eq!(c, correct_c);
});
}
......@@ -46,4 +56,3 @@ fn main() {
fn matmul_test() {
main();
}
#[entry]
fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] {
let res : f32[n, l];
fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : i32[n, l];
@outer for i = 0 to n {
@middle for j = 0 to l {
......@@ -13,3 +13,51 @@ fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[
@exit
return res;
}
/*
#[entry]
fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : i32[n, l];
for bi = 0 to n / 64 {
for bk = 0 to l / 64 {
// TODO: make these all the same size, clone analysis should undo GVN's
// combining of these three arrays.
let atile : i32[66, 64];
let btile : i32[65, 64];
let ctile : i32[64, 64];
for tile_idx = 0 to m / 64 {
for ti = 0 to 64 {
for tk = 0 to 64 {
atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk];
btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk];
// TODO: remove setting ctile to zero explicitly, clone analysis
// should see a lack of a phi for ctile in the block loops and
// induce a copy of an initial value of ctile (all zeros) on each
// iteration of the block loops.
ctile[ti, tk] = 0;
}
}
for ti = 0 to 64 {
for tk = 0 to 64 {
let c_acc = ctile[ti, tk];
for inner_idx = 0 to 64 {
c_acc += atile[ti, inner_idx] * btile[inner_idx, tk];
}
ctile[ti, tk] = c_acc;
}
}
}
for ti = 0 to 64 {
for tk = 0 to 64 {
res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk];
}
}
}
}
return res;
}
*/
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment