Skip to content
Snippets Groups Projects
gcm.rs 72.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::cell::Ref;
    
    use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
    
    rarbore2's avatar
    rarbore2 committed
    use std::iter::{empty, once, zip, FromIterator};
    
    
    rarbore2's avatar
    rarbore2 committed
    use bitvec::prelude::*;
    use either::Either;
    
    use union_find::{QuickFindUf, UnionBySize, UnionFind};
    
    rarbore2's avatar
    rarbore2 committed
    
    
    use hercules_cg::*;
    
    rarbore2's avatar
    rarbore2 committed
    use hercules_ir::*;
    
    rarbore2's avatar
    rarbore2 committed
    
    use crate::*;
    
    /*
     * Top level function to legalize the reference semantics of a Hercules IR
     * function. Hercules IR is a value semantics representation, meaning that all
     * program state is in the form of copyable values, and mutation takes place by
     * making a new value that is a copy of the old value with some modification.
     * This representation is extremely convenient for optimization, but is not good
     * for code generation, where we need to generate code with references to get
     * good performance. Hercules IR can alternatively be interpreted using
     * reference semantics, where pointers to collection objects are passed around,
     * read from, and written to. However, the value semantics and reference
     * semantics interpretation of a Hercules IR function may not be equal - this
     * pass transforms a Hercules IR function such that its new value semantics is
     * the same as its old value semantics and that its new reference semantics is
     * the same as its new value semantics. This pass returns a placement of nodes
     * into ordered basic blocks, since the reference semantics of a function
     * depends on the order of execution with respect to anti-dependencies. This
     * is analogous to global code motion from the original sea of nodes paper.
     *
     * Our strategy for handling multiple mutating users of a collection is to treat
     * the problem similar to register allocation; we perform a liveness analysis,
     * spill constants into newly allocated constants, and read back the spilled
     * contents when they are used after the first mutation. It's not obvious how
     * many spills are needed upfront, and newly spilled constants may affect the
     * liveness analysis result, so every spill restarts the process of checking for
     * spills. Once no more spills are found, the process terminates. When a spill
     * is found, the basic block assignments, and all the other analyses, are not
    
     * necessarily valid anymore, so this function is called in a loop in the pass
     * manager until no more spills are found.
     *
    
     * GCM also generally tries to massage the code to be properly formed for the
     * device backends in other ways. For example, reduction cycles through the
     * `init` inputs of an inner reduce and a use of a non-parallel outer reduce in
     * nested fork-joins is not schedulable, since the outer reduce doesn't dominate
     * the fork corresponding to the inner reduce. In such cases, the outer fork-
     * join must be split and unforkified.
     *
    
     * GCM is additionally complicated by the need to generate code that references
     * objects across multiple devices. In particular, GCM makes sure that every
     * object lives on exactly one device, so that references to that object always
     * live on a single device. Additionally, GCM makes sure that the objects that a
     * node may produce are all on the same device, so that a pointer produced by,
     * for example, a select node can only refer to memory on a single device. Extra
     * collection constants and potentially inter-device copies are inserted as
     * necessary to make sure this is true - an inter-device copy is represented by
     * a write where the `collect` and `data` inputs are on different devices. This
     * is only valid in RT functions - it is asserted that this isn't necessary in
     * device functions. This process "colors" the nodes in the function.
     *
     * GCM has one final responsibility - object allocation. Each Hercules function
     * receives a pointer to a "backing" memory where collection constants live. The
     * backing memory a function receives is for the constants in that function and
     * the constants of every called function. Concretely, a function will pass a
     * sub-regions of its backing memory to a callee, which during the call is that
     * function's backing memory. Object allocation consists of finding the required
     * sizes of all collection constants and functions in terms of dynamic constants
     * (dynamic constant math is expressive enough to represent sizes of types,
     * which is very convenient) and determining the concrete offsets into the
     * backing memory where constants and callee sub-regions live. When two users of
     * backing memory are never live at once, they may share backing memory. This is
     * done after nodes are given a single device color, since we need to know what
     * values are on what devices before we can allocate them to backing memory,
     * since there are separate backing memories per-device.
    
    rarbore2's avatar
    rarbore2 committed
     */
    pub fn gcm(
        editor: &mut FunctionEditor,
        def_use: &ImmutableDefUseMap,
        reverse_postorder: &Vec<NodeID>,
        typing: &Vec<TypeID>,
        control_subgraph: &Subgraph,
        dom: &DomTree,
        fork_join_map: &HashMap<NodeID, NodeID>,
    
    rarbore2's avatar
    rarbore2 committed
        fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
    
    rarbore2's avatar
    rarbore2 committed
        loops: &LoopTree,
    
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    
    rarbore2's avatar
    rarbore2 committed
        objects: &CollectionObjects,
    
        devices: &Vec<Device>,
    
    rarbore2's avatar
    rarbore2 committed
        node_colors: &NodeColors,
    
        backing_allocations: &BackingAllocations,
    ) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> {
    
        if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) {
            return None;
        }
    
    
    rarbore2's avatar
    rarbore2 committed
        let bbs = basic_blocks(
            editor.func(),
    
    rarbore2's avatar
    rarbore2 committed
            editor.get_types(),
    
    rarbore2's avatar
    rarbore2 committed
            editor.func_id(),
            def_use,
            reverse_postorder,
    
    rarbore2's avatar
    rarbore2 committed
            typing,
    
    rarbore2's avatar
    rarbore2 committed
            dom,
            loops,
    
    rarbore2's avatar
    rarbore2 committed
            fork_join_map,
            objects,
    
    rarbore2's avatar
    rarbore2 committed
        );
    
    
        let liveness = liveness_dataflow(
            editor.func(),
            editor.func_id(),
            control_subgraph,
            objects,
            &bbs,
        );
    
        if spill_clones(editor, typing, control_subgraph, objects, &bbs, &liveness) {
            return None;
    
    rarbore2's avatar
    rarbore2 committed
        }
    
    rarbore2's avatar
    rarbore2 committed
        if add_extra_collection_dims(
            editor,
            typing,
            fork_join_map,
            fork_join_nest,
            objects,
            devices,
            &bbs,
        ) {
            return None;
        }
    
    
    rarbore2's avatar
    rarbore2 committed
        let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else {
    
            return None;
        };
    
        let mut alignments = vec![];
        Ref::map(editor.get_types(), |types| {
            for idx in 0..types.len() {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                if types[idx].is_control() || types[idx].is_multireturn() {
    
                    alignments.push(0);
                } else {
                    alignments.push(get_type_alignment(types, TypeID::new(idx)));
                }
            }
            &()
        });
    
        let backing_allocation = object_allocation(
            editor,
            typing,
    
    rarbore2's avatar
    rarbore2 committed
            fork_join_nest,
    
            &node_colors,
            &alignments,
            &liveness,
            backing_allocations,
        );
    
        Some((bbs, node_colors, backing_allocation))
    
    /*
     * Do misc. fixups on the IR, such as unforkifying sequential outer forks with
     * problematic reduces.
     */
    fn preliminary_fixups(
        editor: &mut FunctionEditor,
        fork_join_map: &HashMap<NodeID, NodeID>,
        loops: &LoopTree,
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    ) -> bool {
        let nodes = &editor.func().nodes;
    
        let schedules = &editor.func().schedules;
    
        // Sequentialize non-parallel forks that contain problematic reduce cycles.
    
        for (reduce, cycle) in reduce_cycles {
    
            if !schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
                && cycle.into_iter().any(|id| nodes[id.idx()].is_reduce())
            {
    
                let join = nodes[reduce.idx()].try_reduce().unwrap().0;
                let fork = fork_join_map
                    .into_iter()
                    .filter(|(_, j)| join == **j)
                    .map(|(f, _)| *f)
                    .next()
                    .unwrap();
                let (forks, _) = split_fork(editor, fork, join, reduce_cycles).unwrap();
                if forks.len() > 1 {
                    return true;
                }
                unforkify(editor, fork, join, loops);
                return true;
            }
        }
    
    
        // Get rid of the backward edge on parallel reduces in fork-joins.
        for (_, join) in fork_join_map {
            let parallel_reduces: Vec<_> = editor
                .get_users(*join)
                .filter(|id| {
                    nodes[id.idx()].is_reduce()
                        && schedules[id.idx()].contains(&Schedule::ParallelReduce)
                })
                .collect();
            for reduce in parallel_reduces {
                if reduce_cycles[&reduce].is_empty() {
                    continue;
                }
                let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();
    
                // Replace uses of the reduce in its cycle with the init.
                let success = editor.edit(|edit| {
                    edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
                });
                assert!(success);
                return true;
            }
        }
    
    
    rarbore2's avatar
    rarbore2 committed
    /*
     * Top level global code motion function. Assigns each data node to one of its
     * immediate control use / user nodes, forming (unordered) basic blocks. Returns
     * the control node / basic block each node is in. Takes in a partial
     * partitioning that must be respected. Based on the schedule-early-schedule-
     * late method from Cliff Click's PhD thesis. May fail if an anti-dependency
     * edge can't be satisfied - in this case, a clone that has to be induced is
     * returned instead.
     */
    fn basic_blocks(
        function: &Function,
    
    rarbore2's avatar
    rarbore2 committed
        types: Ref<Vec<Type>>,
    
    rarbore2's avatar
    rarbore2 committed
        func_id: FunctionID,
        def_use: &ImmutableDefUseMap,
        reverse_postorder: &Vec<NodeID>,
    
    rarbore2's avatar
    rarbore2 committed
        typing: &Vec<TypeID>,
    
    rarbore2's avatar
    rarbore2 committed
        dom: &DomTree,
        loops: &LoopTree,
    
        reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    
    rarbore2's avatar
    rarbore2 committed
        fork_join_map: &HashMap<NodeID, NodeID>,
        objects: &CollectionObjects,
    
        devices: &Vec<Device>,
    
    rarbore2's avatar
    rarbore2 committed
    ) -> BasicBlocks {
        let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
    
        // Step 1: assign the basic block locations of all nodes that must be in a
        // specific block. This includes control nodes as well as some special data
        // nodes, such as phis.
        for idx in 0..function.nodes.len() {
            match function.nodes[idx] {
                Node::Phi { control, data: _ } => bbs[idx] = Some(control),
                Node::ThreadID {
                    control,
                    dimension: _,
                } => bbs[idx] = Some(control),
                Node::Reduce {
                    control,
                    init: _,
                    reduct: _,
                } => bbs[idx] = Some(control),
                Node::Call {
                    control,
                    function: _,
                    dynamic_constants: _,
                    args: _,
                } => bbs[idx] = Some(control),
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                Node::DataProjection { data, selection: _ } => {
                    let Node::Call { control, .. } = function.nodes[data.idx()] else {
                        panic!();
                    };
                    bbs[idx] = Some(control);
                }
    
    rarbore2's avatar
    rarbore2 committed
                Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
                _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
                _ => {}
            }
        }
    
        // Step 2: schedule early. Place nodes in the earliest position they could
        // go - use worklist to iterate nodes.
        let mut schedule_early = bbs.clone();
        let mut worklist = VecDeque::from(reverse_postorder.clone());
        while let Some(id) = worklist.pop_front() {
            if schedule_early[id.idx()].is_some() {
                continue;
            }
    
            // For every use, check what block is its "schedule early" block. This
            // node goes in the lowest block amongst those blocks.
            let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()])
                .as_ref()
                .into_iter()
                .map(|id| *id)
                .map(|id| schedule_early[id.idx()])
                .collect();
            if let Some(use_places) = use_places {
                // If every use has been placed, we can place this node as the
                // lowest place in the domtree that dominates all of the use places.
                let lowest = dom.lowest_amongst(use_places.into_iter());
                schedule_early[id.idx()] = Some(lowest);
            } else {
                // If not, then just push this node back on the worklist.
                worklist.push_back(id);
            }
        }
    
        // Step 3: find anti-dependence edges. An anti-dependence edge needs to be
        // drawn between a collection reading node and a collection mutating node
        // when the following conditions are true:
        //
        // 1: The reading and mutating nodes may involve the same collection.
        // 2: The node producing the collection used by the reading node is in a
        //    schedule early block that dominates the schedule early block of the
        //    mutating node. The node producing the collection used by the reading
        //    node may be an originator of a collection, phi or reduce, or mutator,
        //    but not forwarding read - forwarding reads are collapsed, and the
        //    bottom read is treated as reading from the transitive parent of the
        //    forwarding read(s).
    
        // 3: If the node producing the collection is a reduce node, then any read
        //    users that aren't in the reduce's cycle shouldn't anti-depend user any
        //    mutators in the reduce cycle.
    
    rarbore2's avatar
    rarbore2 committed
        //
        // Because we do a liveness analysis based spill of collections, anti-
        // dependencies can be best effort. Thus, when we encounter a read and
        // mutator where the read doesn't dominate the mutator, but an anti-depdence
        // edge is derived for the pair, we just don't draw the edge since it would
        // break the scheduler.
    
    rarbore2's avatar
    rarbore2 committed
        let mut antideps = BTreeSet::new();
        for id in reverse_postorder.iter() {
            // Find a terminating read node and the collections it reads.
            let terminating_reads: BTreeSet<_> =
                terminating_reads(function, func_id, *id, objects).collect();
            if !terminating_reads.is_empty() {
                // Walk forwarding reads to find anti-dependency roots.
                let mut workset = terminating_reads.clone();
                let mut roots = BTreeSet::new();
                while let Some(pop) = workset.pop_first() {
                    let forwarded: BTreeSet<_> =
                        forwarding_reads(function, func_id, pop, objects).collect();
                    if forwarded.is_empty() {
                        roots.insert(pop);
                    } else {
                        workset.extend(forwarded);
                    }
                }
    
                // For each root, find mutating nodes dominated by the root that
                // modify an object read on any input of the current node (the
                // terminating read).
                // TODO: make this less outrageously inefficient.
                let func_objects = &objects[&func_id];
                for root in roots.iter() {
    
                    let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles
                        .get(root)
                        .map(|cycle| !cycle.contains(&id))
                        .unwrap_or(false);
    
    rarbore2's avatar
    rarbore2 committed
                    let root_early = schedule_early[root.idx()].unwrap();
                    let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new();
                    let mut workset = BTreeSet::new();
                    workset.insert(*root);
                    while let Some(pop) = workset.pop_first() {
                        let users = def_use.get_users(pop).into_iter().filter(|user| {
                            !function.nodes[user.idx()].is_phi()
                                && !function.nodes[user.idx()].is_reduce()
                                && schedule_early[user.idx()].unwrap() == root_early
                        });
                        workset.extend(users.clone());
                        root_block_iterated_users.extend(users);
                    }
                    let read_objs: BTreeSet<_> = terminating_reads
                        .iter()
                        .map(|read_use| func_objects.objects(*read_use).into_iter())
                        .flatten()
                        .map(|id| *id)
                        .collect();
                    for mutator in reverse_postorder.iter() {
                        let mutator_early = schedule_early[mutator.idx()].unwrap();
                        if dom.does_dom(root_early, mutator_early)
                            && (root_early != mutator_early
                                || root_block_iterated_users.contains(&mutator))
                            && mutating_objects(function, func_id, *mutator, objects)
                                .any(|mutated| read_objs.contains(&mutated))
                            && id != mutator
    
                            && (!root_is_reduce_and_read_isnt_in_cycle
                                || !reduce_cycles
                                    .get(root)
                                    .map(|cycle| cycle.contains(mutator))
                                    .unwrap_or(false))
    
    rarbore2's avatar
    rarbore2 committed
                            && dom.does_dom(schedule_early[id.idx()].unwrap(), mutator_early)
    
    rarbore2's avatar
    rarbore2 committed
                        {
                            antideps.insert((*id, *mutator));
                        }
                    }
                }
            }
        }
        let mut antideps_uses = vec![vec![]; function.nodes.len()];
        let mut antideps_users = vec![vec![]; function.nodes.len()];
        for (reader, mutator) in antideps.iter() {
            antideps_uses[mutator.idx()].push(*reader);
            antideps_users[reader.idx()].push(*mutator);
        }
    
        // Step 4: schedule late and pick each nodes final position. Since the late
        // schedule of each node depends on the final positions of its users, these
        // two steps must be fused. Compute their latest position, then use the
        // control dependent + shallow loop heuristic to actually place them. A
        // placement might not necessarily be found due to anti-dependency edges.
        // These are optional and not necessary to consider, but we do since obeying
        // them can reduce the number of clones. If the worklist stops making
        // progress, stop considering the anti-dependency edges.
        let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
            .into_iter()
            .map(|(fork, join)| (*join, *fork))
            .collect();
        let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
        let mut num_skip_iters = 0;
        let mut consider_antidependencies = true;
        while let Some(id) = worklist.pop_front() {
            if num_skip_iters >= worklist.len() {
                consider_antidependencies = false;
            }
    
            if bbs[id.idx()].is_some() {
                num_skip_iters = 0;
                continue;
            }
    
            // Calculate the least common ancestor of user blocks, a.k.a. the "late"
            // schedule.
            let calculate_lca = || -> Option<_> {
                let mut lca = None;
                // Helper to incrementally update the LCA.
                let mut update_lca = |a| {
                    if let Some(acc) = lca {
                        lca = Some(dom.least_common_ancestor(acc, a));
                    } else {
                        lca = Some(a);
                    }
                };
    
                // For every user, consider where we need to be to directly dominate the
                // user.
                for user in def_use
                    .get_users(id)
                    .as_ref()
                    .into_iter()
                    .chain(if consider_antidependencies {
                        Either::Left(antideps_users[id.idx()].iter())
                    } else {
                        Either::Right(empty())
                    })
                    .map(|id| *id)
                {
                    if let Node::Phi { control, data } = &function.nodes[user.idx()] {
                        // For phis, we need to dominate the block jumping to the phi in
                        // the slot that corresponds to our use.
                        for (control, data) in
                            zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
                        {
                            if id == *data {
                                update_lca(*control);
                            }
                        }
                    } else if let Node::Reduce {
                        control,
                        init,
                        reduct,
                    } = &function.nodes[user.idx()]
                    {
                        // For reduces, we need to either dominate the block right
                        // before the fork if we're the init input, or we need to
                        // dominate the join if we're the reduct input.
                        if id == *init {
                            let before_fork = function.nodes[join_fork_map[control].idx()]
                                .try_fork()
                                .unwrap()
                                .0;
                            update_lca(before_fork);
                        } else {
                            assert_eq!(id, *reduct);
                            update_lca(*control);
                        }
                    } else {
                        // For everything else, we just need to dominate the user.
                        update_lca(bbs[user.idx()]?);
                    }
                }
    
                Some(lca)
            };
    
            // Check if all users have been placed. If one of them hasn't, then add
            // this node back on to the worklist.
            let Some(lca) = calculate_lca() else {
                worklist.push_back(id);
                num_skip_iters += 1;
                continue;
            };
    
            // Look between the LCA and the schedule early location to place the
            // node.
            let schedule_early = schedule_early[id.idx()].unwrap();
    
            // If the node has no users, then it doesn't really matter where we
            // place it - just place it at the early placement.
    
    rarbore2's avatar
    rarbore2 committed
            let schedule_late = lca.unwrap_or(schedule_early);
    
            let mut chain = dom.chain(schedule_late, schedule_early);
    
    rarbore2's avatar
    rarbore2 committed
    
            if let Some(mut location) = chain.next() {
                while let Some(control_node) = chain.next() {
                    // If the next node further up the dominator tree is in a shallower
                    // loop nest or if we can get out of a reduce loop when we don't
                    // need to be in one, place this data node in a higher-up location.
    
                    // Only do this is the node isn't a constant or undef - if a
                    // node is a constant or undef, we want its placement to be as
                    // control dependent as possible, even inside loops. In GPU
                    // functions specifically, lift constants that may be returned
                    // outside fork-joins.
    
    rarbore2's avatar
    rarbore2 committed
                    let is_constant_or_undef = (function.nodes[id.idx()].is_constant()
                        || function.nodes[id.idx()].is_undef())
                        && !types[typing[id.idx()].idx()].is_primitive();
    
                    let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                        && objects[&func_id].objects(id).into_iter().any(|obj| {
                            objects[&func_id]
                                .all_returned_objects()
                                .any(|ret| ret == *obj)
                        });
    
    rarbore2's avatar
    rarbore2 committed
                    let old_nest = loops
                        .header_of(location)
                        .map(|header| loops.nesting(header).unwrap());
                    let new_nest = loops
                        .header_of(control_node)
                        .map(|header| loops.nesting(header).unwrap());
                    let shallower_nest = if let (Some(old_nest), Some(new_nest)) = (old_nest, new_nest)
                    {
                        old_nest > new_nest
                    } else {
                        // If the new location isn't a loop, it's nesting level should
                        // be considered "shallower" if the current location is in a
                        // loop.
                        old_nest.is_some()
                    };
                    // This will move all nodes that don't need to be in reduce loops
                    // outside of reduce loops. Nodes that do need to be in a reduce
                    // loop use the reduce node forming the loop, so the dominator chain
                    // will consist of one block, and this loop won't ever iterate.
    
                    let currently_at_join = function.nodes[location.idx()].is_join()
                        && !function.nodes[control_node.idx()].is_join();
    
    
                    if (!is_constant_or_undef || is_gpu_returned)
                        && (shallower_nest || currently_at_join)
                    {
    
    rarbore2's avatar
    rarbore2 committed
                        location = control_node;
                    }
                }
    
                bbs[id.idx()] = Some(location);
                num_skip_iters = 0;
            } else {
                // If there is no valid location for this node, then it's a reading
                // node of a collection that can't be placed above a mutation that
                // anti-depend uses it. Push the node back on the list, and we'll
                // stop considering anti-dependencies soon. Don't immediately stop
                // considering anti-dependencies, as we may be able to eak out some
                // more use of them.
                worklist.push_back(id);
                num_skip_iters += 1;
                continue;
            }
        }
        let bbs: Vec<_> = bbs.into_iter().map(Option::unwrap).collect();
        // Calculate the number of phis and reduces per basic block. We use this to
        // emit phis and reduces at the top of basic blocks. We want to emit phis
        // and reduces first into ordered basic blocks for two reasons:
        // 1. This is useful for liveness analysis.
        // 2. This is needed for some backends - LLVM expects phis to be at the top
        //    of basic blocks.
        let mut num_phis_reduces = vec![0; function.nodes.len()];
        for (node_idx, bb) in bbs.iter().enumerate() {
            let node = &function.nodes[node_idx];
            if node.is_phi() || node.is_reduce() {
                num_phis_reduces[bb.idx()] += 1;
            }
        }
    
        // Step 5: determine the order of nodes inside each block. Use worklist to
        // add nodes to blocks in order that obeys dependencies.
        let mut order: Vec<Vec<NodeID>> = vec![vec![]; function.nodes.len()];
        let mut worklist = VecDeque::from_iter(
            reverse_postorder
                .into_iter()
                .filter(|id| !function.nodes[id.idx()].is_control()),
        );
        let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
        let mut num_skip_iters = 0;
        let mut consider_antidependencies = true;
        while let Some(id) = worklist.pop_front() {
            // If the worklist isn't making progress, then there's at least one
            // reading node of a collection that is in a anti-depend + normal depend
            // use cycle with a mutating node. See above comment about anti-
            // dependencies being optional; we just stop considering them here.
            if num_skip_iters >= worklist.len() {
                consider_antidependencies = false;
            }
    
            // Phis and reduces always get emitted. Other nodes need to obey
            // dependency relationships and need to come after phis and reduces.
            let node = &function.nodes[id.idx()];
            let bb = bbs[id.idx()];
            if node.is_phi()
                || node.is_reduce()
                || (num_phis_reduces[bb.idx()] == 0
                    && get_uses(node)
                        .as_ref()
                        .into_iter()
                        .chain(if consider_antidependencies {
                            Either::Left(antideps_uses[id.idx()].iter())
                        } else {
                            Either::Right(empty())
                        })
                        .all(|u| {
                            function.nodes[u.idx()].is_control()
                                || bbs[u.idx()] != bbs[id.idx()]
                                || visited[u.idx()]
                        }))
            {
                order[bb.idx()].push(*id);
                visited.set(id.idx(), true);
                num_skip_iters = 0;
                if node.is_phi() || node.is_reduce() {
                    num_phis_reduces[bb.idx()] -= 1;
                }
            } else {
                worklist.push_back(id);
                num_skip_iters += 1;
            }
        }
    
        (bbs, order)
    }
    
    fn terminating_reads<'a>(
        function: &'a Function,
        func_id: FunctionID,
        reader: NodeID,
        objects: &'a CollectionObjects,
    ) -> Box<dyn Iterator<Item = NodeID> + 'a> {
        match function.nodes[reader.idx()] {
            Node::Read {
                collect,
                indices: _,
            } if objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)),
            Node::Write {
                collect: _,
                data,
                indices: _,
            } if !objects[&func_id].objects(data).is_empty() => Box::new(once(data)),
            Node::Call {
                control: _,
                function: callee,
                dynamic_constants: _,
                ref args,
            } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
                let objects = &objects[&callee];
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                let mut returns = objects.all_returned_objects();
    
    rarbore2's avatar
    rarbore2 committed
                let param_obj = objects.param_to_object(idx)?;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                if !objects.is_mutated(param_obj) && !returns.any(|ret| ret == param_obj) {
    
    rarbore2's avatar
    rarbore2 committed
                    Some(*arg)
                } else {
                    None
                }
            })),
    
            Node::LibraryCall {
                library_function,
                ref args,
                ty: _,
                device: _,
            } => match library_function {
                LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
            },
    
    rarbore2's avatar
    rarbore2 committed
            _ => Box::new(empty()),
        }
    }
    
    fn forwarding_reads<'a>(
        function: &'a Function,
        func_id: FunctionID,
        reader: NodeID,
        objects: &'a CollectionObjects,
    ) -> Box<dyn Iterator<Item = NodeID> + 'a> {
        match function.nodes[reader.idx()] {
            Node::Read {
                collect,
                indices: _,
            } if !objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)),
            Node::Ternary {
                op: TernaryOperator::Select,
                first: _,
                second,
                third,
            } if !objects[&func_id].objects(reader).is_empty() => {
                Box::new(once(second).chain(once(third)))
            }
            Node::Call {
                control: _,
                function: callee,
                dynamic_constants: _,
                ref args,
            } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
                let objects = &objects[&callee];
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                let mut returns = objects.all_returned_objects();
    
    rarbore2's avatar
    rarbore2 committed
                let param_obj = objects.param_to_object(idx)?;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                if !objects.is_mutated(param_obj) && returns.any(|ret| ret == param_obj) {
    
    rarbore2's avatar
    rarbore2 committed
                    Some(*arg)
                } else {
                    None
                }
            })),
            _ => Box::new(empty()),
        }
    }
    
    fn mutating_objects<'a>(
        function: &'a Function,
        func_id: FunctionID,
        mutator: NodeID,
        objects: &'a CollectionObjects,
    ) -> Box<dyn Iterator<Item = CollectionObjectID> + 'a> {
        match function.nodes[mutator.idx()] {
            Node::Write {
                collect,
                data: _,
                indices: _,
            } => Box::new(objects[&func_id].objects(collect).into_iter().map(|id| *id)),
            Node::Call {
                control: _,
                function: callee,
                dynamic_constants: _,
                ref args,
            } => Box::new(
                args.into_iter()
                    .enumerate()
                    .filter_map(move |(idx, arg)| {
                        let callee_objects = &objects[&callee];
                        let param_obj = callee_objects.param_to_object(idx)?;
                        if callee_objects.is_mutated(param_obj) {
                            Some(objects[&func_id].objects(*arg).into_iter().map(|id| *id))
                        } else {
                            None
                        }
                    })
                    .flatten(),
            ),
    
            Node::LibraryCall {
                library_function,
                ref args,
                ty: _,
                device: _,
            } => match library_function {
                LibraryFunction::GEMM => {
                    Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
                }
            },
    
    rarbore2's avatar
    rarbore2 committed
            _ => Box::new(empty()),
        }
    }
    
    
    fn mutating_writes<'a>(
        function: &'a Function,
        mutator: NodeID,
        objects: &'a CollectionObjects,
    ) -> Box<dyn Iterator<Item = NodeID> + 'a> {
        match function.nodes[mutator.idx()] {
            Node::Write {
                collect,
                data: _,
                indices: _,
            } => Box::new(once(collect)),
            Node::Call {
                control: _,
                function: callee,
                dynamic_constants: _,
                ref args,
            } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
                let callee_objects = &objects[&callee];
                let param_obj = callee_objects.param_to_object(idx)?;
                if callee_objects.is_mutated(param_obj) {
                    Some(*arg)
                } else {
                    None
                }
            })),
    
            Node::LibraryCall {
                library_function,
                ref args,
                ty: _,
                device: _,
            } => match library_function {
                LibraryFunction::GEMM => Box::new(once(args[0])),
            },
    
            _ => Box::new(empty()),
        }
    }
    
    
    rarbore2's avatar
    rarbore2 committed
    /*
     * Top level function to find implicit clones that need to be spilled. Returns
     * whether a clone was spilled, in which case the whole scheduling process must
     * be restarted.
     */
    fn spill_clones(
        editor: &mut FunctionEditor,
        typing: &Vec<TypeID>,
        control_subgraph: &Subgraph,
        objects: &CollectionObjects,
        bbs: &BasicBlocks,
    
        liveness: &Liveness,
    
    rarbore2's avatar
    rarbore2 committed
    ) -> bool {
    
        // Step 1: compute an interference graph from the liveness result. This
    
    rarbore2's avatar
    rarbore2 committed
        // graph contains a vertex per node ID producing a collection value and an
        // edge per pair of node IDs that interfere. Nodes A and B interfere if node
    
        // A is defined right above a point where node B is live and A != B. Extra
        // edges are drawn for forwarding reads - when there is a node A that is a
        // forwarding read of a node B, A and B really have the same live range for
        // the purpose of determining when spills are necessary, since forwarding
        // reads can be thought of as nothing but pointer math. For this purpose, we
        // maintain a union-find of nodes that form a forwarding read DAG (notably,
        // phis and reduces are not considered forwarding reads). The more precise
        // version of the interference condition is nodes A and B interfere is node
        // A is defined right above a point where a node C is live where C is in the
        // same union-find class as B.
    
        // Assemble the union-find to group forwarding read DAGs.
        let mut union_find = QuickFindUf::<UnionBySize>::new(editor.func().nodes.len());
        for id in editor.node_ids() {
            for forwarding_read in forwarding_reads(editor.func(), editor.func_id(), id, objects) {
                union_find.union(id.idx(), forwarding_read.idx());
            }
        }
    
        // Figure out which classes contain which node IDs, since we need to iterate
        // the disjoint sets.
        let mut disjoint_sets: BTreeMap<usize, Vec<NodeID>> = BTreeMap::new();
        for id in editor.node_ids() {
            disjoint_sets
                .entry(union_find.find(id.idx()))
                .or_default()
                .push(id);
        }
    
        // Create the graph.
    
    rarbore2's avatar
    rarbore2 committed
        let mut edges = vec![];
        for (bb, liveness) in liveness {
            let insts = &bbs.1[bb.idx()];
            for (node, live) in zip(insts, liveness.into_iter().skip(1)) {
                for live_node in live {
    
                    for live_node in disjoint_sets[&union_find.find(live_node.idx())].iter() {
                        if *node != *live_node {
                            edges.push((*node, *live_node));
                        }
    
        // Step 2: filter edges (A, B) to just see edges where A uses B and A
    
        // mutates B. These are the edges that may require a spill.
    
    rarbore2's avatar
    rarbore2 committed
        let mut spill_edges = edges.into_iter().filter(|(a, b)| {
    
            mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
                || (get_uses(&editor.func().nodes[a.idx()])
                    .as_ref()
                    .into_iter()
                    .any(|u| *u == *b)
                    && (editor.func().nodes[a.idx()].is_phi()
    
                        || editor.func().nodes[a.idx()].is_reduce())
                    && !editor.func().nodes[a.idx()]
                        .try_reduce()
    
    rarbore2's avatar
    rarbore2 committed
                        .map(|(_, init, reduct)| {
                            (init == *b || reduct == *b)
                                && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
                        })
                        .unwrap_or(false)
                    && !editor.func().nodes[a.idx()]
                        .try_phi()
                        .map(|(_, data)| {
                            data.contains(b)
    
                                && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
                        })
                        .unwrap_or(false))
    
        // Step 3: if there is a spill edge, spill it and return true. Otherwise,
    
    rarbore2's avatar
    rarbore2 committed
        // return false.
        if let Some((user, obj)) = spill_edges.next() {
            // Figure out the most immediate dominating region for every basic
            // block. These are the points where spill slot phis get placed.
            let nodes = &editor.func().nodes;
            let mut imm_dom_reg = vec![NodeID::new(0); editor.func().nodes.len()];
            for (idx, node) in nodes.into_iter().enumerate() {
                if node.is_region() {
                    imm_dom_reg[idx] = NodeID::new(idx);
                }
            }
            let rev_po = control_subgraph.rev_po(NodeID::new(0));
            for bb in rev_po.iter() {
                if !nodes[bb.idx()].is_region() && !nodes[bb.idx()].is_start() {
                    imm_dom_reg[bb.idx()] =
                        imm_dom_reg[control_subgraph.preds(*bb).next().unwrap().idx()];
                }
            }
    
            let other_obj_users: Vec<_> = editor.get_users(obj).filter(|id| *id != user).collect();
            let mut dummy_phis = vec![NodeID::new(0); imm_dom_reg.len()];
            let mut success = editor.edit(|mut edit| {
                // Construct the spill slot. This is just a constant that gets phi-
                // ed throughout the entire function.
                let cons_id = edit.add_zero_constant(typing[obj.idx()]);
                let slot_id = edit.add_node(Node::Constant { id: cons_id });
    
    rarbore2's avatar
    rarbore2 committed
                edit = edit.add_schedule(slot_id, Schedule::NoResetConstant)?;
    
    rarbore2's avatar
    rarbore2 committed
    
                // Allocate IDs for phis that move the spill slot throughout the
                // function without implicit clones. These are dummy phis, since
                // there are potentially cycles between them. We will replace them
                // later.
                for (idx, reg) in imm_dom_reg.iter().enumerate().skip(1) {
                    if idx == reg.idx() {
                        dummy_phis[idx] = edit.add_node(Node::Phi {
                            control: *reg,
                            data: empty().collect(),
                        });
                    }
                }
    
                // Spill `obj` before `user` potentially modifies it.
                let spill_region = imm_dom_reg[bbs.0[obj.idx()].idx()];
                let spill_id = edit.add_node(Node::Write {
                    collect: if spill_region == NodeID::new(0) {
                        slot_id
                    } else {
                        dummy_phis[spill_region.idx()]
                    },
                    data: obj,
                    indices: empty().collect(),
                });
    
                // Before each other user, unspill `obj`.
                for other_user in other_obj_users {
                    let other_region = imm_dom_reg[bbs.0[other_user.idx()].idx()];
                    // If this assert fails, then `obj` is not in the first basic
                    // block, but it has a user that is in the first basic block,
                    // which violates SSA.
                    assert!(other_region == spill_region || other_region != NodeID::new(0));
    
                    // If an other user is a phi, we need to be a little careful
                    // about how we insert unspilling code for `obj`. Instead of
                    // inserting an unspill in the same block as the user, we need
                    // to insert one in each predecessor of the phi that corresponds
                    // to a use of `obj`. Since this requires modifying individual
                    // uses in a phi, just rebuild the node entirely.
                    if let Node::Phi { control, data } = edit.get_node(other_user).clone() {
                        assert_eq!(control, other_region);
                        let mut new_data = vec![];
                        for (pred, data) in zip(control_subgraph.preds(control), data) {
                            let pred = imm_dom_reg[pred.idx()];
                            if data == obj {
                                let unspill_id = edit.add_node(Node::Write {
                                    collect: obj,
                                    data: if pred == spill_region {
                                        spill_id
                                    } else {
                                        dummy_phis[pred.idx()]
                                    },
                                    indices: empty().collect(),
                                });
                                new_data.push(unspill_id);
                            } else {
                                new_data.push(data);
                            }
                        }
                        let new_phi = edit.add_node(Node::Phi {
                            control,
                            data: new_data.into_boxed_slice(),
                        });
                        edit = edit.replace_all_uses(other_user, new_phi)?;
                        edit = edit.delete_node(other_user)?;
                    } else {
                        let unspill_id = edit.add_node(Node::Write {
                            collect: obj,
                            data: if other_region == spill_region {
                                spill_id
                            } else {
                                dummy_phis[other_region.idx()]
                            },
                            indices: empty().collect(),
                        });
                        edit = edit.replace_all_uses_where(obj, unspill_id, |id| *id == other_user)?;