Skip to content
Snippets Groups Projects
fork_join_analysis.rs 7.85 KiB
use std::collections::{HashMap, HashSet};

use crate::*;

/*
 * Top level function for creating a fork-join map. Map is from fork node ID to
 * join node ID, since a join can easily determine the fork it corresponds to
 * (that's the mechanism used to implement this analysis). This analysis depends
 * on type information.
 */
pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> {
    let mut fork_join_map = HashMap::new();
    for idx in 0..function.nodes.len() {
        // We only care about join nodes.
        if function.nodes[idx].is_join() {
            // Iterate the control predecessors until finding a fork. Maintain a
            // counter of unmatched fork-join pairs seen on the way, since fork-
            // joins may be nested. Every join is dominated by their fork, so
            // just iterate the first unseen predecessor of each control node.
            let join_id = NodeID::new(idx);
            let mut unpaired = 0;
            let mut cursor = join_id;
            let mut seen = HashSet::<NodeID>::new();
            let fork_id = loop {
                cursor = control
                    .preds(cursor)
                    .filter(|pred| !seen.contains(pred))
                    .next()
                    .unwrap();
                seen.insert(cursor);

                if function.nodes[cursor.idx()].is_join() {
                    unpaired += 1;
                } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 {
                    unpaired -= 1;
                } else if function.nodes[cursor.idx()].is_fork() {
                    break cursor;
                }
            };
            fork_join_map.insert(fork_id, join_id);
        }
    }
    fork_join_map
}

/*
 * Find fork/join nests that each control node is inside of. Result is a map
 * from each control node to a list of fork nodes. The fork nodes are listed in
 * ascending order of nesting.
 */
pub fn compute_fork_join_nesting(
    function: &Function,
    dom: &DomTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, Vec<NodeID>> {
    // For each control node, ascend dominator tree, looking for fork nodes. For
    // each fork node, make sure each control node isn't strictly dominated by
    // the corresponding join node.
    (0..function.nodes.len())
        .map(NodeID::new)
        .filter(|id| dom.contains(*id))
        .map(|id| {
            (
                id,
                dom.ascend(id)
                    // Filter for forks that dominate this control node,
                    .filter(|id| function.nodes[id.idx()].is_fork())
                    // where its corresponding join doesn't dominate the control
                    // node (if so, then this control is after the fork-join).
                    // Check for strict dominance since the join itself should
                    // be nested in the fork.
                    .filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id))
                    .collect(),
            )
        })
        .collect()
}

/*
 * Top level function to calculate reduce cycles. Returns for each reduce node
 * what other nodes form a cycle with that reduce node. The strict definition of
 * a reduce cycle is the iterated set of uses of the `reduct` input to the
 * reduce that through iterated uses use the reduce, without going through phis
 * or reduces outside the fork-join.
 */
pub fn reduce_cycles(
    function: &Function,
    def_use: &ImmutableDefUseMap,
    fork_join_map: &HashMap<NodeID, NodeID>,
    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
    let reduces = (0..function.nodes.len())
        .filter(|idx| function.nodes[*idx].is_reduce())
        .map(NodeID::new);
    let mut result = HashMap::new();
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();

    for reduce in reduces {
        let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
        let fork = join_fork_map[&join];
        let might_be_in_fork_join = |id| {
            fork_join_nest
                .get(&id)
                .map(|nest| nest.contains(&fork))
                .unwrap_or(true)
                && function.nodes[id.idx()]
                    .try_phi()
                    .map(|(control, _)| fork_join_nest[&control].contains(&fork))
                    .unwrap_or(true)
                && function.nodes[id.idx()]
                    .try_reduce()
                    .map(|(control, _, _)| fork_join_nest[&control].contains(&fork))
                    .unwrap_or(true)
                && !function.nodes[id.idx()].is_parameter()
                && !function.nodes[id.idx()].is_constant()
                && !function.nodes[id.idx()].is_dynamic_constant()
                && !function.nodes[id.idx()].is_undef()
        };

        // Find nodes in the fork-join that the reduce can reach through uses.
        let mut reachable_uses = HashSet::new();
        let mut workset = vec![];
        reachable_uses.insert(reduct);
        workset.push(reduct);
        while let Some(pop) = workset.pop() {
            for u in get_uses(&function.nodes[pop.idx()]).as_ref() {
                if !reachable_uses.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
                    reachable_uses.insert(*u);
                    workset.push(*u);
                }
            }
        }

        // Find nodes in the fork-join that the reduce can reach through users.
        let mut reachable_users = HashSet::new();
        workset.clear();
        reachable_users.insert(reduce);
        workset.push(reduce);
        while let Some(pop) = workset.pop() {
            for u in def_use.get_users(pop) {
                if !reachable_users.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
                    reachable_users.insert(*u);
                    workset.push(*u);
                }
            }
        }

        // The reduce cycle is the insersection of nodes reachable through uses
        // and users.
        let intersection = reachable_uses
            .intersection(&reachable_users)
            .map(|id| *id)
            .collect();
        result.insert(reduce, intersection);
    }

    result
}

/*
 * Top level function to calculate which nodes are "inside" a fork-join.
 */
pub fn nodes_in_fork_joins(
    function: &Function,
    def_use: &ImmutableDefUseMap,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
    let mut result = HashMap::new();

    // Iterate users of fork until reaching corresponding join or reduces.
    for (fork, join) in fork_join_map {
        let mut worklist = vec![*fork];
        let mut set = HashSet::new();
        set.insert(*fork);

        // Iterate uses of the fork.
        while let Some(item) = worklist.pop() {
            for u in def_use.get_users(item) {
                let terminate = *u == *join
                    || function.nodes[u.idx()]
                        .try_reduce()
                        .map(|(control, _, _)| control == *join)
                        .unwrap_or(false);
                if !set.contains(u) && !terminate {
                    worklist.push(*u);
                }
                set.insert(*u);
            }
        }
        assert!(set.contains(join));

        // Add all the nodes in the reduce cycle. Some of these nodes may not
        // use thread IDs of the fork, so do this explicitly.
        for u in def_use.get_users(*join) {
            if let Some(cycle) = reduce_cycles.get(u) {
                set.extend(cycle);
            }
        }

        result.insert(*fork, set);
    }

    // Add an entry for the start node containing every node.
    result.insert(
        NodeID::new(0),
        (0..function.nodes.len()).map(NodeID::new).collect(),
    );

    result
}