Skip to content
Snippets Groups Projects
loops.rs 11.79 KiB
use std::collections::hash_map;
use std::collections::VecDeque;
use std::collections::{HashMap, HashSet};

use bitvec::prelude::*;

use crate::*;

/*
 * Custom type for storing a loop tree. Each node corresponds to a single loop
 * or a fork join pair in the IR graph. Each node in the tree corresponds to
 * some subset of the overall IR graph. The root node corresponds to the entire
 * IR graph. The children of the root correspond to the top-level loops and fork
 * join pairs, and so on. Each node in the loop tree has a representative
 * "header" node. For normal loops, this is the region node branched to by a
 * dominated if node. For fork join pairs, this is the fork node. A loop is a
 * top-level loop if its parent is the root node of the subgraph. Each node in
 * the tree is an entry in the loops HashMap - the key is the "header" node for
 * the loop, and the value is a pair of the set of control nodes inside the loop
 * and this loop's parent header.
 */
#[derive(Debug, Clone)]
pub struct LoopTree {
    root: NodeID,
    loops: HashMap<NodeID, (BitVec<u8, Lsb0>, NodeID)>,
    inverse_loops: HashMap<NodeID, NodeID>,
    nesting: HashMap<NodeID, usize>,
}

impl LoopTree {
    pub fn contains(&self, x: NodeID) -> bool {
        x == self.root || self.loops.contains_key(&x)
    }

    pub fn loops(&self) -> hash_map::Iter<'_, NodeID, (BitVec<u8, Lsb0>, NodeID)> {
        self.loops.iter()
    }

    pub fn nodes_in_loop(&self, header: NodeID) -> impl Iterator<Item = NodeID> + '_ {
        self.loops[&header].0.iter_ones().map(NodeID::new)
    }

    pub fn is_in_loop(&self, header: NodeID, is_in: NodeID) -> bool {
        header == self.root || self.loops[&header].0[is_in.idx()]
    }

    pub fn header_of(&self, control_node: NodeID) -> Option<NodeID> {
        self.inverse_loops.get(&control_node).map(|h| *h)
    }

    /*
     * Sometimes, we need to iterate the loop tree bottom-up. Just assemble the
     * order upfront.
     */
    pub fn bottom_up_loops(&self) -> Vec<(NodeID, &BitVec<u8, Lsb0>)> {
        let mut bottom_up = vec![];
        let mut children_count: HashMap<NodeID, u32> =
            self.loops.iter().map(|(k, _)| (*k, 0)).collect();
        children_count.insert(self.root, 0);
        for (_, (_, parent)) in self.loops.iter() {
            *children_count.get_mut(&parent).unwrap() += 1;
        }
        let mut worklist: VecDeque<_> = self.loops.iter().map(|(k, v)| (*k, &v.0)).collect();
        while let Some(pop) = worklist.pop_front() {
            if children_count[&pop.0] == 0 {
                *children_count.get_mut(&self.loops[&pop.0].1).unwrap() -= 1;
                bottom_up.push(pop);
            } else {
                worklist.push_back(pop);
            }
        }
        bottom_up
    }

    /*
     * Gets the nesting of a loop, keyed by the header.
     */
    pub fn nesting(&self, header: NodeID) -> Option<usize> {
        self.nesting.get(&header).map(|id| *id)
    }
}

/*
 * Top level function for calculating loop trees.
 */
