Skip to content
Snippets Groups Projects
gcm.rs 12.59 KiB
use std::collections::{HashMap, HashSet, VecDeque};
use std::iter::{zip, FromIterator};

use crate::*;

/*
 * Top level global code motion function. Assigns each data node to one of its
 * immediate control use / user nodes, forming (unordered) basic blocks. Returns
 * the control node / basic block each node is in. Takes in a partial
 * partitioning that must be respected. Based on the schedule-early-schedule-
 * late method from Cliff Click's PhD thesis.
 */
pub fn gcm(
    function: &Function,
    def_use: &ImmutableDefUseMap,
    reverse_postorder: &Vec<NodeID>,
    dom: &DomTree,
    antideps: &Vec<(NodeID, NodeID)>,
    loops: &LoopTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
    partial_partition: &mut Vec<Option<PartitionID>>,
) -> Vec<NodeID> {
    let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];

    // Step 1: assign the basic block locations of all nodes that must be in a
    // specific block. This includes control nodes as well as some special data
    // nodes, such as phis.
    for idx in 0..function.nodes.len() {
        match function.nodes[idx] {
            Node::Phi { control, data: _ } => bbs[idx] = Some(control),
            Node::ThreadID {
                control,
                dimension: _,
            } => bbs[idx] = Some(control),
            Node::Reduce {
                control,
                init: _,
                reduct: _,
            } => bbs[idx] = Some(control),
            Node::Call {
                control,
                function: _,
                dynamic_constants: _,
                args: _,
            } => bbs[idx] = Some(control),
            Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
            Node::Constant { id: _ } => bbs[idx] = Some(NodeID::new(0)),
            Node::DynamicConstant { id: _ } => bbs[idx] = Some(NodeID::new(0)),
            _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
            _ => {}
        }
    }

    // Step 2: schedule early. Place nodes in the earliest position they could
    // go - use worklist to iterate nodes.
    let mut schedule_early = bbs.clone();
    let mut antideps_uses = HashMap::<NodeID, Vec<NodeID>>::new();
    for (read, write) in antideps {
        antideps_uses.entry(*write).or_default().push(*read);
    }
    let mut worklist = VecDeque::from(reverse_postorder.clone());
    while let Some(id) = worklist.pop_front() {
        if schedule_early[id.idx()].is_some() {
            continue;
        }

        // For every use, check what block is its "schedule early" block. This
        // node goes in the lowest block amongst those blocks.
        let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()])
            .as_ref()
            .into_iter()
            .map(|id| *id)
            // Include "uses" from anti-dependencies.
            .chain(
                antideps_uses
                    .get(&id)
                    .unwrap_or(&vec![])
                    .into_iter()
                    .map(|id| *id),
            )
            .map(|id| schedule_early[id.idx()])
            .collect();
        if let Some(use_places) = use_places {
            // If every use has been placed, we can place this node as the
            // lowest place in the domtree that dominates all of the use places.
            let lowest = dom.lowest_amongst(use_places.into_iter());
            schedule_early[id.idx()] = Some(lowest);
        } else {
            // If not, then just push this node back on the worklist.
            worklist.push_back(id);
        }
    }

    // Step 3: schedule late and pick each nodes final position. Since the late
    // schedule of each node depends on the final positions of its users, these
    // two steps must be fused. Compute their latest position, then use the
    // control dependent + shallow loop heuristic to actually place them.
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();
    let mut antideps_users = HashMap::<NodeID, Vec<NodeID>>::new();
    for (read, write) in antideps {
        antideps_users.entry(*read).or_default().push(*write);
    }
    let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
    while let Some(id) = worklist.pop_front() {
        if bbs[id.idx()].is_some() {
            continue;
        }

        // Calculate the least common ancestor of user blocks, a.k.a. the "late"
        // schedule.
        let calculate_lca = || -> Option<_> {
            let mut lca = None;
            // Helper to incrementally update the LCA.
            let mut update_lca = |a| {
                if let Some(acc) = lca {
                    lca = Some(dom.least_common_ancestor(acc, a));
                } else {
                    lca = Some(a);
                }
            };

            // For every user, consider where we need to be to directly dominate the
            // user.
            for user in def_use
                .get_users(id)
                .as_ref()
                .into_iter()
                .map(|id| *id)
                // Include "users" from anti-dependencies.
                .chain(
                    antideps_users
                        .get(&id)
                        .unwrap_or(&vec![])
                        .into_iter()
                        .map(|id| *id),
                )
            {
                if let Node::Phi { control, data } = &function.nodes[user.idx()] {
                    // For phis, we need to dominate the block jumping to the phi in
                    // the slot that corresponds to our use.
                    for (control, data) in
                        zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
                    {
                        if id == *data {
                            update_lca(*control);
                        }
                    }
                } else if let Node::Reduce {
                    control,
                    init,
                    reduct,
                } = &function.nodes[user.idx()]
                {
                    // For reduces, we need to either dominate the block right
                    // before the fork if we're the init input, or we need to
                    // dominate the join if we're the reduct input.
                    if id == *init {
                        let before_fork = function.nodes[join_fork_map[control].idx()]
                            .try_fork()
                            .unwrap()
                            .0;
                        update_lca(before_fork);
                    } else {
                        assert_eq!(id, *reduct);
                        update_lca(*control);
                    }
                } else {
                    // For everything else, we just need to dominate the user.
                    update_lca(bbs[user.idx()]?);
                }
            }

            Some(lca)
        };

        // Check if all users have been placed. If one of them hasn't, then add
        // this node back on to the worklist.
        let Some(lca) = calculate_lca() else {
            worklist.push_back(id);
            continue;
        };

        // #[feature(iter_collect_into)]

        // Look between the LCA and the schedule early location to place the
        // node. If a data node can't be scheduled to any control nodes in its
        // partition (this may happen if all of the control nodes in a partition
        // got deleted, for example), then the data node can be scheduled into a
        // control node in a different partition.
        let schedule_early = schedule_early[id.idx()].unwrap();
        let need_to_repartition = !dom
            .chain(lca.unwrap_or(schedule_early), schedule_early)
            .any(|dominator| {
                partial_partition[id.idx()].is_none()
                    || partial_partition[dominator.idx()] == partial_partition[id.idx()]
            });
        if need_to_repartition {
            partial_partition[id.idx()] = None;
        }
        let mut chain = dom
            // If the node has no users, then it doesn't really matter where we
            // place it - just place it at the early placement.
            .chain(lca.unwrap_or(schedule_early), schedule_early)
            // Filter the dominator chain by control nodes in the same partition
            // as this data node, if the data node is in a partition already.
            .filter(|dominator| {
                partial_partition[id.idx()].is_none()
                    || partial_partition[dominator.idx()] == partial_partition[id.idx()]
            });
        let mut location = chain.next().unwrap();
        while let Some(control_node) = chain.next() {
            // If the next node further up the dominator tree is in a shallower
            // loop nest or if we can get out of a reduce loop when we don't
            // need to be in one, place this data node in a higher-up location.
            let shallower_nest = if let (Some(old_nest), Some(new_nest)) =
                (loops.nesting(location), loops.nesting(control_node))
            {
                old_nest > new_nest
            } else {
                false
            };
            // This will move all nodes that don't need to be in reduce loops
            // outside of reduce loops. Nodes that do need to be in a reduce
            // loop use the reduce node forming the loop, so the dominator chain
            // will consist of one block, and this loop won't ever iterate.
            let currently_at_join = function.nodes[location.idx()].is_join();
            if shallower_nest || currently_at_join {
                location = control_node;
            }
        }

        bbs[id.idx()] = Some(location);
    }

    bbs.into_iter().map(Option::unwrap).collect()
}

