diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 0ed61c7cfdbc2c9847402677c5ed2e4caed57590..88d700aa7940dad960a0f929600763df03bdad84 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -167,6 +167,7 @@ pub fn nodes_in_fork_joins( function: &Function, def_use: &ImmutableDefUseMap, fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { let mut result = HashMap::new(); @@ -188,11 +189,14 @@ pub fn nodes_in_fork_joins( worklist.push(*u); } set.insert(*u); + + // Nodes in reduce cycles might not depend on the thread ID. + if terminate && let Some(cycle) = reduce_cycles.get(u) { + set.extend(cycle); + } } } - // Nodes in reduce cycles might not depend on the thread ID. - result.insert(*fork, set); } diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index e6eb3641fef1496e0b82534f1bc449b1027374a6..edb83d74a54bd72c0207923b82485eff8905cbb6 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -1,34 +1,23 @@ -gvn(*); phi-elim(*); + +forkify(*); +fork-guard-elim(*); dce(*); -let out = auto-outline(*); -gpu(out.matmul, out.tiled_64_matmul); +fixpoint { + reduce-slf(*); + slf(*); + infer-schedules(*); +} +fork-coalesce(*); +infer-schedules(*); +dce(*); +let out = auto-outline(*); +gpu(out.matmul); ip-sroa(*); sroa(*); dce(*); -gvn(*); -phi-elim(*); -dce(*); -fixpoint panic after 20 { - forkify(out.matmul); - fork-guard-elim(out.matmul); -} -gvn(out.matmul); -phi-elim(out.matmul); -dce(out.matmul); - -fixpoint panic after 20 { - reduce-slf(out.matmul); - slf(out.matmul); - infer-schedules(out.matmul); -} -fork-coalesce(out.matmul); -dce(out.matmul); - -gcm(*); float-collections(*); -dce(*); gcm(*); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 98b6c7776089b2dee6e67ea11f14d7914cf82025..2eb2804b33003b9383a981ac652d0e71142ba2a5 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -32,11 +32,6 @@ fn main() { .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); - let mut r = runner!(tiled_64_matmul); - let tiled_c = r - .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) - .await; - assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); } #[cfg(feature = "cuda")] { @@ -49,13 +44,6 @@ fn main() { let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); assert_eq!(&*c_cpu, &*correct_c); - let mut r = runner!(tiled_64_matmul); - let tiled_c = r - .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) - .await; - let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); - tiled_c.to_cpu_ref(&mut tiled_c_cpu); - assert_eq!(&*tiled_c_cpu, &*correct_c); } }); } diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index fb6de5bd89e0c48b3a0ac4aa66d4764894f2fcff..e36d94e209a2385fc28e54b1cfe4c432564d3706 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -10,52 +10,5 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[ } } - @exit - return res; -} - -#[entry] -fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; - let atile : i32[64, 64]; - let btile : i32[64, 64]; - let ctile : i32[64, 64]; - - for bi = 0 to n / 64 { - for bk = 0 to l / 64 { - for ti = 0 to 64 { - for tk = 0 to 64 { - atile[ti, tk] = 0; - btile[ti, tk] = 0; - ctile[ti, tk] = 0; - } - } - - for tile_idx = 0 to m / 64 { - for ti = 0 to 64 { - for tk = 0 to 64 { - atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk]; - btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk]; - } - } - for ti = 0 to 64 { - for tk = 0 to 64 { - let c_acc = ctile[ti, tk]; - for inner_idx = 0 to 64 { - c_acc += atile[ti, inner_idx] * btile[inner_idx, tk]; - } - ctile[ti, tk] = c_acc; - } - } - } - - for ti = 0 to 64 { - for tk = 0 to 64 { - res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk]; - } - } - } - } - return res; } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 12d00b95a41b1e6fc67f980c5cca3014602327ae..8e152cfe2cfed6b73e4e881abce09677e5704409 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -423,16 +423,20 @@ impl PassManager { if self.nodes_in_fork_joins.is_none() { self.make_def_uses(); self.make_fork_join_maps(); + self.make_reduce_cycles(); self.nodes_in_fork_joins = Some( zip( self.functions.iter(), zip( self.def_uses.as_ref().unwrap().iter(), - self.fork_join_maps.as_ref().unwrap().iter(), + zip( + self.fork_join_maps.as_ref().unwrap().iter(), + self.reduce_cycles.as_ref().unwrap().iter(), + ), ), ) - .map(|(function, (def_use, fork_join_map))| { - nodes_in_fork_joins(function, def_use, fork_join_map) + .map(|(function, (def_use, (fork_join_map, reduce_cycles)))| { + nodes_in_fork_joins(function, def_use, fork_join_map, reduce_cycles) }) .collect(), );