pub fn loops(
    subgraph: &Subgraph,
    root: NodeID,
    dom: &DomTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> LoopTree {
    // Step 1: collect loop back edges.
    let mut loop_back_edges = vec![];
    for node in subgraph.iter() {
        // Check successors. Any successor dominating its predecessor is the
        // destination of a loop back edge.
        for succ in subgraph.succs(*node) {
            if dom.does_dom(succ, *node) {
                loop_back_edges.push((*node, succ));
            }
        }
    }

    // Step 2: collect "edges" from joins to forks. Technically, this doesn't
    // correspond to a real edge in the graph. However, our loop tree includes
    // fork join pairs as loops, so create a phantom loop back edge.
    for (fork, join) in fork_join_map {
        loop_back_edges.push((*join, *fork));
    }

    // Step 3: find control nodes inside each loop. For a particular natural
    // loop with header d and a back edge from node n to d, the nodes in the
    // loop are d itself, and all nodes with a path to n not going through d.
    let loop_contents = loop_back_edges.iter().map(|(n, d)| {
        // Compute reachability for each loop back edge.
        let mut loop_contents = loop_reachability(*n, *d, subgraph);
        loop_contents.set(d.idx(), true);
        (d, loop_contents)
    });

    // Step 4: merge loops with same header into a single natural loop.
    let mut loops: HashMap<NodeID, BitVec<u8, Lsb0>> = HashMap::new();
    for (header, contents) in loop_contents {
        if loops.contains_key(header) {
            let old_contents = loops.remove(header).unwrap();
            loops.insert(*header, old_contents | contents);
        } else {
            loops.insert(*header, contents);
        }
    }

    // Step 5: figure out loop tree edges. A loop with header a can only be an
    // outer loop of a loop with header b if a dominates b.
    let loops: HashMap<NodeID, (BitVec<u8, Lsb0>, NodeID)> = loops
        .iter()
        .map(|(header, contents)| {
            let mut dominator = *header;
            // Climb the cominator tree.
            while let Some(new_dominator) = dom.imm_dom(dominator) {
                dominator = new_dominator;
                // Check if the dominator node is a loop header.
                if let Some(outer_contents) = loops.get(&dominator) {
                    // Check if the dominating loop actually contains this loop.
                    if outer_contents[header.idx()] {
                        return (*header, (contents.clone(), dominator));
                    }
                }
            }
            // If no dominating node is a loop header for a loop containing this
            // loop, then this loop is a top-level loop.
            (*header, (contents.clone(), root))
        })
        .collect();

    // Step 6: compute loop tree nesting.
    let mut nesting = HashMap::new();
    let mut worklist: VecDeque<NodeID> = loops.keys().map(|id| *id).collect();
    while let Some(header) = worklist.pop_front() {
        let parent = loops[&header].1;
        if parent == root {
            nesting.insert(header, 0);
        } else if let Some(nest) = nesting.get(&parent) {
            nesting.insert(header, nest + 1);
        } else {
            worklist.push_back(header);
        }
    }

    // Step 7: compute the inverse loop map - this maps control nodes to which
    // loop they are in (identified by header), if they are in one. Pick the
    // most nested loop as the loop they are in.
    let mut inverse_loops = HashMap::new();
    for (header, (contents, _)) in loops.iter() {
        for idx in contents.iter_ones() {
            let id = NodeID::new(idx);
            if let Some(old_header) = inverse_loops.get(&id)
                && nesting[old_header] > nesting[header]
            {
                // If the inserted header is more deeply nested, don't do anything.
                assert!(nesting[old_header] != nesting[header] || old_header == header);
            } else {
                inverse_loops.insert(id, *header);
            }
        }
    }

    LoopTree {
        root,
        loops,
        inverse_loops,
        nesting,
    }
}

fn loop_reachability(n: NodeID, d: NodeID, subgraph: &Subgraph) -> BitVec<u8, Lsb0> {
    let visited = bitvec![u8, Lsb0; 0; subgraph.original_num_nodes() as usize];

    // n is the root of the traversal, finding d is a termination condition.
    let visited = loop_reachability_helper(n, d, subgraph, visited);

    visited
}

fn loop_reachability_helper(
    n: NodeID,
    d: NodeID,
    subgraph: &Subgraph,
    mut visited: BitVec<u8, Lsb0>,
) -> BitVec<u8, Lsb0> {
    if visited[n.idx()] {
        // If already visited, return early.
        visited
    } else {
        // Set visited to true.
        visited.set(n.idx(), true);

        // Iterate over predecessors.
        for pred in subgraph.preds(n) {
            // Don't traverse d.
            if pred != d {
                visited = loop_reachability_helper(pred, d, subgraph, visited);
            }
        }

        visited
    }
}

/*
 * Top level function to calculate reduce cycles. Returns for each reduce node
 * what other nodes form a cycle with that reduce node.
 */
pub fn reduce_cycles(
    function: &Function,
    def_use: &ImmutableDefUseMap,
    fork_join_map: &HashMap<NodeID, NodeID>,
    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
    let reduces = (0..function.nodes.len())
        .filter(|idx| function.nodes[*idx].is_reduce())
        .map(NodeID::new);
    let mut result = HashMap::new();
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();

    for reduce in reduces {
        let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
        let fork = join_fork_map[&join];
        let isnt_outside_fork_join = |id: NodeID| {
            let node = &function.nodes[id.idx()];
            node.try_phi()
                .map(|(control, _)| control)
                .or(node.try_reduce().map(|(control, _, _)| control))
                .map(|control| fork_join_nest[&fork].contains(&control))
                .unwrap_or(true)
        };

        // First, find all data nodes that are used by the `reduct` input of the
        // reduce, including the `reduct` itself.
        let mut use_reachable = HashSet::new();
        use_reachable.insert(reduct);
        let mut worklist = vec![reduct];
        while let Some(item) = worklist.pop() {
            for u in get_uses(&function.nodes[item.idx()]).as_ref() {
                if !function.nodes[u.idx()].is_control()
                    && !use_reachable.contains(u)
                    && isnt_outside_fork_join(*u)
                {
                    use_reachable.insert(*u);
                    worklist.push(*u);
                }
            }
        }

        // Second, find all data nodes thare are users of the reduce node.
        let mut user_reachable = HashSet::new();
        let mut worklist = vec![reduce];
        while let Some(item) = worklist.pop() {
            for u in def_use.get_users(item) {
                if !function.nodes[u.idx()].is_control()
                    && !user_reachable.contains(u)
                    && isnt_outside_fork_join(*u)
                {
                    user_reachable.insert(*u);
                    worklist.push(*u);
                }
            }
        }

        // Nodes that are both use-reachable and user-reachable by the reduce
        // node are in the reduce node's cycle.
        result.insert(
            reduce,
            use_reachable
                .intersection(&user_reachable)
                .map(|id| *id)
                .collect(),
        );
    }

    result
}

/*
 * Top level function to calculate which data nodes are "inside" a fork-join,
 * not including its reduces.
 */
pub fn data_nodes_in_fork_joins(
    function: &Function,
    def_use: &ImmutableDefUseMap,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, HashSet<NodeID>> {
    let mut result = HashMap::new();

    for (fork, join) in fork_join_map {
        let mut worklist = vec![*fork];
        let mut set = HashSet::new();

        while let Some(item) = worklist.pop() {
            for u in def_use.get_users(item) {
                if function.nodes[u.idx()].is_control()
                    || function.nodes[u.idx()]
                        .try_reduce()
                        .map(|(control, _, _)| control == *join)
                        .unwrap_or(false)
                {
                    // Ignore control users and reduces of the fork-join.
                    continue;
                }
                if !set.contains(u) {
                    set.insert(*u);
                    worklist.push(*u);
                }
            }
        }

        result.insert(*fork, set);
    }

    result
}