Skip to content
Snippets Groups Projects
Commit 7c06d712 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'bigger_matmul_test' into 'main'

Bigger matmul test

See merge request !85
parents d609f6f6 7a73df97
No related branches found
No related tags found
1 merge request!85Bigger matmul test
Pipeline #200742 passed
...@@ -738,6 +738,7 @@ version = "0.1.0" ...@@ -738,6 +738,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"async-std", "async-std",
"juno_build", "juno_build",
"rand",
"with_builtin_macros", "with_builtin_macros",
] ]
......
...@@ -2,43 +2,53 @@ ...@@ -2,43 +2,53 @@
extern crate async_std; extern crate async_std;
extern crate juno_build; extern crate juno_build;
extern crate rand;
use core::ptr::copy_nonoverlapping; use core::ptr::copy_nonoverlapping;
use rand::random;
juno_build::juno!("matmul"); juno_build::juno!("matmul");
fn main() { fn main() {
async_std::task::block_on(async { async_std::task::block_on(async {
let a: Box<[f32]> = Box::new([1.0, 2.0, 3.0, 4.0]); const I: usize = 256;
let b: Box<[f32]> = Box::new([5.0, 6.0, 7.0, 8.0]); const J: usize = 64;
let mut a_bytes: Box<[u8]> = Box::new([0; 16]); const K: usize = 128;
let mut b_bytes: Box<[u8]> = Box::new([0; 16]); let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]);
let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]);
unsafe { unsafe {
copy_nonoverlapping( copy_nonoverlapping(
Box::as_ptr(&a) as *const u8, Box::as_ptr(&a) as *const u8,
Box::as_mut_ptr(&mut a_bytes) as *mut u8, Box::as_mut_ptr(&mut a_bytes) as *mut u8,
16, I * J * 4,
); );
copy_nonoverlapping( copy_nonoverlapping(
Box::as_ptr(&b) as *const u8, Box::as_ptr(&b) as *const u8,
Box::as_mut_ptr(&mut b_bytes) as *mut u8, Box::as_mut_ptr(&mut b_bytes) as *mut u8,
16, J * K * 4,
); );
}; };
let c_bytes = matmul(2, 2, 2, a_bytes, b_bytes).await; let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
let mut c: Box<[f32]> = Box::new([0.0; 4]); let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
unsafe { unsafe {
copy_nonoverlapping( copy_nonoverlapping(
Box::as_ptr(&c_bytes) as *const u8, Box::as_ptr(&c_bytes) as *const u8,
Box::as_mut_ptr(&mut c) as *mut u8, Box::as_mut_ptr(&mut c) as *mut u8,
16, I * K * 4,
); );
}; };
println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
assert_eq!(c[0], 19.0); for i in 0..I {
assert_eq!(c[1], 22.0); for k in 0..K {
assert_eq!(c[2], 43.0); for j in 0..J {
assert_eq!(c[3], 50.0); correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
}
}
}
assert_eq!(c, correct_c);
}); });
} }
......
fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2) fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2)
c = constant(array(f32, #0, #2), []) c = constant(array(i32, #0, #2), [])
i_j_ctrl = fork(start, #0, #2) i_j_ctrl = fork(start, #0, #2)
i_idx = thread_id(i_j_ctrl, 0) i_idx = thread_id(i_j_ctrl, 0)
j_idx = thread_id(i_j_ctrl, 1) j_idx = thread_id(i_j_ctrl, 1)
...@@ -8,48 +8,11 @@ fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2) ...@@ -8,48 +8,11 @@ fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2)
k_join_ctrl = join(k_ctrl) k_join_ctrl = join(k_ctrl)
i_j_join_ctrl = join(k_join_ctrl) i_j_join_ctrl = join(k_join_ctrl)
r = return(i_j_join_ctrl, update_i_j_c) r = return(i_j_join_ctrl, update_i_j_c)
zero = constant(f32, 0) zero = constant(i32, 0)
a_val = read(a, position(i_idx, k_idx)) a_val = read(a, position(i_idx, k_idx))
b_val = read(b, position(k_idx, j_idx)) b_val = read(b, position(k_idx, j_idx))
mul = mul(a_val, b_val) mul = mul(a_val, b_val)
add = add(mul, dot) add = add(mul, dot)
dot = reduce(k_join_ctrl, zero, add) dot = reduce(k_join_ctrl, zero, add)
update_c = write(update_i_j_c, dot, position(i_idx, j_idx)) update_c = write(update_i_j_c, dot, position(i_idx, j_idx))
update_i_j_c = reduce(i_j_join_ctrl, c, update_c) update_i_j_c = reduce(i_j_join_ctrl, c, update_c)
\ No newline at end of file
...@@ -15,3 +15,4 @@ juno_build = { path = "../../juno_build" } ...@@ -15,3 +15,4 @@ juno_build = { path = "../../juno_build" }
juno_build = { path = "../../juno_build" } juno_build = { path = "../../juno_build" }
with_builtin_macros = "0.1.0" with_builtin_macros = "0.1.0"
async-std = "*" async-std = "*"
rand = "*"
...@@ -2,43 +2,53 @@ ...@@ -2,43 +2,53 @@
extern crate async_std; extern crate async_std;
extern crate juno_build; extern crate juno_build;
extern crate rand;
use core::ptr::copy_nonoverlapping; use core::ptr::copy_nonoverlapping;
use rand::random;
juno_build::juno!("matmul"); juno_build::juno!("matmul");
fn main() { fn main() {
async_std::task::block_on(async { async_std::task::block_on(async {
let a: Box<[f32]> = Box::new([1.0, 2.0, 3.0, 4.0]); const I: usize = 256;
let b: Box<[f32]> = Box::new([5.0, 6.0, 7.0, 8.0]); const J: usize = 64;
let mut a_bytes: Box<[u8]> = Box::new([0; 16]); const K: usize = 128;
let mut b_bytes: Box<[u8]> = Box::new([0; 16]); let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]);
let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]);
unsafe { unsafe {
copy_nonoverlapping( copy_nonoverlapping(
Box::as_ptr(&a) as *const u8, Box::as_ptr(&a) as *const u8,
Box::as_mut_ptr(&mut a_bytes) as *mut u8, Box::as_mut_ptr(&mut a_bytes) as *mut u8,
16, I * J * 4,
); );
copy_nonoverlapping( copy_nonoverlapping(
Box::as_ptr(&b) as *const u8, Box::as_ptr(&b) as *const u8,
Box::as_mut_ptr(&mut b_bytes) as *mut u8, Box::as_mut_ptr(&mut b_bytes) as *mut u8,
16, J * K * 4,
); );
}; };
let c_bytes = matmul(2, 2, 2, a_bytes, b_bytes).await; let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
let mut c: Box<[f32]> = Box::new([0.0; 4]); let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
unsafe { unsafe {
copy_nonoverlapping( copy_nonoverlapping(
Box::as_ptr(&c_bytes) as *const u8, Box::as_ptr(&c_bytes) as *const u8,
Box::as_mut_ptr(&mut c) as *mut u8, Box::as_mut_ptr(&mut c) as *mut u8,
16, I * K * 4,
); );
}; };
println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
assert_eq!(c[0], 19.0); for i in 0..I {
assert_eq!(c[1], 22.0); for k in 0..K {
assert_eq!(c[2], 43.0); for j in 0..J {
assert_eq!(c[3], 50.0); correct_c[i * K + k] += a[i * J + j] * b[j * K + k];
}
}
}
assert_eq!(c, correct_c);
}); });
} }
...@@ -46,4 +56,3 @@ fn main() { ...@@ -46,4 +56,3 @@ fn main() {
fn matmul_test() { fn matmul_test() {
main(); main();
} }
#[entry] #[entry]
fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] { fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : f32[n, l]; let res : i32[n, l];
@outer for i = 0 to n { @outer for i = 0 to n {
@middle for j = 0 to l { @middle for j = 0 to l {
...@@ -13,3 +13,51 @@ fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[ ...@@ -13,3 +13,51 @@ fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[
@exit @exit
return res; return res;
} }
/*
#[entry]
fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : i32[n, l];
for bi = 0 to n / 64 {
for bk = 0 to l / 64 {
// TODO: make these all the same size, clone analysis should undo GVN's
// combining of these three arrays.
let atile : i32[66, 64];
let btile : i32[65, 64];
let ctile : i32[64, 64];
for tile_idx = 0 to m / 64 {
for ti = 0 to 64 {
for tk = 0 to 64 {
atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk];
btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk];
// TODO: remove setting ctile to zero explicitly, clone analysis
// should see a lack of a phi for ctile in the block loops and
// induce a copy of an initial value of ctile (all zeros) on each
// iteration of the block loops.
ctile[ti, tk] = 0;
}
}
for ti = 0 to 64 {
for tk = 0 to 64 {
let c_acc = ctile[ti, tk];
for inner_idx = 0 to 64 {
c_acc += atile[ti, inner_idx] * btile[inner_idx, tk];
}
ctile[ti, tk] = c_acc;
}
}
}
for ti = 0 to 64 {
for tk = 0 to 64 {
res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk];
}
}
}
}
return res;
}
*/
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment