fork_join_analysis.rs 7.85 KiB
use std::collections::{HashMap, HashSet};
use crate::*;
/*
* Top level function for creating a fork-join map. Map is from fork node ID to
* join node ID, since a join can easily determine the fork it corresponds to
* (that's the mechanism used to implement this analysis). This analysis depends
* on type information.
*/
pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> {
let mut fork_join_map = HashMap::new();
for idx in 0..function.nodes.len() {
// We only care about join nodes.
if function.nodes[idx].is_join() {
// Iterate the control predecessors until finding a fork. Maintain a
// counter of unmatched fork-join pairs seen on the way, since fork-
// joins may be nested. Every join is dominated by their fork, so
// just iterate the first unseen predecessor of each control node.
let join_id = NodeID::new(idx);
let mut unpaired = 0;
let mut cursor = join_id;
let mut seen = HashSet::<NodeID>::new();
let fork_id = loop {
cursor = control
.preds(cursor)
.filter(|pred| !seen.contains(pred))
.next()
.unwrap();
seen.insert(cursor);
if function.nodes[cursor.idx()].is_join() {
unpaired += 1;
} else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 {
unpaired -= 1;
} else if function.nodes[cursor.idx()].is_fork() {
break cursor;
}
};
fork_join_map.insert(fork_id, join_id);
}
}
fork_join_map
}
/*
* Find fork/join nests that each control node is inside of. Result is a map
* from each control node to a list of fork nodes. The fork nodes are listed in
* ascending order of nesting.
*/
pub fn compute_fork_join_nesting(
function: &Function,
dom: &DomTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, Vec<NodeID>> {
// For each control node, ascend dominator tree, looking for fork nodes. For
// each fork node, make sure each control node isn't strictly dominated by
// the corresponding join node.
(0..function.nodes.len())
.map(NodeID::new)
.filter(|id| dom.contains(*id))
.map(|id| {
(
id,
dom.ascend(id)
// Filter for forks that dominate this control node,
.filter(|id| function.nodes[id.idx()].is_fork())
// where its corresponding join doesn't dominate the control
// node (if so, then this control is after the fork-join).
// Check for strict dominance since the join itself should
// be nested in the fork.
.filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id))
.collect(),
)
})
.collect()
}
/*
* Top level function to calculate reduce cycles. Returns for each reduce node
* what other nodes form a cycle with that reduce node. The strict definition of
* a reduce cycle is the iterated set of uses of the `reduct` input to the
* reduce that through iterated uses use the reduce, without going through phis
* or reduces outside the fork-join.
*/
pub fn reduce_cycles(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let reduces = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
.map(NodeID::new);
let mut result = HashMap::new();
let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
for reduce in reduces {
let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];
let might_be_in_fork_join = |id| {
fork_join_nest
.get(&id)
.map(|nest| nest.contains(&fork))
.unwrap_or(true)
&& function.nodes[id.idx()]
.try_phi()
.map(|(control, _)| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
&& function.nodes[id.idx()]
.try_reduce()
.map(|(control, _, _)| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
&& !function.nodes[id.idx()].is_parameter()
&& !function.nodes[id.idx()].is_constant()
&& !function.nodes[id.idx()].is_dynamic_constant()
&& !function.nodes[id.idx()].is_undef()
};
// Find nodes in the fork-join that the reduce can reach through uses.
let mut reachable_uses = HashSet::new();
let mut workset = vec![];
reachable_uses.insert(reduct);
workset.push(reduct);
while let Some(pop) = workset.pop() {
for u in get_uses(&function.nodes[pop.idx()]).as_ref() {
if !reachable_uses.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
reachable_uses.insert(*u);
workset.push(*u);
}
}
}
// Find nodes in the fork-join that the reduce can reach through users.
let mut reachable_users = HashSet::new();
workset.clear();
reachable_users.insert(reduce);
workset.push(reduce);
while let Some(pop) = workset.pop() {
for u in def_use.get_users(pop) {
if !reachable_users.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
reachable_users.insert(*u);
workset.push(*u);
}
}
}
// The reduce cycle is the insersection of nodes reachable through uses
// and users.
let intersection = reachable_uses
.intersection(&reachable_users)
.map(|id| *id)
.collect();
result.insert(reduce, intersection);
}
result
}
/*
* Top level function to calculate which nodes are "inside" a fork-join.
*/
pub fn nodes_in_fork_joins(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let mut result = HashMap::new();
// Iterate users of fork until reaching corresponding join or reduces.
for (fork, join) in fork_join_map {
let mut worklist = vec![*fork];
let mut set = HashSet::new();
set.insert(*fork);
// Iterate uses of the fork.
while let Some(item) = worklist.pop() {
for u in def_use.get_users(item) {
let terminate = *u == *join
|| function.nodes[u.idx()]
.try_reduce()
.map(|(control, _, _)| control == *join)
.unwrap_or(false);
if !set.contains(u) && !terminate {
worklist.push(*u);
}
set.insert(*u);
}
}
assert!(set.contains(join));
// Add all the nodes in the reduce cycle. Some of these nodes may not
// use thread IDs of the fork, so do this explicitly.
for u in def_use.get_users(*join) {
if let Some(cycle) = reduce_cycles.get(u) {
set.extend(cycle);
}
}
result.insert(*fork, set);
}
// Add an entry for the start node containing every node.
result.insert(
NodeID::new(0),
(0..function.nodes.len()).map(NodeID::new).collect(),
);
result
}