diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index bba6ac42a4a2479f8309b575fe1fb1030f5b5a21..3e89ae90ead672ab1092531e950b5f741b16d5bf 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -87,7 +87,7 @@ pub fn reduce_cycles( function: &Function, def_use: &ImmutableDefUseMap, fork_join_map: &HashMap<NodeID, NodeID>, - nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { let reduces = (0..function.nodes.len()) .filter(|idx| function.nodes[*idx].is_reduce()) @@ -101,6 +101,24 @@ pub fn reduce_cycles( for reduce in reduces { let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; + let might_be_in_fork_join = |id| { + fork_join_nest + .get(&id) + .map(|nest| nest.contains(&fork)) + .unwrap_or(true) + && function.nodes[id.idx()] + .try_phi() + .map(|(control, _)| fork_join_nest[&control].contains(&fork)) + .unwrap_or(true) + && function.nodes[id.idx()] + .try_reduce() + .map(|(control, _, _)| fork_join_nest[&control].contains(&fork)) + .unwrap_or(true) + && !function.nodes[id.idx()].is_parameter() + && !function.nodes[id.idx()].is_constant() + && !function.nodes[id.idx()].is_dynamic_constant() + && !function.nodes[id.idx()].is_undef() + }; // Find nodes in the fork-join that the reduce can reach through uses. let mut reachable_uses = HashSet::new(); @@ -109,10 +127,7 @@ pub fn reduce_cycles( workset.push(reduct); while let Some(pop) = workset.pop() { for u in get_uses(&function.nodes[pop.idx()]).as_ref() { - if !reachable_uses.contains(u) - && nodes_in_fork_joins[&fork].contains(u) - && *u != reduce - { + if !reachable_uses.contains(u) && might_be_in_fork_join(*u) && *u != reduce { reachable_uses.insert(*u); workset.push(*u); } @@ -126,10 +141,7 @@ pub fn reduce_cycles( workset.push(reduce); while let Some(pop) = workset.pop() { for u in def_use.get_users(pop) { - if !reachable_users.contains(u) - && nodes_in_fork_joins[&fork].contains(u) - && *u != reduce - { + if !reachable_users.contains(u) && might_be_in_fork_join(*u) && *u != reduce { reachable_users.insert(*u); workset.push(*u); } @@ -155,6 +167,7 @@ pub fn nodes_in_fork_joins( function: &Function, def_use: &ImmutableDefUseMap, fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { let mut result = HashMap::new(); @@ -164,6 +177,7 @@ pub fn nodes_in_fork_joins( let mut set = HashSet::new(); set.insert(*fork); + // Iterate uses of the fork. while let Some(item) = worklist.pop() { for u in def_use.get_users(item) { let terminate = *u == *join @@ -177,6 +191,15 @@ pub fn nodes_in_fork_joins( set.insert(*u); } } + assert!(set.contains(join)); + + // Add all the nodes in the reduce cycle. Some of these nodes may not + // use thread IDs of the fork, so do this explicitly. + for u in def_use.get_users(*join) { + if let Some(cycle) = reduce_cycles.get(u) { + set.extend(cycle); + } + } result.insert(*fork, set); } diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 8c339d728bf3e139a8960673fcd687806c7def70..16e5c3264d33a7c9bef85fc0fa3cec02963dbf48 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -292,7 +292,11 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { // gravestone. for id in deleted_nodeids.iter() { // Check that there are no users of deleted nodes. - assert!(editor.mut_def_use[id.idx()].is_empty()); + assert!( + editor.mut_def_use[id.idx()].is_empty(), + "PANIC: Attempted to delete node {:?}, but there are still users of this node ({:?}).", + id, editor.mut_def_use[id.idx()] + ); editor.function.nodes[id.idx()] = Node::Start; } diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index e6eb3641fef1496e0b82534f1bc449b1027374a6..edb83d74a54bd72c0207923b82485eff8905cbb6 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -1,34 +1,23 @@ -gvn(*); phi-elim(*); + +forkify(*); +fork-guard-elim(*); dce(*); -let out = auto-outline(*); -gpu(out.matmul, out.tiled_64_matmul); +fixpoint { + reduce-slf(*); + slf(*); + infer-schedules(*); +} +fork-coalesce(*); +infer-schedules(*); +dce(*); +let out = auto-outline(*); +gpu(out.matmul); ip-sroa(*); sroa(*); dce(*); -gvn(*); -phi-elim(*); -dce(*); -fixpoint panic after 20 { - forkify(out.matmul); - fork-guard-elim(out.matmul); -} -gvn(out.matmul); -phi-elim(out.matmul); -dce(out.matmul); - -fixpoint panic after 20 { - reduce-slf(out.matmul); - slf(out.matmul); - infer-schedules(out.matmul); -} -fork-coalesce(out.matmul); -dce(out.matmul); - -gcm(*); float-collections(*); -dce(*); gcm(*); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 98b6c7776089b2dee6e67ea11f14d7914cf82025..2eb2804b33003b9383a981ac652d0e71142ba2a5 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -32,11 +32,6 @@ fn main() { .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); - let mut r = runner!(tiled_64_matmul); - let tiled_c = r - .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) - .await; - assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); } #[cfg(feature = "cuda")] { @@ -49,13 +44,6 @@ fn main() { let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); assert_eq!(&*c_cpu, &*correct_c); - let mut r = runner!(tiled_64_matmul); - let tiled_c = r - .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) - .await; - let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); - tiled_c.to_cpu_ref(&mut tiled_c_cpu); - assert_eq!(&*tiled_c_cpu, &*correct_c); } }); } diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index fb6de5bd89e0c48b3a0ac4aa66d4764894f2fcff..e36d94e209a2385fc28e54b1cfe4c432564d3706 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -10,52 +10,5 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[ } } - @exit - return res; -} - -#[entry] -fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; - let atile : i32[64, 64]; - let btile : i32[64, 64]; - let ctile : i32[64, 64]; - - for bi = 0 to n / 64 { - for bk = 0 to l / 64 { - for ti = 0 to 64 { - for tk = 0 to 64 { - atile[ti, tk] = 0; - btile[ti, tk] = 0; - ctile[ti, tk] = 0; - } - } - - for tile_idx = 0 to m / 64 { - for ti = 0 to 64 { - for tk = 0 to 64 { - atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk]; - btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk]; - } - } - for ti = 0 to 64 { - for tk = 0 to 64 { - let c_acc = ctile[ti, tk]; - for inner_idx = 0 to 64 { - c_acc += atile[ti, inner_idx] * btile[inner_idx, tk]; - } - ctile[ti, tk] = c_acc; - } - } - } - - for ti = 0 to 64 { - for tk = 0 to 64 { - res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk]; - } - } - } - } - return res; } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index a4783a93701564f041237f2fecef8a22c62abe5c..8e152cfe2cfed6b73e4e881abce09677e5704409 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -401,21 +401,19 @@ impl PassManager { if self.reduce_cycles.is_none() { self.make_def_uses(); self.make_fork_join_maps(); - self.make_nodes_in_fork_joins(); + self.make_fork_join_nests(); let def_uses = self.def_uses.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter(); + let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); self.reduce_cycles = Some( self.functions .iter() .zip(fork_join_maps) - .zip(nodes_in_fork_joins) + .zip(fork_join_nests) .zip(def_uses) - .map( - |(((function, fork_join_map), nodes_in_fork_joins), def_use)| { - reduce_cycles(function, def_use, fork_join_map, nodes_in_fork_joins) - }, - ) + .map(|(((function, fork_join_map), fork_join_nests), def_use)| { + reduce_cycles(function, def_use, fork_join_map, fork_join_nests) + }) .collect(), ); } @@ -425,16 +423,20 @@ impl PassManager { if self.nodes_in_fork_joins.is_none() { self.make_def_uses(); self.make_fork_join_maps(); + self.make_reduce_cycles(); self.nodes_in_fork_joins = Some( zip( self.functions.iter(), zip( self.def_uses.as_ref().unwrap().iter(), - self.fork_join_maps.as_ref().unwrap().iter(), + zip( + self.fork_join_maps.as_ref().unwrap().iter(), + self.reduce_cycles.as_ref().unwrap().iter(), + ), ), ) - .map(|(function, (def_use, fork_join_map))| { - nodes_in_fork_joins(function, def_use, fork_join_map) + .map(|(function, (def_use, (fork_join_map, reduce_cycles)))| { + nodes_in_fork_joins(function, def_use, fork_join_map, reduce_cycles) }) .collect(), );