-
Aaron Councilman authoredAaron Councilman authored
gcm.rs 12.59 KiB
use std::collections::{HashMap, HashSet, VecDeque};
use std::iter::{zip, FromIterator};
use crate::*;
/*
* Top level global code motion function. Assigns each data node to one of its
* immediate control use / user nodes, forming (unordered) basic blocks. Returns
* the control node / basic block each node is in. Takes in a partial
* partitioning that must be respected. Based on the schedule-early-schedule-
* late method from Cliff Click's PhD thesis.
*/
pub fn gcm(
function: &Function,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
dom: &DomTree,
antideps: &Vec<(NodeID, NodeID)>,
loops: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
partial_partition: &mut Vec<Option<PartitionID>>,
) -> Vec<NodeID> {
let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
// Step 1: assign the basic block locations of all nodes that must be in a
// specific block. This includes control nodes as well as some special data
// nodes, such as phis.
for idx in 0..function.nodes.len() {
match function.nodes[idx] {
Node::Phi { control, data: _ } => bbs[idx] = Some(control),
Node::ThreadID {
control,
dimension: _,
} => bbs[idx] = Some(control),
Node::Reduce {
control,
init: _,
reduct: _,
} => bbs[idx] = Some(control),
Node::Call {
control,
function: _,
dynamic_constants: _,
args: _,
} => bbs[idx] = Some(control),
Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
Node::Constant { id: _ } => bbs[idx] = Some(NodeID::new(0)),
Node::DynamicConstant { id: _ } => bbs[idx] = Some(NodeID::new(0)),
_ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
_ => {}
}
}
// Step 2: schedule early. Place nodes in the earliest position they could
// go - use worklist to iterate nodes.
let mut schedule_early = bbs.clone();
let mut antideps_uses = HashMap::<NodeID, Vec<NodeID>>::new();
for (read, write) in antideps {
antideps_uses.entry(*write).or_default().push(*read);
}
let mut worklist = VecDeque::from(reverse_postorder.clone());
while let Some(id) = worklist.pop_front() {
if schedule_early[id.idx()].is_some() {
continue;
}
// For every use, check what block is its "schedule early" block. This
// node goes in the lowest block amongst those blocks.
let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()])
.as_ref()
.into_iter()
.map(|id| *id)
// Include "uses" from anti-dependencies.
.chain(
antideps_uses
.get(&id)
.unwrap_or(&vec![])
.into_iter()
.map(|id| *id),
)
.map(|id| schedule_early[id.idx()])
.collect();
if let Some(use_places) = use_places {
// If every use has been placed, we can place this node as the
// lowest place in the domtree that dominates all of the use places.
let lowest = dom.lowest_amongst(use_places.into_iter());
schedule_early[id.idx()] = Some(lowest);
} else {
// If not, then just push this node back on the worklist.
worklist.push_back(id);
}
}
// Step 3: schedule late and pick each nodes final position. Since the late
// schedule of each node depends on the final positions of its users, these
// two steps must be fused. Compute their latest position, then use the
// control dependent + shallow loop heuristic to actually place them.
let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
let mut antideps_users = HashMap::<NodeID, Vec<NodeID>>::new();
for (read, write) in antideps {
antideps_users.entry(*read).or_default().push(*write);
}
let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
while let Some(id) = worklist.pop_front() {
if bbs[id.idx()].is_some() {
continue;
}
// Calculate the least common ancestor of user blocks, a.k.a. the "late"
// schedule.
let calculate_lca = || -> Option<_> {
let mut lca = None;
// Helper to incrementally update the LCA.
let mut update_lca = |a| {
if let Some(acc) = lca {
lca = Some(dom.least_common_ancestor(acc, a));
} else {
lca = Some(a);
}
};
// For every user, consider where we need to be to directly dominate the
// user.
for user in def_use
.get_users(id)
.as_ref()
.into_iter()
.map(|id| *id)
// Include "users" from anti-dependencies.
.chain(
antideps_users
.get(&id)
.unwrap_or(&vec![])
.into_iter()
.map(|id| *id),
)
{
if let Node::Phi { control, data } = &function.nodes[user.idx()] {
// For phis, we need to dominate the block jumping to the phi in
// the slot that corresponds to our use.
for (control, data) in
zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
{
if id == *data {
update_lca(*control);
}
}
} else if let Node::Reduce {
control,
init,
reduct,
} = &function.nodes[user.idx()]
{
// For reduces, we need to either dominate the block right
// before the fork if we're the init input, or we need to
// dominate the join if we're the reduct input.
if id == *init {
let before_fork = function.nodes[join_fork_map[control].idx()]
.try_fork()
.unwrap()
.0;
update_lca(before_fork);
} else {
assert_eq!(id, *reduct);
update_lca(*control);
}
} else {
// For everything else, we just need to dominate the user.
update_lca(bbs[user.idx()]?);
}
}
Some(lca)
};
// Check if all users have been placed. If one of them hasn't, then add
// this node back on to the worklist.
let Some(lca) = calculate_lca() else {
worklist.push_back(id);
continue;
};
// #[feature(iter_collect_into)]
// Look between the LCA and the schedule early location to place the
// node. If a data node can't be scheduled to any control nodes in its
// partition (this may happen if all of the control nodes in a partition
// got deleted, for example), then the data node can be scheduled into a
// control node in a different partition.
let schedule_early = schedule_early[id.idx()].unwrap();
let need_to_repartition = !dom
.chain(lca.unwrap_or(schedule_early), schedule_early)
.any(|dominator| {
partial_partition[id.idx()].is_none()
|| partial_partition[dominator.idx()] == partial_partition[id.idx()]
});
if need_to_repartition {
partial_partition[id.idx()] = None;
}
let mut chain = dom
// If the node has no users, then it doesn't really matter where we
// place it - just place it at the early placement.
.chain(lca.unwrap_or(schedule_early), schedule_early)
// Filter the dominator chain by control nodes in the same partition
// as this data node, if the data node is in a partition already.
.filter(|dominator| {
partial_partition[id.idx()].is_none()
|| partial_partition[dominator.idx()] == partial_partition[id.idx()]
});
let mut location = chain.next().unwrap();
while let Some(control_node) = chain.next() {
// If the next node further up the dominator tree is in a shallower
// loop nest or if we can get out of a reduce loop when we don't
// need to be in one, place this data node in a higher-up location.
let shallower_nest = if let (Some(old_nest), Some(new_nest)) =
(loops.nesting(location), loops.nesting(control_node))
{
old_nest > new_nest
} else {
false
};
// This will move all nodes that don't need to be in reduce loops
// outside of reduce loops. Nodes that do need to be in a reduce
// loop use the reduce node forming the loop, so the dominator chain
// will consist of one block, and this loop won't ever iterate.
let currently_at_join = function.nodes[location.idx()].is_join();
if shallower_nest || currently_at_join {
location = control_node;
}
}
bbs[id.idx()] = Some(location);
}
bbs.into_iter().map(Option::unwrap).collect()
}
/*
* Top level function for creating a fork-join map. Map is from fork node ID to
* join node ID, since a join can easily determine the fork it corresponds to
* (that's the mechanism used to implement this analysis). This analysis depends
* on type information.
*/
pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> {
let mut fork_join_map = HashMap::new();
for idx in 0..function.nodes.len() {
// We only care about join nodes.
if function.nodes[idx].is_join() {
// Iterate the control predecessors until finding a fork. Maintain a
// counter of unmatched fork-join pairs seen on the way, since fork-
// joins may be nested. Every join is dominated by their fork, so
// just iterate the first unseen predecessor of each control node.
let join_id = NodeID::new(idx);
let mut unpaired = 0;
let mut cursor = join_id;
let mut seen = HashSet::<NodeID>::new();
let fork_id = loop {
cursor = control
.preds(cursor)
.filter(|pred| !seen.contains(pred))
.next()
.unwrap();
seen.insert(cursor);
if function.nodes[cursor.idx()].is_join() {
unpaired += 1;
} else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 {
unpaired -= 1;
} else if function.nodes[cursor.idx()].is_fork() {
break cursor;
}
};
fork_join_map.insert(fork_id, join_id);
}
}
fork_join_map
}
/*
* Find fork/join nests that each control node is inside of. Result is a map
* from each control node to a list of fork nodes. The fork nodes are listed in
* ascending order of nesting.
*/
pub fn compute_fork_join_nesting(
function: &Function,
dom: &DomTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, Vec<NodeID>> {
// For each control node, ascend dominator tree, looking for fork nodes. For
// each fork node, make sure each control node isn't strictly dominated by
// the corresponding join node.
(0..function.nodes.len())
.map(NodeID::new)
.filter(|id| dom.contains(*id))
.map(|id| {
(
id,
dom.ascend(id)
// Filter for forks that dominate this control node,
.filter(|id| function.nodes[id.idx()].is_fork())
// where its corresponding join doesn't dominate the control
// node (if so, then this control is after the fork-join).
.filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id))
.collect(),
)
})
.collect()
}