From 423235ed094c4530228faf3e1c4985188ea5c030 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Wed, 1 Jan 2025 19:45:14 -0600 Subject: [PATCH] Incorporate anti-dependencies into GCM --- hercules_ir/src/antideps.rs | 54 +++++++++++-------- hercules_ir/src/gcm.rs | 38 +++++++++++-- hercules_ir/src/subgraph.rs | 71 ++++++++++++++++++++++++- hercules_opt/src/materialize_clones.rs | 6 ++- hercules_opt/src/pass.rs | 24 +++++++-- juno_samples/antideps/src/antideps.jn | 58 ++++++++++++++++++-- juno_samples/antideps/src/main.rs | 6 ++- juno_samples/implicit_clone/src/main.rs | 2 +- juno_samples/matmul/src/main.rs | 19 ++++++- juno_samples/matmul/src/matmul.jn | 2 - 10 files changed, 241 insertions(+), 39 deletions(-) diff --git a/hercules_ir/src/antideps.rs b/hercules_ir/src/antideps.rs index 8aa9f7d9..af949708 100644 --- a/hercules_ir/src/antideps.rs +++ b/hercules_ir/src/antideps.rs @@ -43,24 +43,14 @@ pub fn antideps( objects: &FunctionCollectionObjects, ) -> Vec<(NodeID, NodeID)> { // First, we analyze "generations" of collections as they are mutated. - // Originating nodes and mutating nodes start a new generation of a + // Originating, mutating, phi, and reduce nodes start a new generation of a // collection. Generations are not ordered due to loops, but are rather just - // node IDs of the originating or mutating node (parameter, constant, call, - // undef, write). Other nodes operating on collections mean reads / writes - // can operate on potentially different generations of multiple collections - // (phi, reduce, select). + // node IDs of node (parameter, constant, call, undef, write, phi, reduce). + // Other nodes operating on collections mean reads / writes can operate on + // potentially different generations of multiple collections (select). let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| { match function.nodes[id.idx()] { - Node::Phi { - control: _, - data: _, - } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } - | Node::Ternary { + Node::Ternary { op: TernaryOperator::Select, first: _, second: _, @@ -121,12 +111,21 @@ pub fn antideps( collect: _, indices: _, } => inputs[0].clone(), - Node::Write { + Node::Phi { + control: _, + data: _, + } + | Node::Reduce { + control: _, + init: _, + reduct: _, + } + | Node::Write { collect: _, data: _, indices: _, } => { - // Writes update the generation to the write. + // Phis, reduces, and writes update the generation to the write. let objs = inputs[0].objs.iter().map(|(obj, _)| (*obj, id)).collect(); GenerationLattice { objs } } @@ -169,7 +168,7 @@ pub fn antideps( } Node::Write { collect, - data: _, + data, indices: _, } => { for pair in lattice[collect.idx()].objs.iter() { @@ -179,6 +178,19 @@ pub fn antideps( .1 .push(id); } + + // When a write takes a collection on its `data` input, it + // memcpys that collection into the mutated collection. This is + // a read. + if !objects.objects(*data).is_empty() { + for pair in lattice[collect.idx()].objs.iter() { + reads_writes_calls_mut_calls_per_pair + .entry(*pair) + .or_default() + .0 + .push(id); + } + } } Node::Call { control: _, @@ -219,7 +231,7 @@ pub fn antideps( // Case 1: for read in reads.iter() { for write in writes.iter() { - if *write != gen { + if *write != gen && *read != *write { antideps.push((*read, *write)); } } @@ -228,7 +240,7 @@ pub fn antideps( // Case 2: for read in reads.iter() { for mut_call in mut_calls.iter() { - if *mut_call != gen { + if *mut_call != gen && *read != *mut_call { antideps.push((*read, *mut_call)); } } @@ -237,7 +249,7 @@ pub fn antideps( // Case 3: for call in calls.iter().chain(mut_calls.iter()) { for write in writes.iter() { - if *write != gen { + if *write != gen && *call != *write { antideps.push((*call, *write)); } } diff --git a/hercules_ir/src/gcm.rs b/hercules_ir/src/gcm.rs index a8fd0a21..3718df9b 100644 --- a/hercules_ir/src/gcm.rs +++ b/hercules_ir/src/gcm.rs @@ -30,12 +30,18 @@ pub fn gcm( function: &Function, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, + control_subgraph: &Subgraph, dom: &DomTree, antideps: &Vec<(NodeID, NodeID)>, loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, ) -> BasicBlocks { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; + let back_edges = control_subgraph.back_edges(NodeID::new(0)); + let no_loop_reachability = + control_subgraph.pairwise_reachability(|src, dst| !back_edges.contains(&(src, dst))); + let antideps_users = map_antideps(antideps); + let antideps_uses = flip_antideps(antideps); // Step 1: assign the basic block locations of all nodes that must be in a // specific block. This includes control nodes as well as some special data @@ -103,7 +109,7 @@ pub fn gcm( .map(|(fork, join)| (*join, *fork)) .collect(); let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev()); - while let Some(id) = worklist.pop_front() { + 'worklist: while let Some(id) = worklist.pop_front() { if bbs[id.idx()].is_some() { continue; } @@ -169,13 +175,38 @@ pub fn gcm( continue; }; + // Check if all anti-dependency users have been placed. If one of them + // hasn't, then add this node back on to the worklist. We need to know + // where the anti-dependency users are, so that we can place this + // read "above" mutators that anti-depend on it. The condition for a + // potential placement location is that in the CFG *without loop back- + // edges* the mutator cannot reach the read. Ask Russel about why this + // works, hopefully I'll have a convincing argument by then ;). + let mut antidep_user_locations = vec![]; + for antidep_user in antideps_users.get(&id).unwrap_or(&vec![]) { + if let Some(location) = bbs[antidep_user.idx()] { + antidep_user_locations.push(location); + } else { + worklist.push_back(id); + continue 'worklist; + } + } + // Look between the LCA and the schedule early location to place the // node. let schedule_early = schedule_early[id.idx()].unwrap(); let mut chain = dom // If the node has no users, then it doesn't really matter where we // place it - just place it at the early placement. - .chain(lca.unwrap_or(schedule_early), schedule_early); + .chain(lca.unwrap_or(schedule_early), schedule_early) + // Only allow locations that don't violate the anti-depence property + // listed above. + .filter(|location| { + !antidep_user_locations.iter().any(|antidep_user_location| { + antidep_user_location != location + && no_loop_reachability[antidep_user_location.idx()][location.idx()] + }) + }); let mut location = chain.next().unwrap(); while let Some(control_node) = chain.next() { // If the next node further up the dominator tree is in a shallower @@ -211,7 +242,6 @@ pub fn gcm( .filter(|id| !function.nodes[id.idx()].is_control()), ); let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; - let antideps = flip_antideps(&antideps); while let Some(id) = worklist.pop_front() { let node = &function.nodes[id.idx()]; if node.is_phi() @@ -219,7 +249,7 @@ pub fn gcm( || get_uses(node) .as_ref() .into_iter() - .chain(antideps.get(&id).into_iter().flatten()) + .chain(antideps_uses.get(&id).into_iter().flatten()) .all(|u| { function.nodes[u.idx()].is_control() || bbs[u.idx()] != bbs[id.idx()] diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs index de342579..89e8bcc6 100644 --- a/hercules_ir/src/subgraph.rs +++ b/hercules_ir/src/subgraph.rs @@ -1,6 +1,11 @@ -use crate::*; +extern crate bitvec; + +use std::collections::{HashMap, HashSet}; +use std::mem::take; + +use self::bitvec::prelude::*; -use std::collections::HashMap; +use crate::*; /* * In various parts of the compiler, we want to consider a subset of a complete @@ -166,6 +171,68 @@ impl Subgraph { original_num_nodes, } } + + pub fn back_edges(&self, root: NodeID) -> HashSet<(NodeID, NodeID)> { + let mut edges = HashSet::new(); + let mut stack = vec![]; + let mut visited = bitvec![u8, Lsb0; 0; self.original_num_nodes as usize]; + let mut in_stack = bitvec![u8, Lsb0; 0; self.original_num_nodes as usize]; + + stack.push(root); + visited.set(root.idx(), true); + in_stack.set(root.idx(), true); + + while let Some(pop) = stack.pop() { + if self.succs(pop).any(|succ| !visited[succ.idx()]) { + stack.push(pop); + } else { + in_stack.set(pop.idx(), false); + } + for succ in self.succs(pop) { + if in_stack[succ.idx()] { + edges.insert((pop, succ)); + } else if !visited[succ.idx()] { + visited.set(succ.idx(), true); + stack.push(succ); + in_stack.set(succ.idx(), true); + break; + } + } + } + + edges + } + + pub fn pairwise_reachability<P>(&self, p: P) -> Vec<BitVec<u8, Lsb0>> + where + P: Fn(NodeID, NodeID) -> bool, + { + let mut reachability = vec![ + bitvec![u8, Lsb0; 0; self.original_num_nodes as usize]; + self.original_num_nodes as usize + ]; + + // Every control node can reach itself. + for node in self.iter() { + reachability[node.idx()].set(node.idx(), true); + } + + // Propagate N - 1 times, as that is the length of the longest path. + for _ in 0..self.original_num_nodes { + for node in self.iter() { + for succ in self.succs(*node) { + assert_ne!(*node, succ); + if p(*node, succ) { + let mut node_reachability = take(&mut reachability[node.idx()]); + node_reachability |= &reachability[succ.idx()]; + reachability[node.idx()] = node_reachability; + } + } + } + } + + reachability + } } /* diff --git a/hercules_opt/src/materialize_clones.rs b/hercules_opt/src/materialize_clones.rs index 79002d15..687ac10c 100644 --- a/hercules_opt/src/materialize_clones.rs +++ b/hercules_opt/src/materialize_clones.rs @@ -12,6 +12,10 @@ use crate::*; * in the backend that collections have reference, rather than value, semantics. * The pass calling this function is mandatory for correctness. */ -pub fn materialize_clones(editor: &mut FunctionEditor, objects: &FunctionCollectionObjects) { +pub fn materialize_clones( + editor: &mut FunctionEditor, + objects: &FunctionCollectionObjects, + bbs: &BasicBlocks, +) { todo!() } diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index c02e967e..bcaba374 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -284,12 +284,14 @@ impl PassManager { if self.bbs.is_none() { self.make_def_uses(); self.make_reverse_postorders(); + self.make_control_subgraphs(); self.make_doms(); self.make_antideps(); self.make_loops(); self.make_fork_join_maps(); let def_uses = self.def_uses.as_ref().unwrap().iter(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter(); + let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter(); let doms = self.doms.as_ref().unwrap().iter(); let antideps = self.antideps.as_ref().unwrap().iter(); let loops = self.loops.as_ref().unwrap().iter(); @@ -301,19 +303,29 @@ impl PassManager { def_uses, zip( reverse_postorders, - zip(doms, zip(antideps, zip(loops, fork_join_maps))), + zip( + control_subgraphs, + zip(doms, zip(antideps, zip(loops, fork_join_maps))), + ), ), ), ) .map( |( function, - (def_use, (reverse_postorder, (dom, (antideps, (loops, fork_join_map))))), + ( + def_use, + ( + reverse_postorder, + (control_subgraph, (dom, (antideps, (loops, fork_join_map)))), + ), + ), )| { gcm( function, def_use, reverse_postorder, + control_subgraph, dom, antideps, loops, @@ -802,8 +814,10 @@ impl PassManager { Pass::MaterializeClones => { self.make_def_uses(); self.make_collection_objects(); + self.make_bbs(); let def_uses = self.def_uses.as_ref().unwrap(); let collection_objects = self.collection_objects.as_ref().unwrap(); + let bbs = self.bbs.as_ref().unwrap(); for idx in 0..self.module.functions.len() { let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); @@ -817,7 +831,11 @@ impl PassManager { &types_ref, &def_uses[idx], ); - materialize_clones(&mut editor, &collection_objects[&FunctionID::new(idx)]); + materialize_clones( + &mut editor, + &collection_objects[&FunctionID::new(idx)], + &bbs[idx], + ); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn index f54b8dcf..5949c91a 100644 --- a/juno_samples/antideps/src/antideps.jn +++ b/juno_samples/antideps/src/antideps.jn @@ -9,16 +9,68 @@ fn simple_antideps(a : usize, b : usize) -> i32 { #[entry] fn complex_antideps(x : i32) -> i32 { let arr : i32[4]; - arr[1] = 7; - let r = 0; + let arr2 : i32[12]; + arr[1] = 7 + arr2[0]; + let r = arr2[1]; + arr2[1] = 99; while x > 6 { + arr2[10] = 98; if x > 5 { - r = arr[1]; + r = arr[1] + arr2[(x + 1) as usize]; } else { + arr2[11] = 97; arr[1] = 8; r = arr[1]; } + arr2[x as usize] = 13; x -= 1; } return r; } + +#[entry] +fn very_complex_antideps(x: usize) -> usize { + let arr1 : usize[203]; + let arr2 : usize[202]; + let arr3 : usize[201]; + let arr4 : usize[200]; + let y = x; + while y < 10 { + arr1[y] = y; + arr2[y] = arr1[x]; + y += 4; + continue; + } + let z = arr1[x] + arr1[x + 4] + arr2[x + 4]; + if z == 13 { + z = 12; + } + while true { + if arr1[x] == 3 { + z = 11; + break; + } else { + z = 10; + continue; + } + } + if z == 11 { + arr3[z - 10] = 99; + arr4[arr3[z - 10] - 98] = 4; + arr4[arr4[1]] = 42; + } else { + arr1[203] = 99999999; + arr1[204] = 99999999; + arr1[1000] = 99999999; + } + let w = arr4[4] + arr2[7]; + while z > 0 { + arr2[w] = arr3[1]; + arr4[w] = arr2[w]; + z += z; + if z > 100 { + break; + } + } + return arr4[w] + w; +} diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs index e613f5ff..b0a99163 100644 --- a/juno_samples/antideps/src/main.rs +++ b/juno_samples/antideps/src/main.rs @@ -13,7 +13,11 @@ fn main() { let output = complex_antideps(9).await; println!("{}", output); - assert_eq!(output, 7); + assert_eq!(output, 20); + + let output = very_complex_antideps(3).await; + println!("{}", output); + assert_eq!(output, 144); }); } diff --git a/juno_samples/implicit_clone/src/main.rs b/juno_samples/implicit_clone/src/main.rs index 5ff952a3..ca7ddeb1 100644 --- a/juno_samples/implicit_clone/src/main.rs +++ b/juno_samples/implicit_clone/src/main.rs @@ -9,7 +9,7 @@ fn main() { async_std::task::block_on(async { let output = implicit_clone(3).await; println!("{}", output); - assert_eq!(output, 9); + assert_eq!(output, 11); }); } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 865beaf5..948459dd 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -31,7 +31,14 @@ fn main() { J * K * 4, ); }; - let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await; + let c_bytes = matmul( + I as u64, + J as u64, + K as u64, + a_bytes.clone(), + b_bytes.clone(), + ) + .await; let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); unsafe { copy_nonoverlapping( @@ -40,6 +47,15 @@ fn main() { I * K * 4, ); }; + let tiled_c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await; + let mut tiled_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + unsafe { + copy_nonoverlapping( + Box::as_ptr(&tiled_c_bytes) as *const u8, + Box::as_mut_ptr(&mut tiled_c) as *mut u8, + I * K * 4, + ); + }; let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { @@ -49,6 +65,7 @@ fn main() { } } assert_eq!(c, correct_c); + assert_eq!(tiled_c, correct_c); }); } diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index bcfa7afb..775bb382 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -14,7 +14,6 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[ return res; } -/* #[entry] fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { let res : i32[n, l]; @@ -60,4 +59,3 @@ fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l] return res; } -*/ \ No newline at end of file -- GitLab