Skip to content
Snippets Groups Projects
fork_join_analysis.rs 6.74 KiB
use std::collections::{HashMap, HashSet};

use crate::*;

/*
 * Top level function for creating a fork-join map. Map is from fork node ID to
 * join node ID, since a join can easily determine the fork it corresponds to
 * (that's the mechanism used to implement this analysis). This analysis depends
 * on type information.
 */
pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> {
    let mut fork_join_map = HashMap::new();
    for idx in 0..function.nodes.len() {
        // We only care about join nodes.
        if function.nodes[idx].is_join() {
            // Iterate the control predecessors until finding a fork. Maintain a
            // counter of unmatched fork-join pairs seen on the way, since fork-
            // joins may be nested. Every join is dominated by their fork, so
            // just iterate the first unseen predecessor of each control node.
            let join_id = NodeID::new(idx);
            let mut unpaired = 0;
            let mut cursor = join_id;
            let mut seen = HashSet::<NodeID>::new();
            let fork_id = loop {
                cursor = control
                    .preds(cursor)
                    .filter(|pred| !seen.contains(pred))
                    .next()
                    .unwrap();
                seen.insert(cursor);

                if function.nodes[cursor.idx()].is_join() {
                    unpaired += 1;
                } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 {
                    unpaired -= 1;
                } else if function.nodes[cursor.idx()].is_fork() {
                    break cursor;
                }
            };
            fork_join_map.insert(fork_id, join_id);
        }
    }
    fork_join_map
}

/*
 * Find fork/join nests that each control node is inside of. Result is a map
 * from each control node to a list of fork nodes. The fork nodes are listed in
 * ascending order of nesting.
 */
pub fn compute_fork_join_nesting(
    function: &Function,
    dom: &DomTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, Vec<NodeID>> {
    // For each control node, ascend dominator tree, looking for fork nodes. For
    // each fork node, make sure each control node isn't strictly dominated by
    // the corresponding join node.
    (0..function.nodes.len())
        .map(NodeID::new)
        .filter(|id| dom.contains(*id))
        .map(|id| {
            (
                id,
                dom.ascend(id)
                    // Filter for forks that dominate this control node,
                    .filter(|id| function.nodes[id.idx()].is_fork())
                    // where its corresponding join doesn't dominate the control
                    // node (if so, then this control is after the fork-join).
                    .filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id))
                    .collect(),
            )
        })
        .collect()
}

/*
 * Top level function to calculate reduce cycles. Returns for each reduce node
 * what other nodes form a cycle with that reduce node. The strict definition of
 * a reduce cycle is the iterated set of uses of the `reduct` input to the
 * reduce that through iterated uses use the reduce, without going through phis
 * or reduces outside the fork-join.
 */
pub fn reduce_cycles(
    function: &Function,
    fork_join_map: &HashMap<NodeID, NodeID>,
    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
    let reduces = (0..function.nodes.len())
        .filter(|idx| function.nodes[*idx].is_reduce())
        .map(NodeID::new);
    let mut result = HashMap::new();
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();

    for reduce in reduces {
        let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
        let fork = join_fork_map[&join];

        // DFS the uses of `reduct` until finding the reduce itself.
        let mut current_visited = HashSet::new();
        let mut in_cycle = HashSet::new();
        reduce_cycle_dfs_helper(
            function,
            reduct,
            fork,
            reduce,
            &mut current_visited,
            &mut in_cycle,
            fork_join_nest,
        );
        result.insert(reduce, in_cycle);
    }

    result
}

fn reduce_cycle_dfs_helper(
    function: &Function,
    iter: NodeID,
    fork: NodeID,
    reduce: NodeID,
    current_visited: &mut HashSet<NodeID>,
    in_cycle: &mut HashSet<NodeID>,
    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> bool {
    let isnt_outside_fork_join = |id: NodeID| {
        let node = &function.nodes[id.idx()];
        node.try_phi()
            .map(|(control, _)| control)
            .or(node.try_reduce().map(|(control, _, _)| control))
            .map(|control| fork_join_nest[&control].contains(&fork))
            .unwrap_or(true)
    };

    if iter == reduce || in_cycle.contains(&iter) {
        return true;
    }

    current_visited.insert(iter);
    let mut found_reduce = false;

    // This doesn't short circuit on purpose.
    for u in get_uses(&function.nodes[iter.idx()]).as_ref() {
        found_reduce |= !current_visited.contains(u)
            && !function.nodes[u.idx()].is_control()
            && isnt_outside_fork_join(*u)
            && reduce_cycle_dfs_helper(
                function,
                *u,
                fork,
                reduce,
                current_visited,
                in_cycle,
                fork_join_nest,
            )
    }
    if found_reduce {
        in_cycle.insert(iter);
    }
    current_visited.remove(&iter);
    found_reduce
}

/*
 * Top level function to calculate which data nodes are "inside" a fork-join,
 * not including its reduces.
 */
pub fn data_nodes_in_fork_joins(
    function: &Function,
    def_use: &ImmutableDefUseMap,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, HashSet<NodeID>> {
    let mut result = HashMap::new();

    for (fork, join) in fork_join_map {
        let mut worklist = vec![*fork];
        let mut set = HashSet::new();

        while let Some(item) = worklist.pop() {
            for u in def_use.get_users(item) {
                if function.nodes[u.idx()].is_control()
                    || function.nodes[u.idx()]
                        .try_reduce()
                        .map(|(control, _, _)| control == *join)
                        .unwrap_or(false)
                {
                    // Ignore control users and reduces of the fork-join.
                    continue;
                }
                if !set.contains(u) {
                    set.insert(*u);
                    worklist.push(*u);
                }
            }
        }

        result.insert(*fork, set);
    }

    result
}