/*
 * Top level function for creating a fork-join map. Map is from fork node ID to
 * join node ID, since a join can easily determine the fork it corresponds to
 * (that's the mechanism used to implement this analysis). This analysis depends
 * on type information.
 */
pub fn fork_join_map(function: &Function, control: &Subgraph) -> HashMap<NodeID, NodeID> {
    let mut fork_join_map = HashMap::new();
    for idx in 0..function.nodes.len() {
        // We only care about join nodes.
        if function.nodes[idx].is_join() {
            // Iterate the control predecessors until finding a fork. Maintain a
            // counter of unmatched fork-join pairs seen on the way, since fork-
            // joins may be nested. Every join is dominated by their fork, so
            // just iterate the first unseen predecessor of each control node.
            let join_id = NodeID::new(idx);
            let mut unpaired = 0;
            let mut cursor = join_id;
            let mut seen = HashSet::<NodeID>::new();
            let fork_id = loop {
                cursor = control
                    .preds(cursor)
                    .filter(|pred| !seen.contains(pred))
                    .next()
                    .unwrap();
                seen.insert(cursor);

                if function.nodes[cursor.idx()].is_join() {
                    unpaired += 1;
                } else if function.nodes[cursor.idx()].is_fork() && unpaired > 0 {
                    unpaired -= 1;
                } else if function.nodes[cursor.idx()].is_fork() {
                    break cursor;
                }
            };
            fork_join_map.insert(fork_id, join_id);
        }
    }
    fork_join_map
}

/*
 * Find fork/join nests that each control node is inside of. Result is a map
 * from each control node to a list of fork nodes. The fork nodes are listed in
 * ascending order of nesting.
 */
pub fn compute_fork_join_nesting(
    function: &Function,
    dom: &DomTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> HashMap<NodeID, Vec<NodeID>> {
    // For each control node, ascend dominator tree, looking for fork nodes. For
    // each fork node, make sure each control node isn't strictly dominated by
    // the corresponding join node.
    (0..function.nodes.len())
        .map(NodeID::new)
        .filter(|id| dom.contains(*id))
        .map(|id| {
            (
                id,
                dom.ascend(id)
                    // Filter for forks that dominate this control node,
                    .filter(|id| function.nodes[id.idx()].is_fork())
                    // where its corresponding join doesn't dominate the control
                    // node (if so, then this control is after the fork-join).
                    .filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id))
                    .collect(),
            )
        })
        .collect()
}