use std::collections::{HashMap, HashSet, VecDeque}; use std::iter::{zip, FromIterator}; use crate::*; /* * Top level global code motion function. Assigns each data node to one of its * immediate control use / user nodes, forming (unordered) basic blocks. Returns * the control node / basic block each node is in. Takes in a partial * partitioning that must be respected. Based on the schedule-early-schedule- * late method from Cliff Click's PhD thesis. */ pub fn gcm( function: &Function, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, dom: &DomTree, antideps: &Vec<(NodeID, NodeID)>, loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, partial_partition: &mut Vec<Option<PartitionID>>, ) -> Vec<NodeID> { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; // Step 1: assign the basic block locations of all nodes that must be in a // specific block. This includes control nodes as well as some special data // nodes, such as phis. for idx in 0..function.nodes.len() { match function.nodes[idx] { Node::Phi { control, data: _ } => bbs[idx] = Some(control), Node::ThreadID { control, dimension: _, } => bbs[idx] = Some(control), Node::Reduce { control, init: _, reduct: _, } => bbs[idx] = Some(control), Node::Call { control, function: _, dynamic_constants: _, args: _, } => bbs[idx] = Some(control), Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)), Node::Constant { id: _ } => bbs[idx] = Some(NodeID::new(0)), Node::DynamicConstant { id: _ } => bbs[idx] = Some(NodeID::new(0)), _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)), _ => {} } } // Step 2: schedule early. Place nodes in the earliest position they could // go - use worklist to iterate nodes. let mut schedule_early = bbs.clone(); let mut antideps_uses = HashMap::<NodeID, Vec<NodeID>>::new(); for (read, write) in antideps { antideps_uses.entry(*write).or_default().push(*read); } let mut worklist = VecDeque::from(reverse_postorder.clone()); while let Some(id) = worklist.pop_front() { if schedule_early[id.idx()].is_some() { continue; } // For every use, check what block is its "schedule early" block. This // node goes in the lowest block amongst those blocks. let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()]) .as_ref() .into_iter() .map(|id| *id) // Include "uses" from anti-dependencies. .chain( antideps_uses .get(&id) .unwrap_or(&vec![]) .into_iter() .map(|id| *id), ) .map(|id| schedule_early[id.idx()]) .collect(); if let Some(use_places) = use_places { // If every use has been placed, we can place this node as the // lowest place in the domtree that dominates all of the use places. let lowest = dom.lowest_amongst(use_places.into_iter()); schedule_early[id.idx()] = Some(lowest); } else { // If not, then just push this node back on the worklist. worklist.push_back(id); } } // Step 3: schedule late and pick each nodes final position. Since the late // schedule of each node depends on the final positions of its users, these // two steps must be fused. Compute their latest position, then use the // control dependent + shallow loop heuristic to actually place them. let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map .into_iter() .map(|(fork, join)| (*join, *fork)) .collect(); let mut antideps_users = HashMap::<NodeID, Vec<NodeID>>::new(); for (read, write) in antideps { antideps_users.entry(*read).or_default().push(*write); } let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev()); while let Some(id) = worklist.pop_front() { if bbs[id.idx()].is_some() { continue; } // Calculate the least common ancestor of user blocks, a.k.a. the "late" // schedule. let calculate_lca = || -> Option<_> { let mut lca = None; // Helper to incrementally update the LCA. let mut update_lca = |a| { if let Some(acc) = lca { lca = Some(dom.least_common_ancestor(acc, a)); } else { lca = Some(a); } }; // For every user, consider where we need to be to directly dominate the // user. for user in def_use .get_users(id) .as_ref() .into_iter() .map(|id| *id) // Include "users" from anti-dependencies. .chain( antideps_users .get(&id) .unwrap_or(&vec![]) .into_iter() .map(|id| *id), ) { if let Node::Phi { control, data } = &function.nodes[user.idx()] { // For phis, we need to dominate the block jumping to the phi in // the slot that corresponds to our use. for (control, data) in zip(get_uses(&function.nodes[control.idx()]).as_ref(), data) { if id == *data { update_lca(*control); } } } else if let Node::Reduce { control, init, reduct, } = &function.nodes[user.idx()] { // For reduces, we need to either dominate the block right // before the fork if we're the init input, or we need to // dominate the join if we're the reduct input. if id == *init { let before_fork = function.nodes[join_fork_map[control].idx()] .try_fork() .unwrap() .0; update_lca(before_fork); } else { assert_eq!(id, *reduct); update_lca(*control); } } else { // For everything else, we just need to dominate the user. update_lca(bbs[user.idx()]?); } } Some(lca) }; // Check if all users have been placed. If one of them hasn't, then add // this node back on to the worklist. let Some(lca) = calculate_lca() else { worklist.push_back(id); continue; }; // #[feature(iter_collect_into)] // Look between the LCA and the schedule early location to place the // node. If a data node can't be scheduled to any control nodes in its // partition (this may happen if all of the control nodes in a partition // got deleted, for example), then the data node can be scheduled into a // control node in a different partition. let schedule_early = schedule_early[id.idx()].unwrap(); let need_to_repartition = !dom .chain(lca.unwrap_or(schedule_early), schedule_early) .any(|dominator| { partial_partition[id.idx()].is_none() || partial_partition[dominator.idx()] == partial_partition[id.idx()] }); if need_to_repartition { partial_partition[id.idx()] = None; } let mut chain = dom // If the node has no users, then it doesn't really matter where we // place it - just place it at the early placement. .chain(lca.unwrap_or(schedule_early), schedule_early) // Filter the dominator chain by control nodes in the same partition // as this data node, if the data node is in a partition already. .filter(|dominator| { partial_partition[id.idx()].is_none() || partial_partition[dominator.idx()] == partial_partition[id.idx()] }); let mut location = chain.next().unwrap(); while let Some(control_node) = chain.next() { // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. let shallower_nest = if let (Some(old_nest), Some(new_nest)) = (loops.nesting(location), loops.nesting(control_node)) { old_nest > new_nest } else { false }; // This will move all nodes that don't need to be in reduce loops // outside of reduce loops. Nodes that do need to be in a reduce // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join(); if shallower_nest || currently_at_join { location = control_node; } } bbs[id.idx()] = Some(location); } bbs.into_iter().map(Option::unwrap).collect() } /* * Top level function for creating a fork-join map. Map is from fork node ID to * join node ID, since a join can easily determine the fork it corresponds to * (that's the mechanism used to implement this analysis). This analysis depends * on type information. */ pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> { let mut fork_join_map = HashMap::new(); for idx in 0..function.nodes.len() { // We only care about join nodes. if function.nodes[idx].is_join() { // Iterate the control predecessors until finding a fork. Maintain a // counter of unmatched fork-join pairs seen on the way, since fork- // joins may be nested. Every join is dominated by their fork, so // just iterate the first unseen predecessor of each control node. let join_id = NodeID::new(idx); let mut unpaired = 0; let mut cursor = join_id; let mut seen = HashSet::<NodeID>::new(); let fork_id = loop { cursor = control .preds(cursor) .filter(|pred| !seen.contains(pred)) .next() .unwrap(); seen.insert(cursor); if function.nodes[cursor.idx()].is_join() { unpaired += 1; } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 { unpaired -= 1; } else if function.nodes[cursor.idx()].is_fork() { break cursor; } }; fork_join_map.insert(fork_id, join_id); } } fork_join_map } /* * Find fork/join nests that each control node is inside of. Result is a map * from each control node to a list of fork nodes. The fork nodes are listed in * ascending order of nesting. */ pub fn compute_fork_join_nesting( function: &Function, dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, ) -> HashMap<NodeID, Vec<NodeID>> { // For each control node, ascend dominator tree, looking for fork nodes. For // each fork node, make sure each control node isn't strictly dominated by // the corresponding join node. (0..function.nodes.len()) .map(NodeID::new) .filter(|id| dom.contains(*id)) .map(|id| { ( id, dom.ascend(id) // Filter for forks that dominate this control node, .filter(|id| function.nodes[id.idx()].is_fork()) // where its corresponding join doesn't dominate the control // node (if so, then this control is after the fork-join). .filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id)) .collect(), ) }) .collect() }