diff --git a/Cargo.lock b/Cargo.lock index e9e4f311440fbafe440d4734b35fbcc54365bd3e..73ab201c405dd39e11fb32dc014cde9a8e44fc12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -738,6 +738,7 @@ version = "0.1.0" dependencies = [ "async-std", "juno_build", + "rand", "with_builtin_macros", ] diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 12c14249aa62502c766028cef3c0518cf0fb4633..93d007c791579a75dea65d5680ab3018e9b00085 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -2,43 +2,53 @@ extern crate async_std; extern crate juno_build; +extern crate rand; use core::ptr::copy_nonoverlapping; +use rand::random; + juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { - let a: Box<[f32]> = Box::new([1.0, 2.0, 3.0, 4.0]); - let b: Box<[f32]> = Box::new([5.0, 6.0, 7.0, 8.0]); - let mut a_bytes: Box<[u8]> = Box::new([0; 16]); - let mut b_bytes: Box<[u8]> = Box::new([0; 16]); + const I: usize = 256; + const J: usize = 64; + const K: usize = 128; + let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]); + let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]); unsafe { copy_nonoverlapping( Box::as_ptr(&a) as *const u8, Box::as_mut_ptr(&mut a_bytes) as *mut u8, - 16, + I * J * 4, ); copy_nonoverlapping( Box::as_ptr(&b) as *const u8, Box::as_mut_ptr(&mut b_bytes) as *mut u8, - 16, + J * K * 4, ); }; - let c_bytes = matmul(2, 2, 2, a_bytes, b_bytes).await; - let mut c: Box<[f32]> = Box::new([0.0; 4]); + let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await; + let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); unsafe { copy_nonoverlapping( Box::as_ptr(&c_bytes) as *const u8, Box::as_mut_ptr(&mut c) as *mut u8, - 16, + I * K * 4, ); }; - println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]); - assert_eq!(c[0], 19.0); - assert_eq!(c[1], 22.0); - assert_eq!(c[2], 43.0); - assert_eq!(c[3], 50.0); + let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + for i in 0..I { + for k in 0..K { + for j in 0..J { + correct_c[i * K + k] += a[i * J + j] * b[j * K + k]; + } + } + } + assert_eq!(c, correct_c); }); } diff --git a/hercules_samples/matmul/src/matmul.hir b/hercules_samples/matmul/src/matmul.hir index 8bbccfdfc012d09610c21fe024d9b709374d18de..ab0f384a563ccb6144e59b811745fe5aa76f08dd 100644 --- a/hercules_samples/matmul/src/matmul.hir +++ b/hercules_samples/matmul/src/matmul.hir @@ -1,5 +1,5 @@ -fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2) - c = constant(array(f32, #0, #2), []) +fn matmul<3>(a: array(i32, #0, #1), b: array(i32, #1, #2)) -> array(i32, #0, #2) + c = constant(array(i32, #0, #2), []) i_j_ctrl = fork(start, #0, #2) i_idx = thread_id(i_j_ctrl, 0) j_idx = thread_id(i_j_ctrl, 1) @@ -8,48 +8,11 @@ fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2) k_join_ctrl = join(k_ctrl) i_j_join_ctrl = join(k_join_ctrl) r = return(i_j_join_ctrl, update_i_j_c) - zero = constant(f32, 0) + zero = constant(i32, 0) a_val = read(a, position(i_idx, k_idx)) b_val = read(b, position(k_idx, j_idx)) mul = mul(a_val, b_val) add = add(mul, dot) dot = reduce(k_join_ctrl, zero, add) update_c = write(update_i_j_c, dot, position(i_idx, j_idx)) - update_i_j_c = reduce(i_j_join_ctrl, c, update_c) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + update_i_j_c = reduce(i_j_join_ctrl, c, update_c) \ No newline at end of file diff --git a/juno_samples/matmul/Cargo.toml b/juno_samples/matmul/Cargo.toml index c272fc443df485aaacd80fe5fdc882bd4d02225c..ea705dddd2fac0e4b5a4b8fe0ddfeef72039e3c4 100644 --- a/juno_samples/matmul/Cargo.toml +++ b/juno_samples/matmul/Cargo.toml @@ -15,3 +15,4 @@ juno_build = { path = "../../juno_build" } juno_build = { path = "../../juno_build" } with_builtin_macros = "0.1.0" async-std = "*" +rand = "*" diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 6ec3dae763672075b5410f1b0350c56504f36068..865beaf5d9979aac85ecdf0d7c9183dc5b209cf2 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -2,43 +2,53 @@ extern crate async_std; extern crate juno_build; +extern crate rand; use core::ptr::copy_nonoverlapping; +use rand::random; + juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { - let a: Box<[f32]> = Box::new([1.0, 2.0, 3.0, 4.0]); - let b: Box<[f32]> = Box::new([5.0, 6.0, 7.0, 8.0]); - let mut a_bytes: Box<[u8]> = Box::new([0; 16]); - let mut b_bytes: Box<[u8]> = Box::new([0; 16]); + const I: usize = 256; + const J: usize = 64; + const K: usize = 128; + let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let mut a_bytes: Box<[u8]> = Box::new([0; I * J * 4]); + let mut b_bytes: Box<[u8]> = Box::new([0; J * K * 4]); unsafe { copy_nonoverlapping( Box::as_ptr(&a) as *const u8, Box::as_mut_ptr(&mut a_bytes) as *mut u8, - 16, + I * J * 4, ); copy_nonoverlapping( Box::as_ptr(&b) as *const u8, Box::as_mut_ptr(&mut b_bytes) as *mut u8, - 16, + J * K * 4, ); }; - let c_bytes = matmul(2, 2, 2, a_bytes, b_bytes).await; - let mut c: Box<[f32]> = Box::new([0.0; 4]); + let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await; + let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); unsafe { copy_nonoverlapping( Box::as_ptr(&c_bytes) as *const u8, Box::as_mut_ptr(&mut c) as *mut u8, - 16, + I * K * 4, ); }; - println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]); - assert_eq!(c[0], 19.0); - assert_eq!(c[1], 22.0); - assert_eq!(c[2], 43.0); - assert_eq!(c[3], 50.0); + let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + for i in 0..I { + for k in 0..K { + for j in 0..J { + correct_c[i * K + k] += a[i * J + j] * b[j * K + k]; + } + } + } + assert_eq!(c, correct_c); }); } @@ -46,4 +56,3 @@ fn main() { fn matmul_test() { main(); } - diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index 2dc5ec3ddccf7283a421b838960f5cadb3923d88..bcfa7afba69b684a6decd7b5a66ee50d3c077921 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -1,6 +1,6 @@ #[entry] -fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] { - let res : f32[n, l]; +fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { + let res : i32[n, l]; @outer for i = 0 to n { @middle for j = 0 to l { @@ -13,3 +13,51 @@ fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[ @exit return res; } + +/* +#[entry] +fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { + let res : i32[n, l]; + + for bi = 0 to n / 64 { + for bk = 0 to l / 64 { + // TODO: make these all the same size, clone analysis should undo GVN's + // combining of these three arrays. + let atile : i32[66, 64]; + let btile : i32[65, 64]; + let ctile : i32[64, 64]; + + for tile_idx = 0 to m / 64 { + for ti = 0 to 64 { + for tk = 0 to 64 { + atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk]; + btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk]; + // TODO: remove setting ctile to zero explicitly, clone analysis + // should see a lack of a phi for ctile in the block loops and + // induce a copy of an initial value of ctile (all zeros) on each + // iteration of the block loops. + ctile[ti, tk] = 0; + } + } + for ti = 0 to 64 { + for tk = 0 to 64 { + let c_acc = ctile[ti, tk]; + for inner_idx = 0 to 64 { + c_acc += atile[ti, inner_idx] * btile[inner_idx, tk]; + } + ctile[ti, tk] = c_acc; + } + } + } + + for ti = 0 to 64 { + for tk = 0 to 64 { + res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk]; + } + } + } + } + + return res; +} +*/ \ No newline at end of file