fork_join_analysis.rs 6.74 KiB
use std::collections::{HashMap, HashSet};
use crate::*;
/*
* Top level function for creating a fork-join map. Map is from fork node ID to
* join node ID, since a join can easily determine the fork it corresponds to
* (that's the mechanism used to implement this analysis). This analysis depends
* on type information.
*/
pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> {
let mut fork_join_map = HashMap::new();
for idx in 0..function.nodes.len() {
// We only care about join nodes.
if function.nodes[idx].is_join() {
// Iterate the control predecessors until finding a fork. Maintain a
// counter of unmatched fork-join pairs seen on the way, since fork-
// joins may be nested. Every join is dominated by their fork, so
// just iterate the first unseen predecessor of each control node.
let join_id = NodeID::new(idx);
let mut unpaired = 0;
let mut cursor = join_id;
let mut seen = HashSet::<NodeID>::new();
let fork_id = loop {
cursor = control
.preds(cursor)
.filter(|pred| !seen.contains(pred))
.next()
.unwrap();
seen.insert(cursor);
if function.nodes[cursor.idx()].is_join() {
unpaired += 1;
} else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 {
unpaired -= 1;
} else if function.nodes[cursor.idx()].is_fork() {
break cursor;
}
};
fork_join_map.insert(fork_id, join_id);
}
}
fork_join_map
}
/*
* Find fork/join nests that each control node is inside of. Result is a map
* from each control node to a list of fork nodes. The fork nodes are listed in
* ascending order of nesting.
*/
pub fn compute_fork_join_nesting(
function: &Function,
dom: &DomTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, Vec<NodeID>> {
// For each control node, ascend dominator tree, looking for fork nodes. For
// each fork node, make sure each control node isn't strictly dominated by
// the corresponding join node.
(0..function.nodes.len())
.map(NodeID::new)
.filter(|id| dom.contains(*id))
.map(|id| {
(
id,
dom.ascend(id)
// Filter for forks that dominate this control node,
.filter(|id| function.nodes[id.idx()].is_fork())
// where its corresponding join doesn't dominate the control
// node (if so, then this control is after the fork-join).
.filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id))
.collect(),
)
})
.collect()
}
/*
* Top level function to calculate reduce cycles. Returns for each reduce node
* what other nodes form a cycle with that reduce node. The strict definition of
* a reduce cycle is the iterated set of uses of the `reduct` input to the
* reduce that through iterated uses use the reduce, without going through phis
* or reduces outside the fork-join.
*/
pub fn reduce_cycles(
function: &Function,
fork_join_map: &HashMap<NodeID, NodeID>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let reduces = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
.map(NodeID::new);
let mut result = HashMap::new();
let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
for reduce in reduces {
let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];
// DFS the uses of `reduct` until finding the reduce itself.
let mut current_visited = HashSet::new();
let mut in_cycle = HashSet::new();
reduce_cycle_dfs_helper(
function,
reduct,
fork,
reduce,
&mut current_visited,
&mut in_cycle,
fork_join_nest,
);
result.insert(reduce, in_cycle);
}
result
}
fn reduce_cycle_dfs_helper(
function: &Function,
iter: NodeID,
fork: NodeID,
reduce: NodeID,
current_visited: &mut HashSet<NodeID>,
in_cycle: &mut HashSet<NodeID>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> bool {
let isnt_outside_fork_join = |id: NodeID| {
let node = &function.nodes[id.idx()];
node.try_phi()
.map(|(control, _)| control)
.or(node.try_reduce().map(|(control, _, _)| control))
.map(|control| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
};
if iter == reduce || in_cycle.contains(&iter) {
return true;
}
current_visited.insert(iter);
let mut found_reduce = false;
// This doesn't short circuit on purpose.
for u in get_uses(&function.nodes[iter.idx()]).as_ref() {
found_reduce |= !current_visited.contains(u)
&& !function.nodes[u.idx()].is_control()
&& isnt_outside_fork_join(*u)
&& reduce_cycle_dfs_helper(
function,
*u,
fork,
reduce,
current_visited,
in_cycle,
fork_join_nest,
)
}
if found_reduce {
in_cycle.insert(iter);
}
current_visited.remove(&iter);
found_reduce
}
/*
* Top level function to calculate which data nodes are "inside" a fork-join,
* not including its reduces.
*/
pub fn data_nodes_in_fork_joins(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let mut result = HashMap::new();
for (fork, join) in fork_join_map {
let mut worklist = vec![*fork];
let mut set = HashSet::new();
while let Some(item) = worklist.pop() {
for u in def_use.get_users(item) {
if function.nodes[u.idx()].is_control()
|| function.nodes[u.idx()]
.try_reduce()
.map(|(control, _, _)| control == *join)
.unwrap_or(false)
{
// Ignore control users and reduces of the fork-join.
continue;
}
if !set.contains(u) {
set.insert(*u);
worklist.push(*u);
}
}
}
result.insert(*fork, set);
}
result
}