-
Russel Arbore authoredRussel Arbore authored
loops.rs 11.79 KiB
use std::collections::hash_map;
use std::collections::VecDeque;
use std::collections::{HashMap, HashSet};
use bitvec::prelude::*;
use crate::*;
/*
* Custom type for storing a loop tree. Each node corresponds to a single loop
* or a fork join pair in the IR graph. Each node in the tree corresponds to
* some subset of the overall IR graph. The root node corresponds to the entire
* IR graph. The children of the root correspond to the top-level loops and fork
* join pairs, and so on. Each node in the loop tree has a representative
* "header" node. For normal loops, this is the region node branched to by a
* dominated if node. For fork join pairs, this is the fork node. A loop is a
* top-level loop if its parent is the root node of the subgraph. Each node in
* the tree is an entry in the loops HashMap - the key is the "header" node for
* the loop, and the value is a pair of the set of control nodes inside the loop
* and this loop's parent header.
*/
#[derive(Debug, Clone)]
pub struct LoopTree {
root: NodeID,
loops: HashMap<NodeID, (BitVec<u8, Lsb0>, NodeID)>,
inverse_loops: HashMap<NodeID, NodeID>,
nesting: HashMap<NodeID, usize>,
}
impl LoopTree {
pub fn contains(&self, x: NodeID) -> bool {
x == self.root || self.loops.contains_key(&x)
}
pub fn loops(&self) -> hash_map::Iter<'_, NodeID, (BitVec<u8, Lsb0>, NodeID)> {
self.loops.iter()
}
pub fn nodes_in_loop(&self, header: NodeID) -> impl Iterator<Item = NodeID> + '_ {
self.loops[&header].0.iter_ones().map(NodeID::new)
}
pub fn is_in_loop(&self, header: NodeID, is_in: NodeID) -> bool {
header == self.root || self.loops[&header].0[is_in.idx()]
}
pub fn header_of(&self, control_node: NodeID) -> Option<NodeID> {
self.inverse_loops.get(&control_node).map(|h| *h)
}
/*
* Sometimes, we need to iterate the loop tree bottom-up. Just assemble the
* order upfront.
*/
pub fn bottom_up_loops(&self) -> Vec<(NodeID, &BitVec<u8, Lsb0>)> {
let mut bottom_up = vec![];
let mut children_count: HashMap<NodeID, u32> =
self.loops.iter().map(|(k, _)| (*k, 0)).collect();
children_count.insert(self.root, 0);
for (_, (_, parent)) in self.loops.iter() {
*children_count.get_mut(&parent).unwrap() += 1;
}
let mut worklist: VecDeque<_> = self.loops.iter().map(|(k, v)| (*k, &v.0)).collect();
while let Some(pop) = worklist.pop_front() {
if children_count[&pop.0] == 0 {
*children_count.get_mut(&self.loops[&pop.0].1).unwrap() -= 1;
bottom_up.push(pop);
} else {
worklist.push_back(pop);
}
}
bottom_up
}
/*
* Gets the nesting of a loop, keyed by the header.
*/
pub fn nesting(&self, header: NodeID) -> Option<usize> {
self.nesting.get(&header).map(|id| *id)
}
}
/*
* Top level function for calculating loop trees.
*/
pub fn loops(
subgraph: &Subgraph,
root: NodeID,
dom: &DomTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> LoopTree {
// Step 1: collect loop back edges.
let mut loop_back_edges = vec![];
for node in subgraph.iter() {
// Check successors. Any successor dominating its predecessor is the
// destination of a loop back edge.
for succ in subgraph.succs(*node) {
if dom.does_dom(succ, *node) {
loop_back_edges.push((*node, succ));
}
}
}
// Step 2: collect "edges" from joins to forks. Technically, this doesn't
// correspond to a real edge in the graph. However, our loop tree includes
// fork join pairs as loops, so create a phantom loop back edge.
for (fork, join) in fork_join_map {
loop_back_edges.push((*join, *fork));
}
// Step 3: find control nodes inside each loop. For a particular natural
// loop with header d and a back edge from node n to d, the nodes in the
// loop are d itself, and all nodes with a path to n not going through d.
let loop_contents = loop_back_edges.iter().map(|(n, d)| {
// Compute reachability for each loop back edge.
let mut loop_contents = loop_reachability(*n, *d, subgraph);
loop_contents.set(d.idx(), true);
(d, loop_contents)
});
// Step 4: merge loops with same header into a single natural loop.
let mut loops: HashMap<NodeID, BitVec<u8, Lsb0>> = HashMap::new();
for (header, contents) in loop_contents {
if loops.contains_key(header) {
let old_contents = loops.remove(header).unwrap();
loops.insert(*header, old_contents | contents);
} else {
loops.insert(*header, contents);
}
}
// Step 5: figure out loop tree edges. A loop with header a can only be an
// outer loop of a loop with header b if a dominates b.
let loops: HashMap<NodeID, (BitVec<u8, Lsb0>, NodeID)> = loops
.iter()
.map(|(header, contents)| {
let mut dominator = *header;
// Climb the cominator tree.
while let Some(new_dominator) = dom.imm_dom(dominator) {
dominator = new_dominator;
// Check if the dominator node is a loop header.
if let Some(outer_contents) = loops.get(&dominator) {
// Check if the dominating loop actually contains this loop.
if outer_contents[header.idx()] {
return (*header, (contents.clone(), dominator));
}
}
}
// If no dominating node is a loop header for a loop containing this
// loop, then this loop is a top-level loop.
(*header, (contents.clone(), root))
})
.collect();
// Step 6: compute loop tree nesting.
let mut nesting = HashMap::new();
let mut worklist: VecDeque<NodeID> = loops.keys().map(|id| *id).collect();
while let Some(header) = worklist.pop_front() {
let parent = loops[&header].1;
if parent == root {
nesting.insert(header, 0);
} else if let Some(nest) = nesting.get(&parent) {
nesting.insert(header, nest + 1);
} else {
worklist.push_back(header);
}
}
// Step 7: compute the inverse loop map - this maps control nodes to which
// loop they are in (identified by header), if they are in one. Pick the
// most nested loop as the loop they are in.
let mut inverse_loops = HashMap::new();
for (header, (contents, _)) in loops.iter() {
for idx in contents.iter_ones() {
let id = NodeID::new(idx);
if let Some(old_header) = inverse_loops.get(&id)
&& nesting[old_header] > nesting[header]
{
// If the inserted header is more deeply nested, don't do anything.
assert!(nesting[old_header] != nesting[header] || old_header == header);
} else {
inverse_loops.insert(id, *header);
}
}
}
LoopTree {
root,
loops,
inverse_loops,
nesting,
}
}
fn loop_reachability(n: NodeID, d: NodeID, subgraph: &Subgraph) -> BitVec<u8, Lsb0> {
let visited = bitvec![u8, Lsb0; 0; subgraph.original_num_nodes() as usize];
// n is the root of the traversal, finding d is a termination condition.
let visited = loop_reachability_helper(n, d, subgraph, visited);
visited
}
fn loop_reachability_helper(
n: NodeID,
d: NodeID,
subgraph: &Subgraph,
mut visited: BitVec<u8, Lsb0>,
) -> BitVec<u8, Lsb0> {
if visited[n.idx()] {
// If already visited, return early.
visited
} else {
// Set visited to true.
visited.set(n.idx(), true);
// Iterate over predecessors.
for pred in subgraph.preds(n) {
// Don't traverse d.
if pred != d {
visited = loop_reachability_helper(pred, d, subgraph, visited);
}
}
visited
}
}
/*
* Top level function to calculate reduce cycles. Returns for each reduce node
* what other nodes form a cycle with that reduce node.
*/
pub fn reduce_cycles(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let reduces = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
.map(NodeID::new);
let mut result = HashMap::new();
let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
for reduce in reduces {
let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];
let isnt_outside_fork_join = |id: NodeID| {
let node = &function.nodes[id.idx()];
node.try_phi()
.map(|(control, _)| control)
.or(node.try_reduce().map(|(control, _, _)| control))
.map(|control| fork_join_nest[&fork].contains(&control))
.unwrap_or(true)
};
// First, find all data nodes that are used by the `reduct` input of the
// reduce, including the `reduct` itself.
let mut use_reachable = HashSet::new();
use_reachable.insert(reduct);
let mut worklist = vec![reduct];
while let Some(item) = worklist.pop() {
for u in get_uses(&function.nodes[item.idx()]).as_ref() {
if !function.nodes[u.idx()].is_control()
&& !use_reachable.contains(u)
&& isnt_outside_fork_join(*u)
{
use_reachable.insert(*u);
worklist.push(*u);
}
}
}
// Second, find all data nodes thare are users of the reduce node.
let mut user_reachable = HashSet::new();
let mut worklist = vec![reduce];
while let Some(item) = worklist.pop() {
for u in def_use.get_users(item) {
if !function.nodes[u.idx()].is_control()
&& !user_reachable.contains(u)
&& isnt_outside_fork_join(*u)
{
user_reachable.insert(*u);
worklist.push(*u);
}
}
}
// Nodes that are both use-reachable and user-reachable by the reduce
// node are in the reduce node's cycle.
result.insert(
reduce,
use_reachable
.intersection(&user_reachable)
.map(|id| *id)
.collect(),
);
}
result
}
/*
* Top level function to calculate which data nodes are "inside" a fork-join,
* not including its reduces.
*/
pub fn data_nodes_in_fork_joins(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let mut result = HashMap::new();
for (fork, join) in fork_join_map {
let mut worklist = vec![*fork];
let mut set = HashSet::new();
while let Some(item) = worklist.pop() {
for u in def_use.get_users(item) {
if function.nodes[u.idx()].is_control()
|| function.nodes[u.idx()]
.try_reduce()
.map(|(control, _, _)| control == *join)
.unwrap_or(false)
{
// Ignore control users and reduces of the fork-join.
continue;
}
if !set.contains(u) {
set.insert(*u);
worklist.push(*u);
}
}
}
result.insert(*fork, set);
}
result
}