diff --git a/Cargo.lock b/Cargo.lock index ad604b437dae65dd256850126e68f58ffa336378..6ef91b760ec331eeac7e0850f7121e7c0b9f6b83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -652,8 +652,10 @@ name = "hercules_opt" version = "0.1.0" dependencies = [ "bitvec", + "either", "hercules_cg", "hercules_ir", + "itertools", "ordered-float", "postcard", "serde", diff --git a/hercules_ir/src/antideps.rs b/hercules_ir/src/antideps.rs index 9dc3d1ee8b00360044ecf0ce99ad9f022df59693..a9080fd6d49904d162392b77af5eb218f81ea2bd 100644 --- a/hercules_ir/src/antideps.rs +++ b/hercules_ir/src/antideps.rs @@ -1,51 +1,17 @@ use crate::*; -/* - * Top level function to get all anti-dependencies. - */ -pub fn antideps(function: &Function, def_use: &ImmutableDefUseMap) -> Vec<(NodeID, NodeID)> { - generic_antideps( - function, - def_use, - (0..function.nodes.len()).map(NodeID::new), - ) -} - -/* - * Sometimes, we are only interested in anti-dependence edges involving arrays. - */ -pub fn array_antideps( - function: &Function, - def_use: &ImmutableDefUseMap, - types: &Vec<Type>, - typing: &Vec<TypeID>, -) -> Vec<(NodeID, NodeID)> { - generic_antideps( - function, - def_use, - (0..function.nodes.len()) - .map(NodeID::new) - .filter(|id| types[typing[id.idx()].idx()].is_array()), - ) -} - /* * Function to assemble anti-dependence edges. Returns a list of pairs of nodes. * The first item in the pair is the read node, and the second item is the write - * node. Take an iterator of nodes in case we want a subset of all anti- - * dependencies. + * node. */ -fn generic_antideps<I: Iterator<Item = NodeID>>( - function: &Function, - def_use: &ImmutableDefUseMap, - nodes: I, -) -> Vec<(NodeID, NodeID)> { +pub fn antideps(function: &Function, def_use: &ImmutableDefUseMap) -> Vec<(NodeID, NodeID)> { // Anti-dependence edges are between a write node and a read node, where // each node uses the same array value. The read must be scheduled before // the write to avoid incorrect compilation. let mut antideps = vec![]; - for id in nodes { + for id in (0..function.nodes.len()).map(NodeID::new) { // Collect the reads and writes to / from this collection. let users = def_use.get_users(id); let reads = users.iter().filter(|user| { @@ -78,9 +44,6 @@ fn generic_antideps<I: Iterator<Item = NodeID>>( antideps.push((*read, *write)); } } - - // TODO: Multiple write uses should clone the collection for N - 1 of the writes. - assert!(writes.next() == None, "Can't form anti-dependencies when there are two independent writes depending on a single collection value."); } antideps diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index e237ef55a6c9296d2c2fd8e7e6bb5a41d98e393b..eff103eedfbc121c33ac32053bf7a95706d30bfa 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -73,16 +73,12 @@ pub fn def_use(function: &Function) -> ImmutableDefUseMap { * the defs that a node uses. */ #[derive(Debug, Clone)] -pub enum NodeUses<'a> { +pub enum NodeUses { Zero, One([NodeID; 1]), Two([NodeID; 2]), Three([NodeID; 3]), - Variable(&'a Box<[NodeID]>), - // Phi nodes are special, and store both a NodeID locally *and* many in a - // boxed slice. Since these NodeIDs are not stored contiguously, we have to - // construct a new contiguous slice by copying. Sigh. - Owned(Box<[NodeID]>), + Variable(Box<[NodeID]>), } /* @@ -109,7 +105,7 @@ impl<'a> NodeUsesMut<'a> { } } -impl<'a> AsRef<[NodeID]> for NodeUses<'a> { +impl AsRef<[NodeID]> for NodeUses { fn as_ref(&self) -> &[NodeID] { match self { NodeUses::Zero => &[], @@ -117,7 +113,6 @@ impl<'a> AsRef<[NodeID]> for NodeUses<'a> { NodeUses::Two(x) => x, NodeUses::Three(x) => x, NodeUses::Variable(x) => x, - NodeUses::Owned(x) => x, } } } @@ -137,10 +132,10 @@ impl<'a> AsMut<[&'a mut NodeID]> for NodeUsesMut<'a> { /* * Construct NodeUses for a Node. */ -pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { +pub fn get_uses(node: &Node) -> NodeUses { match node { Node::Start => NodeUses::Zero, - Node::Region { preds } => NodeUses::Variable(preds), + Node::Region { preds } => NodeUses::Variable(preds.clone()), Node::If { control, cond } => NodeUses::Two([*control, *cond]), Node::Match { control, sum } => NodeUses::Two([*control, *sum]), Node::Fork { @@ -151,7 +146,7 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { Node::Phi { control, data } => { let mut uses: Vec<NodeID> = Vec::from(&data[..]); uses.push(*control); - NodeUses::Owned(uses.into_boxed_slice()) + NodeUses::Variable(uses.into_boxed_slice()) } Node::ThreadID { control, @@ -178,8 +173,8 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { function: _, dynamic_constants: _, args, - } => NodeUses::Variable(args), - Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args), + } => NodeUses::Variable(args.clone()), + Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args.clone()), Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter() { @@ -191,7 +186,7 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { uses.reverse(); uses.push(*collect); uses.reverse(); - NodeUses::Owned(uses.into_boxed_slice()) + NodeUses::Variable(uses.into_boxed_slice()) } else { NodeUses::One([*collect]) } @@ -212,7 +207,7 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { uses.push(*data); uses.push(*collect); uses.reverse(); - NodeUses::Owned(uses.into_boxed_slice()) + NodeUses::Variable(uses.into_boxed_slice()) } else { NodeUses::Two([*collect, *data]) } diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs index 622f3a3e44ef8a22daa79d6565a26d988822c12a..7c4bb8a2779100e5bbf4f43b54a67b2fabe32de1 100644 --- a/hercules_ir/src/dom.rs +++ b/hercules_ir/src/dom.rs @@ -47,16 +47,12 @@ impl DomTree { a != b && self.does_dom(a, b) } - /* - * Check if a node is in the dom tree (if the node is the root of the tree, - * will still return true). - */ pub fn is_non_root(&self, x: NodeID) -> bool { self.idom.contains_key(&x) } pub fn contains(&self, x: NodeID) -> bool { - x == self.root || self.idom.contains_key(&x) + x == self.root || self.is_non_root(x) } /* @@ -81,32 +77,22 @@ impl DomTree { .1 } - pub fn common_ancestor<I>(&self, x: I) -> Option<NodeID> - where - I: Iterator<Item = NodeID>, - { - let mut positions: HashMap<NodeID, u32> = x - .map(|x| (x, if x == self.root { 0 } else { self.idom[&x].0 })) - .collect(); - if positions.len() == 0 { - return None; + /* + * Find the least common ancestor in the tree of two nodes. This is an + * ancestor of the two nodes that is as far down the tree as possible. + */ + pub fn least_common_ancestor(&self, mut a: NodeID, mut b: NodeID) -> NodeID { + while self.idom[&a].0 < self.idom[&b].0 { + a = self.idom[&a].1; } - let mut current_level = *positions.iter().map(|(_, level)| level).max().unwrap(); - while positions.len() > 1 { - let at_current_level: Vec<NodeID> = positions - .iter() - .filter(|(_, level)| **level == current_level) - .map(|(node, _)| *node) - .collect(); - for node in at_current_level.into_iter() { - positions.remove(&node); - let (level, parent) = self.idom[&node]; - assert!(level == current_level); - positions.insert(parent, level - 1); - } - current_level -= 1; + while self.idom[&a].0 > self.idom[&b].0 { + b = self.idom[&b].1; + } + while a != b { + a = self.idom[&a].1; + b = self.idom[&b].1; } - Some(positions.into_iter().next().unwrap().0) + a } pub fn chain<'a>(&'a self, bottom: NodeID, top: NodeID) -> DomChainIterator<'a> { diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 328ecb77fe0870dd2dea707d732f2ec52f04a7e3..156433885ea09a3dc9f907fdbb6cc729d28f3b4c 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -193,7 +193,7 @@ pub fn write_dot<W: Write>( } } - // Step 4: draw BB edges in light magenta. + // Step 4: draw BB edges in olive. if let Some(bbs) = bbs { let bbs = &bbs[function_id.idx()]; for node_idx in 0..bbs.len() { diff --git a/hercules_ir/src/gcm.rs b/hercules_ir/src/gcm.rs index 716dae5f143fe591fa0c51b5b8b389312c6a4b9a..27939931fa061cedea447c739e9bec63e07216b2 100644 --- a/hercules_ir/src/gcm.rs +++ b/hercules_ir/src/gcm.rs @@ -1,15 +1,16 @@ extern crate bitvec; -use std::collections::{HashMap, VecDeque}; - -use self::bitvec::prelude::*; +use std::collections::{HashMap, HashSet}; +use std::iter::zip; use crate::*; /* * Top level global code motion function. Assigns each data node to one of its * immediate control use / user nodes, forming (unordered) basic blocks. Returns - * the control node / basic block each node is in. + * the control node / basic block each node is in. Takes in a partial + * partitioning that must be respected. Based on the schedule-early-schedule- + * late method from Cliff Click's PhD thesis. */ pub fn gcm( function: &Function, @@ -18,119 +19,212 @@ pub fn gcm( dom: &DomTree, antideps: &Vec<(NodeID, NodeID)>, loops: &LoopTree, + fork_join_map: &HashMap<NodeID, NodeID>, + partial_partition: &Vec<Option<PartitionID>>, ) -> Vec<NodeID> { - // Step 1: find the immediate control uses and immediate control users of - // each node. - let mut immediate_control_uses = - forward_dataflow(function, reverse_postorder, |inputs, node_id| { - immediate_control_flow(inputs, node_id, function) - }); - let mut immediate_control_users = - backward_dataflow(function, def_use, reverse_postorder, |inputs, node_id| { - immediate_control_flow(inputs, node_id, function) - }); + let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; + + // Step 1: assign the basic block locations of all nodes that must be in a + // specific block. This includes control nodes as well as some special data + // nodes, such as phis. + for idx in 0..function.nodes.len() { + match function.nodes[idx] { + Node::Phi { control, data: _ } => bbs[idx] = Some(control), + Node::ThreadID { + control, + dimension: _, + } => bbs[idx] = Some(control), + Node::Reduce { + control, + init: _, + reduct: _, + } => bbs[idx] = Some(control), + Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)), + Node::Constant { id: _ } => bbs[idx] = Some(NodeID::new(0)), + Node::DynamicConstant { id: _ } => bbs[idx] = Some(NodeID::new(0)), + _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)), + _ => {} + } + } - // Reads and writes forming anti dependencies must be put in the same block. + // Step 2: schedule early. Place nodes in reverse postorder in the earliest + // position they could go. + let mut schedule_early = bbs.clone(); + let mut antideps_uses = HashMap::<NodeID, Vec<NodeID>>::new(); for (read, write) in antideps { - let meet = UnionNodeSet::meet( - &immediate_control_uses[read.idx()], - &immediate_control_uses[write.idx()], - ); - immediate_control_uses[read.idx()] = meet.clone(); - immediate_control_uses[write.idx()] = meet; + antideps_uses.entry(*write).or_default().push(*read); + } + for id in reverse_postorder { + if schedule_early[id.idx()].is_some() { + continue; + } - let meet = UnionNodeSet::meet( - &immediate_control_users[read.idx()], - &immediate_control_users[write.idx()], + // For every use, check what block is its "schedule early" block. This + // node goes in the lowest block amongst those blocks. + let lowest = dom.lowest_amongst( + get_uses(&function.nodes[id.idx()]) + .as_ref() + .into_iter() + .map(|id| *id) + // Include "uses" from anti-dependencies. + .chain(antideps_uses.remove(&id).unwrap_or_default().into_iter()) + .map(|id| schedule_early[id.idx()].unwrap()), ); - immediate_control_users[read.idx()] = meet.clone(); - immediate_control_users[write.idx()] = meet; + schedule_early[id.idx()] = Some(lowest); } - // Step 2: find most control dependent, shallowest loop level node for every - // node. - let bbs = (0..function.nodes.len()) - .map(|idx| { - let highest = - dom.lowest_amongst(immediate_control_uses[idx].nodes(function.nodes.len() as u32)); - let lowest = dom - .common_ancestor(immediate_control_users[idx].nodes(function.nodes.len() as u32)) - .unwrap_or(highest); + // Step 3: schedule late and pick each nodes final position. Since the late + // schedule of each node depends on the final positions of its users, these + // two steps must be fused. Place nodes in postorder. Compute their latest + // position, then use the control dependent + shallow loop heuristic to + // actually place them. + let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); + let mut antideps_users = HashMap::<NodeID, Vec<NodeID>>::new(); + for (read, write) in antideps { + antideps_users.entry(*read).or_default().push(*write); + } + for id in reverse_postorder.into_iter().rev() { + if bbs[id.idx()].is_some() { + continue; + } - // If the ancestor of the control users isn't below the lowest - // control use, then just place in the lowest control use. - if !dom.does_dom(highest, lowest) { - highest + // Calculate the least common ancestor of user blocks, a.k.a. the "late" + // schedule. + let mut lca = None; + // Helper to incrementally update the LCA. + let mut update_lca = |a| { + if let Some(acc) = lca { + lca = Some(dom.least_common_ancestor(acc, a)); } else { - // Collect in vector to reverse, since we want to traverse down - // the dom tree, not up it. - let mut chain = dom - .chain(lowest, highest) - .collect::<Vec<_>>() - .into_iter() - .rev(); + lca = Some(a); + } + }; - let mut location = chain.next().unwrap(); - while let Some(control_node) = chain.next() { - // Traverse down the dom tree until we find a loop. - if loops.contains(control_node) { - break; - } else { - location = control_node; + // For every user, consider where we need to be to directly dominate the + // user. + for user in def_use + .get_users(*id) + .as_ref() + .into_iter() + .map(|id| *id) + // Include "users" from anti-dependencies. + .chain(antideps_users.remove(&id).unwrap_or_default().into_iter()) + { + if let Node::Phi { control, data } = &function.nodes[user.idx()] { + // For phis, we need to dominate the block jumping to the phi in + // the slot that corresponds to our use. + for (control, data) in zip(get_uses(&function.nodes[control.idx()]).as_ref(), data) + { + if *id == *data { + update_lca(*control); } } + } else if let Node::Reduce { + control, + init, + reduct, + } = &function.nodes[user.idx()] + { + // For reduces, we need to either dominate the block right + // before the fork if we're the init input, or we need to + // dominate the join if we're the reduct input. + if *id == *init { + let before_fork = function.nodes[join_fork_map[control].idx()] + .try_fork() + .unwrap() + .0; + update_lca(before_fork); + } else { + assert_eq!(*id, *reduct); + update_lca(*control); + } + } else { + // For everything else, we just need to dominate the user. + update_lca(bbs[user.idx()].unwrap()); + } + } - // If the assigned location is a join and this node doesn't use - // a reduce from that join, we actually want to place these - // nodes in the predecessor of the join, so that the code will - // get executed in parallel. - if let Some(control) = function.nodes[location.idx()].try_join() - && location != NodeID::new(idx) - { - // Set up BFS to find reduce nodes. - let mut bfs = VecDeque::new(); - let mut bfs_visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; - bfs.push_back(NodeID::new(idx)); - bfs_visited.set(idx, true); - let mut found_reduce = false; - 'bfs: while let Some(id) = bfs.pop_front() { - for use_id in get_uses(&function.nodes[id.idx()]).as_ref() { - // If we find a reduce, check that it's attached to - // the join we care about - if let Some((join, _, _)) = function.nodes[use_id.idx()].try_reduce() - && join == location - { - found_reduce = true; - break 'bfs; - } - - // Only go through data nodes. - if bfs_visited[use_id.idx()] - || function.nodes[use_id.idx()].is_control() - { - continue; - } + // Look between the LCA and the schedule early location to place the + // node. + let schedule_early = schedule_early[id.idx()].unwrap(); + let mut chain = dom + .chain(lca.unwrap_or(schedule_early), schedule_early) + // Filter the dominator chain by control nodes in the same partition + // as this data node, if the data node is in a partition already. + .filter(|dominator| { + partial_partition[id.idx()].is_none() + || partial_partition[dominator.idx()] == partial_partition[id.idx()] + }); + let mut location = chain.next().unwrap(); + while let Some(control_node) = chain.next() { + // If the next node further up the dominator tree is in a shallower + // loop nest or if we can get out of a reduce loop when we don't + // need to be in one, place this data node in a higher-up location. + let shallower_nest = if let (Some(old_nest), Some(new_nest)) = + (loops.nesting(location), loops.nesting(control_node)) + { + old_nest > new_nest + } else { + false + }; + // This will move all nodes that don't need to be in reduce loops + // outside of reduce loops. Nodes that do need to be in a reduce + // loop use the reduce node forming the loop, so the dominator chain + // will consist of one block, and this loop won't ever iterate. + let currently_at_join = function.nodes[location.idx()].is_join(); + if shallower_nest || currently_at_join { + location = control_node; + } + } - bfs.push_back(*use_id); - bfs_visited.set(use_id.idx(), true); - } - } + bbs[id.idx()] = Some(location); + } - // If we don't depend on the reduce, we're not in a cycle - // with the reduce. Therefore, we should be scheduled to the - // predecessor of the join, since this code can run in - // parallel. - if !found_reduce { - location = control; - } - } + bbs.into_iter().map(Option::unwrap).collect() +} - location - } - }) - .collect(); +/* + * Top level function for creating a fork-join map. Map is from fork node ID to + * join node ID, since a join can easily determine the fork it corresponds to + * (that's the mechanism used to implement this analysis). This analysis depends + * on type information. + */ +pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> { + let mut fork_join_map = HashMap::new(); + for idx in 0..function.nodes.len() { + // We only care about join nodes. + if function.nodes[idx].is_join() { + // Iterate the control predecessors until finding a fork. Maintain a + // counter of unmatched fork-join pairs seen on the way, since fork- + // joins may be nested. Every join is dominated by their fork, so + // just iterate the first unseen predecessor of each control node. + let join_id = NodeID::new(idx); + let mut unpaired = 0; + let mut cursor = join_id; + let mut seen = HashSet::<NodeID>::new(); + let fork_id = loop { + cursor = control + .preds(cursor) + .filter(|pred| !seen.contains(pred)) + .next() + .unwrap(); + seen.insert(cursor); - bbs + if function.nodes[cursor.idx()].is_join() { + unpaired += 1; + } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 { + unpaired -= 1; + } else if function.nodes[cursor.idx()].is_fork() { + break cursor; + } + }; + fork_join_map.insert(fork_id, join_id); + } + } + fork_join_map } /* @@ -148,7 +242,7 @@ pub fn compute_fork_join_nesting( // the corresponding join node. (0..function.nodes.len()) .map(NodeID::new) - .filter(|id| function.nodes[id.idx()].is_control()) + .filter(|id| dom.contains(*id)) .map(|id| { ( id, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 9b946d7a41c19a5b4732c3e2a299d4ccbbdba24e..4c487c59be74ba21e319c188b54a6d29ed6234b6 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -422,6 +422,13 @@ impl Module { Ok(()) } + + pub fn map<F, O>(&self, f: F) -> Vec<O> + where + F: Fn(&Function) -> O, + { + self.functions.iter().map(|func| f(func)).collect() + } } /* @@ -625,11 +632,11 @@ impl Function { * Some analysis results can be updated after gravestone deletions. */ pub trait GraveUpdatable { - fn map_gravestones(self, grave_mapping: &Vec<NodeID>) -> Self; + fn fix_gravestones(&mut self, grave_mapping: &Vec<NodeID>); } impl<T: Clone> GraveUpdatable for Vec<T> { - fn map_gravestones(self, grave_mapping: &Vec<NodeID>) -> Self { + fn fix_gravestones(&mut self, grave_mapping: &Vec<NodeID>) { let mut new_self = vec![]; for (data, (idx, mapping)) in std::iter::zip(self.into_iter(), grave_mapping.iter().enumerate()) @@ -639,7 +646,7 @@ impl<T: Clone> GraveUpdatable for Vec<T> { new_self.push(data.clone()); } } - new_self + *self = new_self; } } diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs index c657572f6f1cac3027219642903221d244bddd93..5aa6bd19a65f842ab19ac855066ce894e0e568f8 100644 --- a/hercules_ir/src/loops.rs +++ b/hercules_ir/src/loops.rs @@ -18,13 +18,14 @@ use crate::*; * dominated if node. For fork join pairs, this is the fork node. A loop is a * top-level loop if its parent is the root node of the subgraph. Each node in * the tree is an entry in the loops HashMap - the key is the "header" node for - * the loop, and the key is a pair of the set of control nodes inside the loop + * the loop, and the value is a pair of the set of control nodes inside the loop * and this loop's parent header. */ #[derive(Debug, Clone)] pub struct LoopTree { root: NodeID, loops: HashMap<NodeID, (BitVec<u8, Lsb0>, NodeID)>, + nesting: HashMap<NodeID, usize>, } impl LoopTree { @@ -42,7 +43,8 @@ impl LoopTree { */ pub fn bottom_up_loops(&self) -> Vec<(NodeID, &BitVec<u8, Lsb0>)> { let mut bottom_up = vec![]; - let mut children_count: HashMap<NodeID, u32> = self.loops.iter().map(|(k, _)| (*k, 0)).collect(); + let mut children_count: HashMap<NodeID, u32> = + self.loops.iter().map(|(k, _)| (*k, 0)).collect(); children_count.insert(self.root, 0); for (_, (_, parent)) in self.loops.iter() { *children_count.get_mut(&parent).unwrap() += 1; @@ -58,6 +60,13 @@ impl LoopTree { } bottom_up } + + /* + * Gets the nesting of a loop, keyed by the header. + */ + pub fn nesting(&self, header: NodeID) -> Option<usize> { + self.nesting.get(&header).map(|id| *id) + } } /* @@ -111,23 +120,46 @@ pub fn loops( // Step 5: figure out loop tree edges. A loop with header a can only be an // outer loop of a loop with header b if a dominates b. - let loops = loops + let loops: HashMap<NodeID, (BitVec<u8, Lsb0>, NodeID)> = loops .iter() .map(|(header, contents)| { let mut dominator = *header; + // Climb the cominator tree. while let Some(new_dominator) = dom.imm_dom(dominator) { dominator = new_dominator; + // Check if the dominator node is a loop header. if let Some(outer_contents) = loops.get(&dominator) { + // Check if the dominating loop actually contains this loop. if outer_contents[header.idx()] { return (*header, (contents.clone(), dominator)); } } } + // If no dominating node is a loop header for a loop containing this + // loop, then this loop is a top-level loop. (*header, (contents.clone(), root)) }) .collect(); - LoopTree { root, loops } + // Step 6: compute loop tree nesting. + let mut nesting = HashMap::new(); + let mut worklist: VecDeque<NodeID> = loops.keys().map(|id| *id).collect(); + while let Some(header) = worklist.pop_front() { + let parent = loops[&header].1; + if parent == root { + nesting.insert(header, 0); + } else if let Some(nest) = nesting.get(&parent) { + nesting.insert(header, nest + 1); + } else { + worklist.push_back(header); + } + } + + LoopTree { + root, + loops, + nesting, + } } fn loop_reachability(n: NodeID, d: NodeID, subgraph: &Subgraph) -> BitVec<u8, Lsb0> { diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs index ea2c6153e40b9c0bfa28c7cc84a69d41c493ad39..2c6335147658dac3d590458c5585f92956c93d42 100644 --- a/hercules_ir/src/schedule.rs +++ b/hercules_ir/src/schedule.rs @@ -70,6 +70,16 @@ impl Plan { } /* + * GCM takes in a partial partitioning, but most of the time we have a full + * partitioning. + */ + pub fn make_partial_partitioning(&self) -> Vec<Option<PartitionID>> { + self.partitions.iter().map(|id| Some(*id)).collect() + } + + /* + * LEGACY API: This is the legacy mechanism for repairing plans. Future code + * should use the `repair_plan` function in editor.rs. * Plans must be "repairable", in the sense that the IR that's referred to * may change after many passes. Since a plan is an explicit side data * structure, it must be updated after every change in the IR. @@ -86,7 +96,7 @@ impl Plan { // Schedules of old nodes just get dropped. Since schedules don't hold // necessary semantic information, we are free to drop them arbitrarily. - schedules = schedules.map_gravestones(grave_mapping); + schedules.fix_gravestones(grave_mapping); schedules.resize(function.nodes.len(), vec![]); // Delete now empty partitions. First, filter out deleted nodes from the @@ -476,6 +486,20 @@ impl Plan { } } +impl GraveUpdatable for Plan { + fn fix_gravestones(&mut self, grave_mapping: &Vec<NodeID>) { + self.schedules.fix_gravestones(grave_mapping); + self.partitions.fix_gravestones(grave_mapping); + self.partition_devices.fix_gravestones(grave_mapping); + let mut renumber_partitions = HashMap::new(); + for id in self.partitions.iter_mut() { + let next_id = PartitionID::new(renumber_partitions.len()); + *id = *renumber_partitions.entry(*id).or_insert(next_id); + } + self.num_partitions = renumber_partitions.len(); + } +} + /* * A "default" plan should be available, where few schedules are used and * conservative partitioning is enacted. Only schedules that can be proven safe diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs index f1466ce3ab91770c5c2130c00cb9a234dcfd928b..3549718a12d170f02df0da3ec548667cf40326e9 100644 --- a/hercules_ir/src/subgraph.rs +++ b/hercules_ir/src/subgraph.rs @@ -238,11 +238,12 @@ where } /* - * Get the control subgraph of a function. + * Get the control subgraph of a function. Ignore gravestones. */ pub fn control_subgraph(function: &Function, def_use: &ImmutableDefUseMap) -> Subgraph { subgraph(function, def_use, |node| { function.nodes[node.idx()].is_control() + && (!function.nodes[node.idx()].is_start() || node.idx() == 0) }) } diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index b6b9b2811f5f1e1323876f3138c907f26b9c1e8e..64d460016ee7c3b20573feb228673d6d75bf7fb2 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::iter::zip; use crate::*; @@ -1071,47 +1071,6 @@ fn typeflow( } } -/* - * Top level function for creating a fork-join map. Map is from fork node ID to - * join node ID, since a join can easily determine the fork it corresponds to - * (that's the mechanism used to implement this analysis). This analysis depends - * on type information. - */ -pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> { - let mut fork_join_map = HashMap::new(); - for idx in 0..function.nodes.len() { - // We only care about join nodes. - if function.nodes[idx].is_join() { - // Iterate the control predecessors until finding a fork. Maintain a - // counter of unmatched fork-join pairs seen on the way, since fork- - // joins may be nested. Every join is dominated by their fork, so - // just iterate the first unseen predecessor of each control node. - let join_id = NodeID::new(idx); - let mut unpaired = 0; - let mut cursor = join_id; - let mut seen = HashSet::<NodeID>::new(); - let fork_id = loop { - cursor = control - .preds(cursor) - .filter(|pred| !seen.contains(pred)) - .next() - .unwrap(); - seen.insert(cursor); - - if function.nodes[cursor.idx()].is_join() { - unpaired += 1; - } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 { - unpaired -= 1; - } else if function.nodes[cursor.idx()].is_fork() { - break cursor; - } - }; - fork_join_map.insert(fork_id, join_id); - } - } - fork_join_map -} - /* * Determine if a given cast conversion is valid. */ diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 6682ae8980d76bbdd4ac392add2de7b280442fb9..e1936a97d4e717b06195188016b711735af6367b 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -6,6 +6,8 @@ authors = ["Russel Arbore <rarbore2@illinois.edu>, Aaron Councilman <aaronjc4@il [dependencies] ordered-float = "*" bitvec = "*" +either = "*" +itertools = "*" take_mut = "*" postcard = { version = "*", features = ["alloc"] } serde = { version = "*", features = ["derive"] } diff --git a/hercules_opt/src/dce.rs b/hercules_opt/src/dce.rs index 255402902aa6ec0bcc240cbb7bd18a9eb8945526..14fd39899ff18756f4c7ab35897b0093ef807038 100644 --- a/hercules_opt/src/dce.rs +++ b/hercules_opt/src/dce.rs @@ -3,45 +3,36 @@ extern crate hercules_ir; use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; +use crate::*; + /* * Top level function to run dead code elimination. Deletes nodes by setting * nodes to gravestones. Works with a function already containing gravestones. */ -pub fn dce(function: &mut Function) { - // Step 1: count number of users for each node. - let mut num_users = vec![0; function.nodes.len()]; - for (idx, node) in function.nodes.iter().enumerate() { - for u in get_uses(node).as_ref() { - num_users[u.idx()] += 1; - } +pub fn dce(editor: &mut FunctionEditor) { + // Create worklist (starts as all nodes). + let mut worklist: Vec<NodeID> = (0..editor.func().nodes.len()).map(NodeID::new).collect(); - // Return nodes shouldn't be considered dead code, so create a "phantom" - // user. - if node.is_return() { - num_users[idx] += 1; - } - } - - // Step 2: worklist over zero user nodes. - - // Worklist starts as list of all nodes with 0 users. - let mut worklist: Vec<_> = num_users - .iter() - .enumerate() - .filter(|(_, num_users)| **num_users == 0) - .map(|(idx, _)| idx) - .collect(); while let Some(work) = worklist.pop() { - // Use start node as gravestone node value. - let mut gravestone = Node::Start; - std::mem::swap(&mut function.nodes[work], &mut gravestone); + // If a node on the worklist is a start node, it is either *the* start + // node (which we shouldn't delete), or is a gravestone for an already + // deleted node earlier in the worklist. If a node is a return node, it + // shouldn't be removed. + if editor.func().nodes[work.idx()].is_start() || editor.func().nodes[work.idx()].is_return() + { + continue; + } - // Now that we set the gravestone, figure out other nodes that need to - // be added to the worklist. - for u in get_uses(&gravestone).as_ref() { - num_users[u.idx()] -= 1; - if num_users[u.idx()] == 0 { - worklist.push(u.idx()); + // If a node on the worklist has 0 users, delete it. Add its uses onto + // the worklist. + if editor.users(work).len() == 0 { + let uses = get_uses(&editor.func().nodes[work.idx()]); + let success = editor.edit(|edit| edit.delete_node(work)); + if success { + // If the edit was performed, then its uses may now be dead. + for u in uses.as_ref() { + worklist.push(*u); + } } } } diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs new file mode 100644 index 0000000000000000000000000000000000000000..9f71991db33609737cf39d53a1b00ff5615bf2f1 --- /dev/null +++ b/hercules_opt/src/editor.rs @@ -0,0 +1,557 @@ +extern crate bitvec; +extern crate either; +extern crate hercules_ir; +extern crate itertools; + +use std::collections::{BTreeMap, HashSet}; +use std::mem::take; + +use self::bitvec::prelude::*; +use self::either::Either; +use self::itertools::Itertools; + +use self::hercules_ir::antideps::*; +use self::hercules_ir::dataflow::*; +use self::hercules_ir::def_use::*; +use self::hercules_ir::dom::*; +use self::hercules_ir::gcm::*; +use self::hercules_ir::ir::*; +use self::hercules_ir::loops::*; +use self::hercules_ir::schedule::*; +use self::hercules_ir::subgraph::*; + +pub type Edit = (HashSet<NodeID>, HashSet<NodeID>); + +/* + * Helper object for editing Hercules functions in a trackable manner. Edits are + * recorded in order to repair partitions and debug info. + */ +#[derive(Debug)] +pub struct FunctionEditor<'a> { + // Wraps a mutable reference to a function. Doesn't provide access to this + // reference directly, so that we can monitor edits. + function: &'a mut Function, + // Most optimizations need def use info, so provide an iteratively updated + // mutable version that's automatically updated based on recorded edits. + mut_def_use: Vec<HashSet<NodeID>>, + // Record edits as a mapping from sets of node IDs to sets of node IDs. The + // sets on the "left" side of this map should be mutually disjoint, and the + // sets on the "right" side should also be mutually disjoint. All of the + // node IDs on the left side should be deleted node IDs or IDs of unmodified + // nodes, and all of the node IDs on the right side should be added node IDs + // or IDs of unmodified nodes. In other words, there should be no added node + // IDs on the left side, and no deleted node IDs on the right side. These + // mappings are stored sequentially in a list, rather than as a map. This is + // because a transformation may iteratively update a function - i.e., a node + // ID added in iteration N may be deleted in iteration N + M. To maintain a + // more precise history of edits, we store each edit individually, which + // allows us to make more precise repairs of partitions and debug info. + edits: Vec<Edit>, + // The pass manager may indicate that only a certain subset of nodes should + // be modified in a function - what this actually means is that some nodes + // are off limits for deletion (equivalently modification) or being replaced + // as a use. + mutable_nodes: BitVec<u8, Lsb0>, +} + +/* + * Helper objects to make a single edit. + */ +#[derive(Debug)] +pub struct FunctionEdit<'a: 'b, 'b> { + // Reference the active function editor. + editor: &'b mut FunctionEditor<'a>, + // Keep track of deleted node IDs. + deleted: HashSet<NodeID>, + // Keep track of added node IDs. + added: HashSet<NodeID>, + // Keep track of added and use updated nodes. + added_and_updated: BTreeMap<NodeID, Node>, + // Compute a def-use map entries iteratively. + updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>, + // Technically client edit functions don't have to early exit, so explicitly + // keep track of if this edit should be aborted. + abort: bool, +} + +impl<'a: 'b, 'b> FunctionEditor<'a> { + pub fn new(function: &'a mut Function, def_use: &ImmutableDefUseMap) -> Self { + let mut_def_use = (0..function.nodes.len()) + .map(|idx| { + def_use + .get_users(NodeID::new(idx)) + .into_iter() + .map(|x| *x) + .collect() + }) + .collect(); + let mutable_nodes = bitvec![u8, Lsb0; 1; function.nodes.len()]; + + FunctionEditor { + function, + mut_def_use, + edits: vec![], + mutable_nodes, + } + } + + pub fn edit<F>(&'b mut self, edit: F) -> bool + where + F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>, + { + // Create the edit helper struct and perform the edit using it. + let edit_obj = FunctionEdit { + editor: self, + deleted: HashSet::new(), + added: HashSet::new(), + added_and_updated: BTreeMap::new(), + updated_def_use: BTreeMap::new(), + abort: false, + }; + + if let Ok(populated_edit) = edit(edit_obj) { + // If the populated edit is returned, then the edit can be performed + // without modifying immutable nodes. + let FunctionEdit { + editor, + deleted, + added, + added_and_updated, + updated_def_use, + abort, + } = populated_edit; + if abort { + return false; + } + + // Step 1: update the mutable def use map. + for (u, new_users) in updated_def_use { + // Go through new def-use entries in order. These are either + // updates to existing nodes, in which case we just modify them + // in place, or they are user entries for new nodes, in which + // case we push them. + if u.idx() < editor.mut_def_use.len() { + editor.mut_def_use[u.idx()] = new_users; + } else { + // The new nodes must be traversed in order in a packed + // fashion - if a new node was created without an + // accompanying new users entry, something is wrong! + assert_eq!(editor.mut_def_use.len(), u.idx()); + editor.mut_def_use.push(new_users); + } + } + + // Step 2: add and update nodes. + for (id, node) in added_and_updated { + if id.idx() < editor.function.nodes.len() { + editor.function.nodes[id.idx()] = node; + } else { + // New nodes should've been assigned increasing IDs starting + // at the previous number of nodes, so check that. + assert_eq!(editor.function.nodes.len(), id.idx()); + editor.function.nodes.push(node); + } + } + + // Step 3: delete nodes. This is done using "gravestones", where a + // node other than node ID 0 being a start node is considered a + // gravestone. + for id in deleted.iter() { + // Check that there are no users of deleted nodes. + assert!(editor.mut_def_use[id.idx()].is_empty()); + editor.function.nodes[id.idx()] = Node::Start; + } + + // Step 4: add a single edit to the edit list. + editor.edits.push((deleted, added)); + + // Step 5: update the length of mutable_nodes. All added nodes are + // mutable. + editor + .mutable_nodes + .resize(editor.function.nodes.len(), true); + + true + } else { + false + } + } + + pub fn func(&self) -> &Function { + &self.function + } + + pub fn users(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ { + self.mut_def_use[id.idx()].iter().map(|x| *x) + } + + pub fn edits(self) -> Vec<Edit> { + self.edits + } +} + +impl<'a, 'b> FunctionEdit<'a, 'b> { + fn ensure_updated_def_use_entry(&mut self, id: NodeID) { + if !self.updated_def_use.contains_key(&id) { + let old_entry = self + .editor + .mut_def_use + .get(id.idx()) + .map(|entry| entry.clone()) + .unwrap_or_default(); + self.updated_def_use.insert(id, old_entry); + } + } + + pub fn add_node(&mut self, node: Node) -> NodeID { + let id = NodeID::new(self.editor.function.nodes.len() + self.added.len()); + // Added nodes need to have an entry in the def-use map. + self.updated_def_use.insert(id, HashSet::new()); + // Added nodes use other nodes, and we need to update their def-use + // entries. + for u in get_uses(&node).as_ref() { + self.ensure_updated_def_use_entry(*u); + self.updated_def_use.get_mut(u).unwrap().insert(id); + } + // Add the node. + self.added_and_updated.insert(id, node); + self.added.insert(id); + id + } + + pub fn delete_node(mut self, id: NodeID) -> Result<Self, Self> { + // We can only delete mutable nodes. Return None if we try to modify an + // immutable node, as it means the whole edit should be aborted. + if self.editor.mutable_nodes[id.idx()] { + assert!( + !self.added.contains(&id), + "PANIC: Please don't delete a node that was added in the same edit." + ); + // Deleted nodes use other nodes, and we need to update their def- + // use entries. + let uses: Box<[NodeID]> = get_uses(&self.editor.function.nodes[id.idx()]) + .as_ref() + .into(); + for u in uses { + self.ensure_updated_def_use_entry(u); + self.updated_def_use.get_mut(&u).unwrap().remove(&id); + } + self.deleted.insert(id); + Ok(self) + } else { + self.abort = true; + Err(self) + } + } + + pub fn replace_all_uses(mut self, old: NodeID, new: NodeID) -> Result<Self, Self> { + // We can only replace uses of mutable nodes. Return None if we try to + // replace uses of an immutable node, as it means the whole edit should + // be aborted. + if self.editor.mutable_nodes[old.idx()] { + // Update all of the users of the old node. + self.ensure_updated_def_use_entry(old); + for user_id in self.updated_def_use[&old].iter() { + // Replace uses of old with new. + let mut updated_user = self.node(*user_id).clone(); + for u in get_uses_mut(&mut updated_user).as_mut() { + if **u == old { + **u = new; + } + } + // Add the updated user to added_and_updated. + self.added_and_updated.insert(*user_id, updated_user); + } + + // All of the users of the old node become users of the new node, so + // move all of the entries in the def-use from the old to the new. + let old_entries = take(self.updated_def_use.get_mut(&old).unwrap()); + self.ensure_updated_def_use_entry(new); + self.updated_def_use + .get_mut(&new) + .unwrap() + .extend(old_entries); + + Ok(self) + } else { + self.abort = true; + Err(self) + } + } + + pub fn node(&self, id: NodeID) -> &Node { + assert!(!self.deleted.contains(&id)); + if let Some(node) = self.added_and_updated.get(&id) { + // Refer to added or updated node. This node is guaranteed to be + // updated with uses after replace_all_uses is called. + node + } else { + // Refer to the original node. + &self.editor.function.nodes[id.idx()] + } + } + + pub fn users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ { + assert!(!self.deleted.contains(&id)); + if let Some(users) = self.updated_def_use.get(&id) { + // Refer to the updated users set. + Either::Left(users.iter().map(|x| *x)) + } else { + // Refer to the original users set. + Either::Right(self.editor.mut_def_use[id.idx()].iter().map(|x| *x)) + } + } +} + +/* + * Simplify an edit sequence into a single, larger, edit. + */ +fn collapse_edits(edits: &[Edit]) -> Edit { + let mut total_edit = Edit::default(); + + for edit in edits { + assert!(edit.0.is_disjoint(&edit.1), "PANIC: Edit sequence is malformed - can't add and delete the same node ID in a single edit."); + assert!( + total_edit.0.is_disjoint(&edit.0), + "PANIC: Edit sequence is malformed - can't delete the same node ID twice." + ); + assert!( + total_edit.1.is_disjoint(&edit.1), + "PANIC: Edit sequence is malformed - can't add the same node ID twice." + ); + + for delete in edit.0.iter() { + total_edit.0.insert(*delete); + total_edit.1.remove(delete); + } + + for addition in edit.1.iter() { + total_edit.0.remove(addition); + total_edit.1.insert(*addition); + } + } + + total_edit +} + +/* + * Plans can be repaired - this entails repairing schedules as well as + * partitions. `new_function` is the function after the edits have occurred, but + * before gravestones have been removed. + */ +pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) { + // Step 1: collapse all of the edits into a single edit. For repairing + // partitions, we don't need to consider the intermediate edit states. + let total_edit = collapse_edits(edits); + + // Step 2: drop schedules for deleted nodes and create empty schedule lists + // for added nodes. + for deleted in total_edit.0.iter() { + plan.schedules[deleted.idx()] = vec![]; + } + if !total_edit.1.is_empty() { + assert_eq!( + total_edit.1.iter().max().unwrap().idx() + 1, + new_function.nodes.len() + ); + plan.schedules.resize(new_function.nodes.len(), vec![]); + } + + // Step 3: figure out the order to add nodes to partitions. Roughly, we look + // at the added nodes in reverse postorder and partition by control/data. We + // first add control nodes to partitions using node-specific rules. We then + // add data nodes based on the partitions of their immediate control uses + // and users. + let def_use = def_use(new_function); + let rev_po = reverse_postorder(&def_use); + let added_control_nodes: Vec<NodeID> = rev_po + .iter() + .filter(|id| total_edit.1.contains(id) && new_function.nodes[id.idx()].is_control()) + .map(|id| *id) + .collect(); + let added_data_nodes: Vec<NodeID> = rev_po + .iter() + .filter(|id| total_edit.1.contains(id) && !new_function.nodes[id.idx()].is_control()) + .map(|id| *id) + .collect(); + + // Step 4: figure out the partitions for added control nodes. + // Do a bunch of analysis that basically boils down to finding what fork- + // joins are top-level. + let control_subgraph = control_subgraph(new_function, &def_use); + let dom = dominator(&control_subgraph, NodeID::new(0)); + let fork_join_map = fork_join_map(new_function, &control_subgraph); + let fork_join_nesting = compute_fork_join_nesting(new_function, &dom, &fork_join_map); + // While building, the new partitions map uses Option since we don't have + // partitions for new nodes yet, and we need to record that specifically for + // computing the partitions of region nodes. + let mut new_partitions: Vec<Option<PartitionID>> = take(&mut plan.partitions) + .into_iter() + .map(|part| Some(part)) + .collect(); + new_partitions.resize(new_function.nodes.len(), None); + // Iterate the added control nodes in reverse postorder. + for control_id in added_control_nodes { + let node = &new_function.nodes[control_id.idx()]; + // There are three cases where this control node needs to start a new + // partition: + // 1. It's a top-level fork. + // 2. One of its control predecessors is a top-level join. + // 3. It's a region node where not every predecessor is in the same + // partition (equivalently, not every predecessor is in the same + // partition - only region nodes can have multiple predecessors). + let top_level_fork = node.is_fork() && fork_join_nesting[&control_id].len() == 1; + let top_level_join = control_subgraph.preds(control_id).any(|pred| { + new_function.nodes[pred.idx()].is_join() && fork_join_nesting[&pred].len() == 1 + }); + // It's not possible for every predecessor to not have been assigned a + // partition yet because of reverse postorder traversal. + let multi_pred_region = !control_subgraph + .preds(control_id) + .map(|pred| new_partitions[pred.idx()]) + .all_equal(); + + if top_level_fork || top_level_join || multi_pred_region { + // This control node goes in a new partition. + let part_id = PartitionID::new(plan.num_partitions); + plan.num_partitions += 1; + new_partitions[control_id.idx()] = Some(part_id); + } else { + // This control node goes in the partition of any one of its + // predecessors. They're all the same by condition 3 above. + new_partitions[control_id.idx()] = control_subgraph + .preds(control_id) + .filter_map(|pred_id| new_partitions[pred_id.idx()]) + .next(); + } + } + + // Step 5: figure out the partitions for added data nodes. + let antideps = antideps(&new_function, &def_use); + let loops = loops(&control_subgraph, NodeID::new(0), &dom, &fork_join_map); + let bbs = gcm( + new_function, + &def_use, + &rev_po, + &dom, + &antideps, + &loops, + &fork_join_map, + &new_partitions, + ); + for data_id in added_data_nodes { + new_partitions[data_id.idx()] = new_partitions[bbs[data_id.idx()].idx()]; + } + + // Step 6: wrap everything up. + plan.partitions = new_partitions.into_iter().map(|id| id.unwrap()).collect(); + plan.partition_devices + .resize(new_function.nodes.len(), Device::CPU); +} + +#[cfg(test)] +mod editor_tests { + #[allow(unused_imports)] + use super::*; + + use std::mem::replace; + + use self::hercules_ir::parse::parse; + + fn canonicalize(function: &mut Function) -> Vec<Option<NodeID>> { + // The reverse postorder traversal from the Start node is a map from new + // index to old ID. + let rev_po = reverse_postorder(&def_use(function)); + let num_new_nodes = rev_po.len(); + + // Construct a map from old ID to new ID. + let mut old_to_new = vec![None; function.nodes.len()]; + for (new_idx, old_id) in rev_po.into_iter().enumerate() { + old_to_new[old_id.idx()] = Some(NodeID::new(new_idx)); + } + + // Move the old nodes before permuting them. + let mut old_nodes = take(&mut function.nodes); + function.nodes = vec![Node::Start; num_new_nodes]; + + // Permute the old nodes back into the function and fix their uses. + for (old_idx, new_id) in old_to_new.iter().enumerate() { + // Check if this old node is in the canonicalized form. + if let Some(new_id) = new_id { + // Get the old node. + let mut node = replace(&mut old_nodes[old_idx], Node::Start); + + // Fix its uses. + for u in get_uses_mut(&mut node).as_mut() { + // Map every use using the old-to-new map. If we try to use + // a node that doesn't have a mapping, then the original IR + // had a node reachable from the start using another node + // not reachable from the start, which is malformed. + **u = old_to_new[u.idx()].unwrap(); + } + + // Insert the fixed node into its new spot. + function.nodes[new_id.idx()] = node; + } + } + + old_to_new + } + + #[test] + fn example1() { + // Define the original function. + let mut src_module = parse( + " +fn func(x: i32) -> i32 + c = constant(i32, 7) + y = add(x, c) + r = return(start, y) +", + ) + .unwrap(); + + // Find the ID of the add node and its uses. + let func = &mut src_module.functions[0]; + let (add, left, right) = func + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| { + node.try_binary(BinaryOperator::Add) + .map(|(left, right)| (NodeID::new(idx), left, right)) + }) + .next() + .unwrap(); + + // Edit the function by replacing the add with a multiply. + let mut editor = FunctionEditor::new(func, &def_use(func)); + let success = editor.edit(|mut edit| { + let mul = edit.add_node(Node::Binary { + op: BinaryOperator::Mul, + left, + right, + }); + let edit = edit.replace_all_uses(add, mul)?; + let edit = edit.delete_node(add)?; + Ok(edit) + }); + assert!(success); + + // Canonicalize the function. + canonicalize(func); + + // Check that the function is correct. + let mut dst_module = parse( + " +fn func(x: i32) -> i32 + c = constant(i32, 7) + y = mul(x, c) + r = return(start, y) +", + ) + .unwrap(); + canonicalize(&mut dst_module.functions[0]); + assert_eq!(src_module.functions[0].nodes, dst_module.functions[0].nodes); + } +} diff --git a/hercules_opt/src/gvn.rs b/hercules_opt/src/gvn.rs index e8337e609b3ae881c3e8a012d9cd09e212eee2c1..e1a179f7986ad7134b27b1e0d87e4947faa44fa6 100644 --- a/hercules_opt/src/gvn.rs +++ b/hercules_opt/src/gvn.rs @@ -2,57 +2,57 @@ extern crate hercules_ir; use std::collections::HashMap; -use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; +use crate::*; + /* * Top level function to run global value numbering. In the sea of nodes, GVN is * fairly simple compared to in a normal CFG. Needs access to constants for * identity function simplification. */ -pub fn gvn(function: &mut Function, constants: &Vec<Constant>, def_use: &ImmutableDefUseMap) { - // Step 1: create worklist (starts as all nodes) and value number hashmap. - let mut worklist: Vec<_> = (0..function.nodes.len()).rev().map(NodeID::new).collect(); +pub fn gvn(editor: &mut FunctionEditor, constants: &Vec<Constant>) { + // Create worklist (starts as all nodes) and value number hashmap. + let mut worklist: Vec<NodeID> = (0..editor.func().nodes.len()).map(NodeID::new).collect(); let mut value_numbers: HashMap<Node, NodeID> = HashMap::new(); - // Step 2: do worklist. while let Some(work) = worklist.pop() { - // First, iteratively simplify the work node by unwrapping identity - // functions. - let value = crawl_identities(work, function, constants); + // First, simplify the work node by unwrapping identity functions. + let value = crawl_identities(work, editor.func(), constants); // Next, check if there is a value number for this simplified value yet. - if let Some(leader) = value_numbers.get(&function.nodes[value.idx()]) { - // Also need to check that leader is not the current work ID. The - // leader should never remove itself. - if *leader != work { - // If there is a value number (a previously found Node ID) for the - // current node, then replace all users' uses of the current work - // node ID with the value number node ID. - for user in def_use.get_users(work) { - for u in get_uses_mut(&mut function.nodes[user.idx()]).as_mut() { - if **u == work { - **u = *leader; - } - } + if let Some(number) = value_numbers.get(&editor.func().nodes[value.idx()]) { + // If the number is this worklist item, there's nothing to be done. + if *number == work { + continue; + } - // Since we modified user, it may now be congruent to other - // nodes, so add it back into the worklist. - worklist.push(*user); - } + // Record the users of `work` before making any edits. + let work_users: Vec<NodeID> = editor.users(work).collect(); - // Since all ex-users now use the value number node ID, delete this - // node. - function.nodes[work.idx()] = Node::Start; + // At this point, we know the number of the node IDed `work` is + // `number`. We want to replace `work` with `number`, which means + // 1. replacing all uses of `work` with `number` + // 2. deleting `work` + let success = editor.edit(|edit| { + let edit = edit.replace_all_uses(work, *number)?; + edit.delete_node(work) + }); - // Explicitly continue to branch away from adding current work - // as leader into value_numbers. - continue; + // If the edit was performed, then the old users of `work` may now + // be congruent to other nodes. + if success { + worklist.extend(work_users); } + } else { + // If there isn't a number yet for this value, assign the value the + // ID of the node as its number. + value_numbers.insert(editor.func().nodes[value.idx()].clone(), value); + // The crawled identity `value` may be different than `work` - add + // `work` back on to the worklist so that in a subsequent iteration, + // we can simplify it. + worklist.push(work); } - // If not found, insert the simplified node with its node ID as the - // value number. - value_numbers.insert(function.nodes[value.idx()].clone(), value); } } diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 5e53d5327765139e80897006b3cb0375f00a9378..ff789dd2da1648fee29796871505a9f4fd642dc1 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -2,6 +2,7 @@ pub mod ccp; pub mod dce; +pub mod editor; pub mod fork_guard_elim; pub mod forkify; pub mod gvn; @@ -12,6 +13,7 @@ pub mod sroa; pub use crate::ccp::*; pub use crate::dce::*; +pub use crate::editor::*; pub use crate::fork_guard_elim::*; pub use crate::forkify::*; pub use crate::gvn::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 719819fff19b4f937250e2d3c74b9a7c4f373a4c..e96fb4b62e3f36557fe53de201b2c8ee826b91fc 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -237,22 +237,39 @@ impl PassManager { self.make_doms(); self.make_antideps(); self.make_loops(); + self.make_fork_join_maps(); let def_uses = self.def_uses.as_ref().unwrap().iter(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter(); let doms = self.doms.as_ref().unwrap().iter(); let antideps = self.antideps.as_ref().unwrap().iter(); let loops = self.loops.as_ref().unwrap().iter(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); self.bbs = Some( zip( self.module.functions.iter(), zip( def_uses, - zip(reverse_postorders, zip(doms, zip(antideps, loops))), + zip( + reverse_postorders, + zip(doms, zip(antideps, zip(loops, fork_join_maps))), + ), ), ) .map( - |(function, (def_use, (reverse_postorder, (dom, (antideps, loops)))))| { - gcm(function, def_use, reverse_postorder, dom, antideps, loops) + |( + function, + (def_use, (reverse_postorder, (dom, (antideps, (loops, fork_join_map))))), + )| { + gcm( + function, + def_use, + reverse_postorder, + dom, + antideps, + loops, + fork_join_map, + &vec![None; function.nodes.len()], + ) }, ) .collect(), @@ -308,9 +325,23 @@ impl PassManager { for pass in self.passes.clone().iter() { match pass { Pass::DCE => { + self.make_def_uses(); + let def_uses = self.def_uses.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - dce(&mut self.module.functions[idx]); + let mut editor = + FunctionEditor::new(&mut self.module.functions[idx], &def_uses[idx]); + dce(&mut editor); + + let edits = &editor.edits(); + if let Some(plans) = self.plans.as_mut() { + repair_plan(&mut plans[idx], &self.module.functions[idx], edits); + } + let grave_mapping = self.module.functions[idx].delete_gravestones(); + if let Some(plans) = self.plans.as_mut() { + plans[idx].fix_gravestones(&grave_mapping); + } } + self.clear_analyses(); } Pass::CCP => { self.make_def_uses(); @@ -325,17 +356,27 @@ impl PassManager { &reverse_postorders[idx], ); } + self.legacy_repair_plan(); + self.clear_analyses(); } Pass::GVN => { self.make_def_uses(); let def_uses = self.def_uses.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - gvn( - &mut self.module.functions[idx], - &self.module.constants, - &def_uses[idx], - ); + let mut editor = + FunctionEditor::new(&mut self.module.functions[idx], &def_uses[idx]); + gvn(&mut editor, &self.module.constants); + + let edits = &editor.edits(); + if let Some(plans) = self.plans.as_mut() { + repair_plan(&mut plans[idx], &self.module.functions[idx], edits); + } + let grave_mapping = self.module.functions[idx].delete_gravestones(); + if let Some(plans) = self.plans.as_mut() { + plans[idx].fix_gravestones(&grave_mapping); + } } + self.clear_analyses(); } Pass::Forkify => { self.make_def_uses(); @@ -351,11 +392,15 @@ impl PassManager { &loops[idx], ) } + self.legacy_repair_plan(); + self.clear_analyses(); } Pass::PhiElim => { for function in self.module.functions.iter_mut() { phi_elim(function); } + self.legacy_repair_plan(); + self.clear_analyses(); } Pass::ForkGuardElim => { self.make_def_uses(); @@ -370,6 +415,8 @@ impl PassManager { &def_uses[idx], ) } + self.legacy_repair_plan(); + self.clear_analyses(); } Pass::Predication => { self.make_def_uses(); @@ -392,15 +439,10 @@ impl PassManager { &plans[idx].schedules, ) } + self.legacy_repair_plan(); + self.clear_analyses(); } Pass::SROA => { - println!("{:?}", self.module.functions[0].nodes); - println!("{:?}", self.module.constants); - for ty_id in (0..self.module.types.len()).map(TypeID::new) { - let mut str_ty = "".to_string(); - self.module.write_type(ty_id, &mut str_ty).unwrap(); - println!("{}: {}", ty_id.idx(), str_ty); - } self.make_def_uses(); self.make_reverse_postorders(); self.make_typing(); @@ -417,6 +459,8 @@ impl PassManager { &mut self.module.constants, ); } + self.legacy_repair_plan(); + self.clear_analyses(); } Pass::Verify => { let ( @@ -450,9 +494,6 @@ impl PassManager { ); } } - - // Verify doesn't require clearing analysis results. - continue; } Pass::Xdot(force_analyses) => { self.make_reverse_postorders(); @@ -470,9 +511,6 @@ impl PassManager { self.bbs.as_ref(), self.plans.as_ref(), ); - - // Xdot doesn't require clearing analysis results. - continue; } Pass::SchedXdot => { self.make_def_uses(); @@ -497,9 +535,6 @@ impl PassManager { ); xdot_sched_module(&smodule); - - // Xdot doesn't require clearing analysis results. - continue; } Pass::Codegen => { self.make_def_uses(); @@ -570,25 +605,23 @@ impl PassManager { file.write_all(&hman_contents) .expect("PANIC: Unable to write output manifest file contents."); self.manifests = Some(smodule.manifests); - - // Codegen doesn't require clearing analysis results. - continue; } } + } + } - // Cleanup the module after passes. Delete gravestone nodes. Repair - // the plans. Clear out-of-date analyses. - for idx in 0..self.module.functions.len() { - let grave_mapping = self.module.functions[idx].delete_gravestones(); - let plans = &mut self.plans; - let functions = &self.module.functions; - if let Some(plans) = plans.as_mut() { - take_mut::take(&mut plans[idx], |plan| { - plan.repair(&functions[idx], &grave_mapping) - }); - } + fn legacy_repair_plan(&mut self) { + // Cleanup the module after passes. Delete gravestone nodes. Repair the + // plans. + for idx in 0..self.module.functions.len() { + let grave_mapping = self.module.functions[idx].delete_gravestones(); + let plans = &mut self.plans; + let functions = &self.module.functions; + if let Some(plans) = plans.as_mut() { + take_mut::take(&mut plans[idx], |plan| { + plan.repair(&functions[idx], &grave_mapping) + }); } - self.clear_analyses(); } }