use std::cell::Ref;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
use std::iter::{empty, once, zip, FromIterator};

use bitvec::prelude::*;
use either::Either;
use union_find::{QuickFindUf, UnionBySize, UnionFind};

use hercules_cg::*;
use hercules_ir::*;

use crate::*;

/*
 * Top level function to legalize the reference semantics of a Hercules IR
 * function. Hercules IR is a value semantics representation, meaning that all
 * program state is in the form of copyable values, and mutation takes place by
 * making a new value that is a copy of the old value with some modification.
 * This representation is extremely convenient for optimization, but is not good
 * for code generation, where we need to generate code with references to get
 * good performance. Hercules IR can alternatively be interpreted using
 * reference semantics, where pointers to collection objects are passed around,
 * read from, and written to. However, the value semantics and reference
 * semantics interpretation of a Hercules IR function may not be equal - this
 * pass transforms a Hercules IR function such that its new value semantics is
 * the same as its old value semantics and that its new reference semantics is
 * the same as its new value semantics. This pass returns a placement of nodes
 * into ordered basic blocks, since the reference semantics of a function
 * depends on the order of execution with respect to anti-dependencies. This
 * is analogous to global code motion from the original sea of nodes paper.
 *
 * Our strategy for handling multiple mutating users of a collection is to treat
 * the problem similar to register allocation; we perform a liveness analysis,
 * spill constants into newly allocated constants, and read back the spilled
 * contents when they are used after the first mutation. It's not obvious how
 * many spills are needed upfront, and newly spilled constants may affect the
 * liveness analysis result, so every spill restarts the process of checking for
 * spills. Once no more spills are found, the process terminates. When a spill
 * is found, the basic block assignments, and all the other analyses, are not
 * necessarily valid anymore, so this function is called in a loop in the pass
 * manager until no more spills are found.
 *
 * GCM also generally tries to massage the code to be properly formed for the
 * device backends in other ways. For example, reduction cycles through the
 * `init` inputs of an inner reduce and a use of a non-parallel outer reduce in
 * nested fork-joins is not schedulable, since the outer reduce doesn't dominate
 * the fork corresponding to the inner reduce. In such cases, the outer fork-
 * join must be split and unforkified.
 *
 * GCM is additionally complicated by the need to generate code that references
 * objects across multiple devices. In particular, GCM makes sure that every
 * object lives on exactly one device, so that references to that object always
 * live on a single device. Additionally, GCM makes sure that the objects that a
 * node may produce are all on the same device, so that a pointer produced by,
 * for example, a select node can only refer to memory on a single device. Extra
 * collection constants and potentially inter-device copies are inserted as
 * necessary to make sure this is true - an inter-device copy is represented by
 * a write where the `collect` and `data` inputs are on different devices. This
 * is only valid in RT functions - it is asserted that this isn't necessary in
 * device functions. This process "colors" the nodes in the function.
 *
 * GCM has one final responsibility - object allocation. Each Hercules function
 * receives a pointer to a "backing" memory where collection constants live. The
 * backing memory a function receives is for the constants in that function and
 * the constants of every called function. Concretely, a function will pass a
 * sub-regions of its backing memory to a callee, which during the call is that
 * function's backing memory. Object allocation consists of finding the required
 * sizes of all collection constants and functions in terms of dynamic constants
 * (dynamic constant math is expressive enough to represent sizes of types,
 * which is very convenient) and determining the concrete offsets into the
 * backing memory where constants and callee sub-regions live. When two users of
 * backing memory are never live at once, they may share backing memory. This is
 * done after nodes are given a single device color, since we need to know what
 * values are on what devices before we can allocate them to backing memory,
 * since there are separate backing memories per-device.
 */
pub fn gcm(
    editor: &mut FunctionEditor,
    def_use: &ImmutableDefUseMap,
    reverse_postorder: &Vec<NodeID>,
    typing: &Vec<TypeID>,
    control_subgraph: &Subgraph,
    dom: &DomTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
    loops: &LoopTree,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    objects: &CollectionObjects,
    devices: &Vec<Device>,
    node_colors: &NodeColors,
    backing_allocations: &BackingAllocations,
) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> {
    if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) {
        return None;
    }

    let bbs = basic_blocks(
        editor.func(),
        editor.get_types(),
        editor.func_id(),
        def_use,
        reverse_postorder,
        typing,
        dom,
        loops,
        reduce_cycles,
        fork_join_map,
        objects,
        devices,
    );

    let liveness = liveness_dataflow(
        editor.func(),
        editor.func_id(),
        control_subgraph,
        objects,
        &bbs,
    );

    if spill_clones(editor, typing, control_subgraph, objects, &bbs, &liveness) {
        return None;
    }

    let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else {
        return None;
    };

    let mut alignments = vec![];
    Ref::map(editor.get_types(), |types| {
        for idx in 0..types.len() {
            if types[idx].is_control() {
                alignments.push(0);
            } else {
                alignments.push(get_type_alignment(types, TypeID::new(idx)));
            }
        }
        &()
    });

    let backing_allocation = object_allocation(
        editor,
        typing,
        &node_colors,
        &alignments,
        &liveness,
        backing_allocations,
    );

    Some((bbs, node_colors, backing_allocation))
}

/*
 * Do misc. fixups on the IR, such as unforkifying sequential outer forks with
 * problematic reduces.
 */
fn preliminary_fixups(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    loops: &LoopTree,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
    let nodes = &editor.func().nodes;
    let schedules = &editor.func().schedules;

    // Sequentialize non-parallel forks that contain problematic reduce cycles.
    for (reduce, cycle) in reduce_cycles {
        if !schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
            && cycle.into_iter().any(|id| nodes[id.idx()].is_reduce())
        {
            let join = nodes[reduce.idx()].try_reduce().unwrap().0;
            let fork = fork_join_map
                .into_iter()
                .filter(|(_, j)| join == **j)
                .map(|(f, _)| *f)
                .next()
                .unwrap();
            let (forks, _) = split_fork(editor, fork, join, reduce_cycles).unwrap();
            if forks.len() > 1 {
                return true;
            }
            unforkify(editor, fork, join, loops);
            return true;
        }
    }

    // Get rid of the backward edge on parallel reduces in fork-joins.
    for (_, join) in fork_join_map {
        let parallel_reduces: Vec<_> = editor
            .get_users(*join)
            .filter(|id| {
                nodes[id.idx()].is_reduce()
                    && schedules[id.idx()].contains(&Schedule::ParallelReduce)
            })
            .collect();
        for reduce in parallel_reduces {
            if reduce_cycles[&reduce].is_empty() {
                continue;
            }
            let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();

            // Replace uses of the reduce in its cycle with the init.
            let success = editor.edit(|edit| {
                edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
            });
            assert!(success);
            return true;
        }
    }

    false
}

/*
 * Top level global code motion function. Assigns each data node to one of its
 * immediate control use / user nodes, forming (unordered) basic blocks. Returns
 * the control node / basic block each node is in. Takes in a partial
 * partitioning that must be respected. Based on the schedule-early-schedule-
 * late method from Cliff Click's PhD thesis. May fail if an anti-dependency
 * edge can't be satisfied - in this case, a clone that has to be induced is
 * returned instead.
 */
fn basic_blocks(
    function: &Function,
    types: Ref<Vec<Type>>,
    func_id: FunctionID,
    def_use: &ImmutableDefUseMap,
    reverse_postorder: &Vec<NodeID>,
    typing: &Vec<TypeID>,
    dom: &DomTree,
    loops: &LoopTree,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    fork_join_map: &HashMap<NodeID, NodeID>,
    objects: &CollectionObjects,
    devices: &Vec<Device>,
) -> BasicBlocks {
    let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];

    // Step 1: assign the basic block locations of all nodes that must be in a
    // specific block. This includes control nodes as well as some special data
    // nodes, such as phis.
    for idx in 0..function.nodes.len() {
        match function.nodes[idx] {
            Node::Phi { control, data: _ } => bbs[idx] = Some(control),
            Node::ThreadID {
                control,
                dimension: _,
            } => bbs[idx] = Some(control),
            Node::Reduce {
                control,
                init: _,
                reduct: _,
            } => bbs[idx] = Some(control),
            Node::Call {
                control,
                function: _,
                dynamic_constants: _,
                args: _,
            } => bbs[idx] = Some(control),
            Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
            _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
            _ => {}
        }
    }

    // Step 2: schedule early. Place nodes in the earliest position they could
    // go - use worklist to iterate nodes.
    let mut schedule_early = bbs.clone();
    let mut worklist = VecDeque::from(reverse_postorder.clone());
    while let Some(id) = worklist.pop_front() {
        if schedule_early[id.idx()].is_some() {
            continue;
        }

        // For every use, check what block is its "schedule early" block. This
        // node goes in the lowest block amongst those blocks.
        let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()])
            .as_ref()
            .into_iter()
            .map(|id| *id)
            .map(|id| schedule_early[id.idx()])
            .collect();
        if let Some(use_places) = use_places {
            // If every use has been placed, we can place this node as the
            // lowest place in the domtree that dominates all of the use places.
            let lowest = dom.lowest_amongst(use_places.into_iter());
            schedule_early[id.idx()] = Some(lowest);
        } else {
            // If not, then just push this node back on the worklist.
            worklist.push_back(id);
        }
    }

    // Step 3: find anti-dependence edges. An anti-dependence edge needs to be
    // drawn between a collection reading node and a collection mutating node
    // when the following conditions are true:
    //
    // 1: The reading and mutating nodes may involve the same collection.
    // 2: The node producing the collection used by the reading node is in a
    //    schedule early block that dominates the schedule early block of the
    //    mutating node. The node producing the collection used by the reading
    //    node may be an originator of a collection, phi or reduce, or mutator,
    //    but not forwarding read - forwarding reads are collapsed, and the
    //    bottom read is treated as reading from the transitive parent of the
    //    forwarding read(s).
    // 3: If the node producing the collection is a reduce node, then any read
    //    users that aren't in the reduce's cycle shouldn't anti-depend user any
    //    mutators in the reduce cycle.
    //
    // Because we do a liveness analysis based spill of collections, anti-
    // dependencies can be best effort. Thus, when we encounter a read and
    // mutator where the read doesn't dominate the mutator, but an anti-depdence
    // edge is derived for the pair, we just don't draw the edge since it would
    // break the scheduler.
    let mut antideps = BTreeSet::new();
    for id in reverse_postorder.iter() {
        // Find a terminating read node and the collections it reads.
        let terminating_reads: BTreeSet<_> =
            terminating_reads(function, func_id, *id, objects).collect();
        if !terminating_reads.is_empty() {
            // Walk forwarding reads to find anti-dependency roots.
            let mut workset = terminating_reads.clone();
            let mut roots = BTreeSet::new();
            while let Some(pop) = workset.pop_first() {
                let forwarded: BTreeSet<_> =
                    forwarding_reads(function, func_id, pop, objects).collect();
                if forwarded.is_empty() {
                    roots.insert(pop);
                } else {
                    workset.extend(forwarded);
                }
            }

            // For each root, find mutating nodes dominated by the root that
            // modify an object read on any input of the current node (the
            // terminating read).
            // TODO: make this less outrageously inefficient.
            let func_objects = &objects[&func_id];
            for root in roots.iter() {
                let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles
                    .get(root)
                    .map(|cycle| !cycle.contains(&id))
                    .unwrap_or(false);
                let root_early = schedule_early[root.idx()].unwrap();
                let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new();
                let mut workset = BTreeSet::new();
                workset.insert(*root);
                while let Some(pop) = workset.pop_first() {
                    let users = def_use.get_users(pop).into_iter().filter(|user| {
                        !function.nodes[user.idx()].is_phi()
                            && !function.nodes[user.idx()].is_reduce()
                            && schedule_early[user.idx()].unwrap() == root_early
                    });
                    workset.extend(users.clone());
                    root_block_iterated_users.extend(users);
                }
                let read_objs: BTreeSet<_> = terminating_reads
                    .iter()
                    .map(|read_use| func_objects.objects(*read_use).into_iter())
                    .flatten()
                    .map(|id| *id)
                    .collect();
                for mutator in reverse_postorder.iter() {
                    let mutator_early = schedule_early[mutator.idx()].unwrap();
                    if dom.does_dom(root_early, mutator_early)
                        && (root_early != mutator_early
                            || root_block_iterated_users.contains(&mutator))
                        && mutating_objects(function, func_id, *mutator, objects)
                            .any(|mutated| read_objs.contains(&mutated))
                        && id != mutator
                        && (!root_is_reduce_and_read_isnt_in_cycle
                            || !reduce_cycles
                                .get(root)
                                .map(|cycle| cycle.contains(mutator))
                                .unwrap_or(false))
                        && dom.does_dom(schedule_early[id.idx()].unwrap(), mutator_early)
                    {
                        antideps.insert((*id, *mutator));
                    }
                }
            }
        }
    }
    let mut antideps_uses = vec![vec![]; function.nodes.len()];
    let mut antideps_users = vec![vec![]; function.nodes.len()];
    for (reader, mutator) in antideps.iter() {
        antideps_uses[mutator.idx()].push(*reader);
        antideps_users[reader.idx()].push(*mutator);
    }

    // Step 4: schedule late and pick each nodes final position. Since the late
    // schedule of each node depends on the final positions of its users, these
    // two steps must be fused. Compute their latest position, then use the
    // control dependent + shallow loop heuristic to actually place them. A
    // placement might not necessarily be found due to anti-dependency edges.
    // These are optional and not necessary to consider, but we do since obeying
    // them can reduce the number of clones. If the worklist stops making
    // progress, stop considering the anti-dependency edges.
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();
    let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
    let mut num_skip_iters = 0;
    let mut consider_antidependencies = true;
    while let Some(id) = worklist.pop_front() {
        if num_skip_iters >= worklist.len() {
            consider_antidependencies = false;
        }

        if bbs[id.idx()].is_some() {
            num_skip_iters = 0;
            continue;
        }

        // Calculate the least common ancestor of user blocks, a.k.a. the "late"
        // schedule.
        let calculate_lca = || -> Option<_> {
            let mut lca = None;
            // Helper to incrementally update the LCA.
            let mut update_lca = |a| {
                if let Some(acc) = lca {
                    lca = Some(dom.least_common_ancestor(acc, a));
                } else {
                    lca = Some(a);
                }
            };

            // For every user, consider where we need to be to directly dominate the
            // user.
            for user in def_use
                .get_users(id)
                .as_ref()
                .into_iter()
                .chain(if consider_antidependencies {
                    Either::Left(antideps_users[id.idx()].iter())
                } else {
                    Either::Right(empty())
                })
                .map(|id| *id)
            {
                if let Node::Phi { control, data } = &function.nodes[user.idx()] {
                    // For phis, we need to dominate the block jumping to the phi in
                    // the slot that corresponds to our use.
                    for (control, data) in
                        zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
                    {
                        if id == *data {
                            update_lca(*control);
                        }
                    }
                } else if let Node::Reduce {
                    control,
                    init,
                    reduct,
                } = &function.nodes[user.idx()]
                {
                    // For reduces, we need to either dominate the block right
                    // before the fork if we're the init input, or we need to
                    // dominate the join if we're the reduct input.
                    if id == *init {
                        let before_fork = function.nodes[join_fork_map[control].idx()]
                            .try_fork()
                            .unwrap()
                            .0;
                        update_lca(before_fork);
                    } else {
                        assert_eq!(id, *reduct);
                        update_lca(*control);
                    }
                } else {
                    // For everything else, we just need to dominate the user.
                    update_lca(bbs[user.idx()]?);
                }
            }

            Some(lca)
        };

        // Check if all users have been placed. If one of them hasn't, then add
        // this node back on to the worklist.
        let Some(lca) = calculate_lca() else {
            worklist.push_back(id);
            num_skip_iters += 1;
            continue;
        };

        // Look between the LCA and the schedule early location to place the
        // node.
        let schedule_early = schedule_early[id.idx()].unwrap();
        let schedule_late = lca.unwrap_or(schedule_early);
        let mut chain = dom
            // If the node has no users, then it doesn't really matter where we
            // place it - just place it at the early placement.
            .chain(schedule_late, schedule_early);

        if let Some(mut location) = chain.next() {
            while let Some(control_node) = chain.next() {
                // If the next node further up the dominator tree is in a shallower
                // loop nest or if we can get out of a reduce loop when we don't
                // need to be in one, place this data node in a higher-up location.
                // Only do this is the node isn't a constant or undef - if a
                // node is a constant or undef, we want its placement to be as
                // control dependent as possible, even inside loops. In GPU
                // functions specifically, lift constants that may be returned
                // outside fork-joins.
                let is_constant_or_undef = (function.nodes[id.idx()].is_constant()
                    || function.nodes[id.idx()].is_undef())
                    && !types[typing[id.idx()].idx()].is_primitive();
                let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
                    && objects[&func_id]
                        .objects(id)
                        .into_iter()
                        .any(|obj| objects[&func_id].returned_objects().contains(obj));
                let old_nest = loops
                    .header_of(location)
                    .map(|header| loops.nesting(header).unwrap());
                let new_nest = loops
                    .header_of(control_node)
                    .map(|header| loops.nesting(header).unwrap());
                let shallower_nest = if let (Some(old_nest), Some(new_nest)) = (old_nest, new_nest)
                {
                    old_nest > new_nest
                } else {
                    // If the new location isn't a loop, it's nesting level should
                    // be considered "shallower" if the current location is in a
                    // loop.
                    old_nest.is_some()
                };
                // This will move all nodes that don't need to be in reduce loops
                // outside of reduce loops. Nodes that do need to be in a reduce
                // loop use the reduce node forming the loop, so the dominator chain
                // will consist of one block, and this loop won't ever iterate.
                let currently_at_join = function.nodes[location.idx()].is_join()
                    && !function.nodes[control_node.idx()].is_join();

                if (!is_constant_or_undef || is_gpu_returned)
                    && (shallower_nest || currently_at_join)
                {
                    location = control_node;
                }
            }

            bbs[id.idx()] = Some(location);
            num_skip_iters = 0;
        } else {
            // If there is no valid location for this node, then it's a reading
            // node of a collection that can't be placed above a mutation that
            // anti-depend uses it. Push the node back on the list, and we'll
            // stop considering anti-dependencies soon. Don't immediately stop
            // considering anti-dependencies, as we may be able to eak out some
            // more use of them.
            worklist.push_back(id);
            num_skip_iters += 1;
            continue;
        }
    }
    let bbs: Vec<_> = bbs.into_iter().map(Option::unwrap).collect();
    // Calculate the number of phis and reduces per basic block. We use this to
    // emit phis and reduces at the top of basic blocks. We want to emit phis
    // and reduces first into ordered basic blocks for two reasons:
    // 1. This is useful for liveness analysis.
    // 2. This is needed for some backends - LLVM expects phis to be at the top
    //    of basic blocks.
    let mut num_phis_reduces = vec![0; function.nodes.len()];
    for (node_idx, bb) in bbs.iter().enumerate() {
        let node = &function.nodes[node_idx];
        if node.is_phi() || node.is_reduce() {
            num_phis_reduces[bb.idx()] += 1;
        }
    }

    // Step 5: determine the order of nodes inside each block. Use worklist to
    // add nodes to blocks in order that obeys dependencies.
    let mut order: Vec<Vec<NodeID>> = vec![vec![]; function.nodes.len()];
    let mut worklist = VecDeque::from_iter(
        reverse_postorder
            .into_iter()
            .filter(|id| !function.nodes[id.idx()].is_control()),
    );
    let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
    let mut num_skip_iters = 0;
    let mut consider_antidependencies = true;
    while let Some(id) = worklist.pop_front() {
        // If the worklist isn't making progress, then there's at least one
        // reading node of a collection that is in a anti-depend + normal depend
        // use cycle with a mutating node. See above comment about anti-
        // dependencies being optional; we just stop considering them here.
        if num_skip_iters >= worklist.len() {
            consider_antidependencies = false;
        }

        // Phis and reduces always get emitted. Other nodes need to obey
        // dependency relationships and need to come after phis and reduces.
        let node = &function.nodes[id.idx()];
        let bb = bbs[id.idx()];
        if node.is_phi()
            || node.is_reduce()
            || (num_phis_reduces[bb.idx()] == 0
                && get_uses(node)
                    .as_ref()
                    .into_iter()
                    .chain(if consider_antidependencies {
                        Either::Left(antideps_uses[id.idx()].iter())
                    } else {
                        Either::Right(empty())
                    })
                    .all(|u| {
                        function.nodes[u.idx()].is_control()
                            || bbs[u.idx()] != bbs[id.idx()]
                            || visited[u.idx()]
                    }))
        {
            order[bb.idx()].push(*id);
            visited.set(id.idx(), true);
            num_skip_iters = 0;
            if node.is_phi() || node.is_reduce() {
                num_phis_reduces[bb.idx()] -= 1;
            }
        } else {
            worklist.push_back(id);
            num_skip_iters += 1;
        }
    }

    (bbs, order)
}

fn terminating_reads<'a>(
    function: &'a Function,
    func_id: FunctionID,
    reader: NodeID,
    objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = NodeID> + 'a> {
    match function.nodes[reader.idx()] {
        Node::Read {
            collect,
            indices: _,
        } if objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)),
        Node::Write {
            collect: _,
            data,
            indices: _,
        } if !objects[&func_id].objects(data).is_empty() => Box::new(once(data)),
        Node::Call {
            control: _,
            function: callee,
            dynamic_constants: _,
            ref args,
        } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
            let objects = &objects[&callee];
            let returns = objects.returned_objects();
            let param_obj = objects.param_to_object(idx)?;
            if !objects.is_mutated(param_obj) && !returns.contains(&param_obj) {
                Some(*arg)
            } else {
                None
            }
        })),
        Node::LibraryCall {
            library_function,
            ref args,
            ty: _,
            device: _,
        } => match library_function {
            LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
        },
        _ => Box::new(empty()),
    }
}

fn forwarding_reads<'a>(
    function: &'a Function,
    func_id: FunctionID,
    reader: NodeID,
    objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = NodeID> + 'a> {
    match function.nodes[reader.idx()] {
        Node::Read {
            collect,
            indices: _,
        } if !objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)),
        Node::Ternary {
            op: TernaryOperator::Select,
            first: _,
            second,
            third,
        } if !objects[&func_id].objects(reader).is_empty() => {
            Box::new(once(second).chain(once(third)))
        }
        Node::Call {
            control: _,
            function: callee,
            dynamic_constants: _,
            ref args,
        } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
            let objects = &objects[&callee];
            let returns = objects.returned_objects();
            let param_obj = objects.param_to_object(idx)?;
            if !objects.is_mutated(param_obj) && returns.contains(&param_obj) {
                Some(*arg)
            } else {
                None
            }
        })),
        _ => Box::new(empty()),
    }
}

fn mutating_objects<'a>(
    function: &'a Function,
    func_id: FunctionID,
    mutator: NodeID,
    objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = CollectionObjectID> + 'a> {
    match function.nodes[mutator.idx()] {
        Node::Write {
            collect,
            data: _,
            indices: _,
        } => Box::new(objects[&func_id].objects(collect).into_iter().map(|id| *id)),
        Node::Call {
            control: _,
            function: callee,
            dynamic_constants: _,
            ref args,
        } => Box::new(
            args.into_iter()
                .enumerate()
                .filter_map(move |(idx, arg)| {
                    let callee_objects = &objects[&callee];
                    let param_obj = callee_objects.param_to_object(idx)?;
                    if callee_objects.is_mutated(param_obj) {
                        Some(objects[&func_id].objects(*arg).into_iter().map(|id| *id))
                    } else {
                        None
                    }
                })
                .flatten(),
        ),
        Node::LibraryCall {
            library_function,
            ref args,
            ty: _,
            device: _,
        } => match library_function {
            LibraryFunction::GEMM => {
                Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
            }
        },
        _ => Box::new(empty()),
    }
}

fn mutating_writes<'a>(
    function: &'a Function,
    mutator: NodeID,
    objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = NodeID> + 'a> {
    match function.nodes[mutator.idx()] {
        Node::Write {
            collect,
            data: _,
            indices: _,
        } => Box::new(once(collect)),
        Node::Call {
            control: _,
            function: callee,
            dynamic_constants: _,
            ref args,
        } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
            let callee_objects = &objects[&callee];
            let param_obj = callee_objects.param_to_object(idx)?;
            if callee_objects.is_mutated(param_obj) {
                Some(*arg)
            } else {
                None
            }
        })),
        Node::LibraryCall {
            library_function,
            ref args,
            ty: _,
            device: _,
        } => match library_function {
            LibraryFunction::GEMM => Box::new(once(args[0])),
        },
        _ => Box::new(empty()),
    }
}

/*
 * Top level function to find implicit clones that need to be spilled. Returns
 * whether a clone was spilled, in which case the whole scheduling process must
 * be restarted.
 */
fn spill_clones(
    editor: &mut FunctionEditor,
    typing: &Vec<TypeID>,
    control_subgraph: &Subgraph,
    objects: &CollectionObjects,
    bbs: &BasicBlocks,
    liveness: &Liveness,
) -> bool {
    // Step 1: compute an interference graph from the liveness result. This
    // graph contains a vertex per node ID producing a collection value and an
    // edge per pair of node IDs that interfere. Nodes A and B interfere if node
    // A is defined right above a point where node B is live and A != B. Extra
    // edges are drawn for forwarding reads - when there is a node A that is a
    // forwarding read of a node B, A and B really have the same live range for
    // the purpose of determining when spills are necessary, since forwarding
    // reads can be thought of as nothing but pointer math. For this purpose, we
    // maintain a union-find of nodes that form a forwarding read DAG (notably,
    // phis and reduces are not considered forwarding reads). The more precise
    // version of the interference condition is nodes A and B interfere is node
    // A is defined right above a point where a node C is live where C is in the
    // same union-find class as B.

    // Assemble the union-find to group forwarding read DAGs.
    let mut union_find = QuickFindUf::<UnionBySize>::new(editor.func().nodes.len());
    for id in editor.node_ids() {
        for forwarding_read in forwarding_reads(editor.func(), editor.func_id(), id, objects) {
            union_find.union(id.idx(), forwarding_read.idx());
        }
    }

    // Figure out which classes contain which node IDs, since we need to iterate
    // the disjoint sets.
    let mut disjoint_sets: BTreeMap<usize, Vec<NodeID>> = BTreeMap::new();
    for id in editor.node_ids() {
        disjoint_sets
            .entry(union_find.find(id.idx()))
            .or_default()
            .push(id);
    }

    // Create the graph.
    let mut edges = vec![];
    for (bb, liveness) in liveness {
        let insts = &bbs.1[bb.idx()];
        for (node, live) in zip(insts, liveness.into_iter().skip(1)) {
            for live_node in live {
                for live_node in disjoint_sets[&union_find.find(live_node.idx())].iter() {
                    if *node != *live_node {
                        edges.push((*node, *live_node));
                    }
                }
            }
        }
    }

    // Step 2: filter edges (A, B) to just see edges where A uses B and A
    // mutates B. These are the edges that may require a spill.
    let mut spill_edges = edges.into_iter().filter(|(a, b)| {
        mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
            || (get_uses(&editor.func().nodes[a.idx()])
                .as_ref()
                .into_iter()
                .any(|u| *u == *b)
                && (editor.func().nodes[a.idx()].is_phi()
                    || editor.func().nodes[a.idx()].is_reduce())
                && !editor.func().nodes[a.idx()]
                    .try_reduce()
                    .map(|(_, init, _)| {
                        init == *b
                            && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
                    })
                    .unwrap_or(false))
    });

    // Step 3: if there is a spill edge, spill it and return true. Otherwise,
    // return false.
    if let Some((user, obj)) = spill_edges.next() {
        // Figure out the most immediate dominating region for every basic
        // block. These are the points where spill slot phis get placed.
        let nodes = &editor.func().nodes;
        let mut imm_dom_reg = vec![NodeID::new(0); editor.func().nodes.len()];
        for (idx, node) in nodes.into_iter().enumerate() {
            if node.is_region() {
                imm_dom_reg[idx] = NodeID::new(idx);
            }
        }
        let rev_po = control_subgraph.rev_po(NodeID::new(0));
        for bb in rev_po.iter() {
            if !nodes[bb.idx()].is_region() && !nodes[bb.idx()].is_start() {
                imm_dom_reg[bb.idx()] =
                    imm_dom_reg[control_subgraph.preds(*bb).next().unwrap().idx()];
            }
        }

        let other_obj_users: Vec<_> = editor.get_users(obj).filter(|id| *id != user).collect();
        let mut dummy_phis = vec![NodeID::new(0); imm_dom_reg.len()];
        let mut success = editor.edit(|mut edit| {
            // Construct the spill slot. This is just a constant that gets phi-
            // ed throughout the entire function.
            let cons_id = edit.add_zero_constant(typing[obj.idx()]);
            let slot_id = edit.add_node(Node::Constant { id: cons_id });
            edit = edit.add_schedule(slot_id, Schedule::NoResetConstant)?;

            // Allocate IDs for phis that move the spill slot throughout the
            // function without implicit clones. These are dummy phis, since
            // there are potentially cycles between them. We will replace them
            // later.
            for (idx, reg) in imm_dom_reg.iter().enumerate().skip(1) {
                if idx == reg.idx() {
                    dummy_phis[idx] = edit.add_node(Node::Phi {
                        control: *reg,
                        data: empty().collect(),
                    });
                }
            }

            // Spill `obj` before `user` potentially modifies it.
            let spill_region = imm_dom_reg[bbs.0[obj.idx()].idx()];
            let spill_id = edit.add_node(Node::Write {
                collect: if spill_region == NodeID::new(0) {
                    slot_id
                } else {
                    dummy_phis[spill_region.idx()]
                },
                data: obj,
                indices: empty().collect(),
            });

            // Before each other user, unspill `obj`.
            for other_user in other_obj_users {
                let other_region = imm_dom_reg[bbs.0[other_user.idx()].idx()];
                // If this assert fails, then `obj` is not in the first basic
                // block, but it has a user that is in the first basic block,
                // which violates SSA.
                assert!(other_region == spill_region || other_region != NodeID::new(0));

                // If an other user is a phi, we need to be a little careful
                // about how we insert unspilling code for `obj`. Instead of
                // inserting an unspill in the same block as the user, we need
                // to insert one in each predecessor of the phi that corresponds
                // to a use of `obj`. Since this requires modifying individual
                // uses in a phi, just rebuild the node entirely.
                if let Node::Phi { control, data } = edit.get_node(other_user).clone() {
                    assert_eq!(control, other_region);
                    let mut new_data = vec![];
                    for (pred, data) in zip(control_subgraph.preds(control), data) {
                        let pred = imm_dom_reg[pred.idx()];
                        if data == obj {
                            let unspill_id = edit.add_node(Node::Write {
                                collect: obj,
                                data: if pred == spill_region {
                                    spill_id
                                } else {
                                    dummy_phis[pred.idx()]
                                },
                                indices: empty().collect(),
                            });
                            new_data.push(unspill_id);
                        } else {
                            new_data.push(data);
                        }
                    }
                    let new_phi = edit.add_node(Node::Phi {
                        control,
                        data: new_data.into_boxed_slice(),
                    });
                    edit = edit.replace_all_uses(other_user, new_phi)?;
                    edit = edit.delete_node(other_user)?;
                } else {
                    let unspill_id = edit.add_node(Node::Write {
                        collect: obj,
                        data: if other_region == spill_region {
                            spill_id
                        } else {
                            dummy_phis[other_region.idx()]
                        },
                        indices: empty().collect(),
                    });
                    edit = edit.replace_all_uses_where(obj, unspill_id, |id| *id == other_user)?;
                }
            }

            // Create and hook up all the real phis. Phi elimination will clean
            // this up.
            let mut real_phis = vec![NodeID::new(0); imm_dom_reg.len()];
            for (idx, reg) in imm_dom_reg.iter().enumerate().skip(1) {
                if idx == reg.idx() {
                    real_phis[idx] = edit.add_node(Node::Phi {
                        control: *reg,
                        data: control_subgraph
                            .preds(*reg)
                            .map(|pred| {
                                let pred = imm_dom_reg[pred.idx()];
                                if pred == spill_region {
                                    spill_id
                                } else if pred == NodeID::new(0) {
                                    slot_id
                                } else {
                                    dummy_phis[pred.idx()]
                                }
                            })
                            .collect(),
                    });
                }
            }
            for (dummy, real) in zip(dummy_phis.iter(), real_phis) {
                if *dummy != real {
                    edit = edit.replace_all_uses(*dummy, real)?;
                }
            }

            Ok(edit)
        });
        success = success
            && editor.edit(|mut edit| {
                for dummy in dummy_phis {
                    if dummy != NodeID::new(0) {
                        edit = edit.delete_node(dummy)?;
                    }
                }
                Ok(edit)
            });
        assert!(success, "PANIC: GCM cannot fail to edit a function, as it needs to legalize the reference semantics of every function before code generation.");
        true
    } else {
        false
    }
}

type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>;

/*
 * Liveness dataflow analysis on scheduled Hercules IR. Just look at nodes that
 * involve collections.
 */
fn liveness_dataflow(
    function: &Function,
    func_id: FunctionID,
    control_subgraph: &Subgraph,
    objects: &CollectionObjects,
    bbs: &BasicBlocks,
) -> Liveness {
    let mut po = control_subgraph.rev_po(NodeID::new(0));
    po.reverse();
    let mut liveness = Liveness::default();
    for (bb_idx, insts) in bbs.1.iter().enumerate() {
        liveness.insert(NodeID::new(bb_idx), vec![BTreeSet::new(); insts.len() + 1]);
    }
    let mut num_phis_reduces = vec![0; function.nodes.len()];
    let mut has_phi = vec![false; function.nodes.len()];
    let mut has_seq_reduce = vec![false; function.nodes.len()];
    for (node_idx, bb) in bbs.0.iter().enumerate() {
        let node = &function.nodes[node_idx];
        if node.is_phi() || node.is_reduce() {
            num_phis_reduces[bb.idx()] += 1;
        }
        has_phi[bb.idx()] = node.is_phi();
        has_seq_reduce[bb.idx()] =
            node.is_reduce() && !function.schedules[node_idx].contains(&Schedule::ParallelReduce);
        assert!(!node.is_phi() || !node.is_reduce());
    }
    let is_obj = |id: NodeID| !objects[&func_id].objects(id).is_empty();

    loop {
        let mut changed = false;

        for bb in po.iter() {
            // First, calculate the liveness set for the bottom of this block.
            let last_pt = bbs.1[bb.idx()].len();
            let old_value = &liveness[&bb][last_pt];
            let mut new_value = BTreeSet::new();
            for succ in control_subgraph
                .succs(*bb)
                .chain(if has_seq_reduce[bb.idx()] {
                    Either::Left(once(*bb))
                } else {
                    Either::Right(empty())
                })
            {
                // The liveness at the bottom of a basic block is the union of:
                // 1. The liveness of each succecessor right after its phis and
                //    reduces.
                // 2. Every data use in a phi or reduce that corresponds to this
                //    block as the predecessor.
                let after_phis_reduces_pt = num_phis_reduces[succ.idx()];
                new_value.extend(&liveness[&succ][after_phis_reduces_pt]);
                for inst_idx in 0..after_phis_reduces_pt {
                    let id = bbs.1[succ.idx()][inst_idx];
                    new_value.remove(&id);
                    match function.nodes[id.idx()] {
                        Node::Phi { control, ref data } if is_obj(data[0]) => {
                            assert_eq!(control, succ);
                            new_value.extend(
                                zip(control_subgraph.preds(succ), data)
                                    .filter(|(pred, _)| *pred == *bb)
                                    .map(|(_, data)| *data),
                            );
                        }
                        Node::Reduce {
                            control,
                            init,
                            reduct,
                        } if is_obj(init) => {
                            assert_eq!(control, succ);
                            if succ == *bb {
                                new_value.insert(reduct);
                            } else if !function.schedules[id.idx()]
                                .contains(&Schedule::ParallelReduce)
                            {
                                new_value.insert(init);
                            }
                        }
                        _ => {}
                    }
                }
            }
            changed |= *old_value != new_value;
            liveness.get_mut(&bb).unwrap()[last_pt] = new_value;

            // Second, calculate the liveness set above each instruction in this block.
            for pt in (0..last_pt).rev() {
                let old_value = &liveness[&bb][pt];
                let mut new_value = liveness[&bb][pt + 1].clone();
                let id = bbs.1[bb.idx()][pt];
                let uses = get_uses(&function.nodes[id.idx()]);
                let is_obj = |id: &NodeID| is_obj(*id);
                new_value.remove(&id);
                new_value.extend(
                    if let Node::Write {
                        collect: _,
                        data,
                        ref indices,
                    } = function.nodes[id.idx()]
                        && indices.is_empty()
                    {
                        // If this write is a cloning write, the `collect` input
                        // isn't actually live, because its value doesn't
                        // matter.
                        Either::Left(once(data).filter(is_obj))
                    } else if let Node::Reduce {
                        control: _,
                        init: _,
                        reduct,
                    } = function.nodes[id.idx()]
                        && function.schedules[id.idx()].contains(&Schedule::ParallelReduce)
                    {
                        // If this reduce is a parallel reduce, the `init` input
                        // isn't actually live.
                        Either::Left(once(reduct).filter(is_obj))
                    } else {
                        Either::Right(uses.as_ref().into_iter().map(|id| *id).filter(is_obj))
                    },
                );
                changed |= *old_value != new_value;
                liveness.get_mut(&bb).unwrap()[pt] = new_value;
            }
        }

        if !changed {
            return liveness;
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum UTerm {
    Node(NodeID),
    Device(Device),
}

fn unify(
    mut equations: VecDeque<(UTerm, UTerm)>,
) -> Result<BTreeMap<NodeID, Device>, BTreeMap<NodeID, Device>> {
    let mut theta = BTreeMap::new();

    let mut no_progress_iters = 0;
    while no_progress_iters <= equations.len()
        && let Some((l, r)) = equations.pop_front()
    {
        match (l, r) {
            (UTerm::Node(_), UTerm::Node(_)) => {
                if l != r {
                    equations.push_back((l, r));
                }
                no_progress_iters += 1;
            }
            (UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => {
                theta.insert(n, d);
                for (l, r) in equations.iter_mut() {
                    if *l == UTerm::Node(n) {
                        *l = UTerm::Device(d);
                    }
                    if *r == UTerm::Node(n) {
                        *r = UTerm::Device(d);
                    }
                }
                no_progress_iters = 0;
            }
            (UTerm::Device(d1), UTerm::Device(d2)) if d1 == d2 => {
                no_progress_iters = 0;
            }
            _ => {
                return Err(theta);
            }
        }
    }

    Ok(theta)
}

/*
 * Determine what device each node produces a collection onto. Insert inter-
 * device clones when a single node may potentially be on different devices.
 */
fn color_nodes(
    editor: &mut FunctionEditor,
    typing: &Vec<TypeID>,
    objects: &CollectionObjects,
    devices: &Vec<Device>,
    node_colors: &NodeColors,
) -> Option<FunctionNodeColors> {
    let nodes = &editor.func().nodes;
    let func_id = editor.func_id();
    let func_device = devices[func_id.idx()];
    let mut func_colors = (
        BTreeMap::new(),
        vec![None; editor.func().param_types.len()],
        None,
    );

    // Assigning nodes to devices is tricky due to function calls. Technically,
    // all the information is there to decide what device to place nodes on, but
    // coherently expressing the constraints and deriving the devices is not
    // obvious. Express this as a unification problem, where we need to assign
    // types (devices) to each node. Each unification term is either a node ID
    // or a concrete device. Assemble a list of unification equations to solve.
    let mut equations = vec![];
    for id in editor.node_ids() {
        match nodes[id.idx()] {
            Node::Phi {
                control: _,
                ref data,
            } if !editor.get_type(typing[id.idx()]).is_primitive() => {
                // Every input to a phi needs to be on the same device. The
                // phi itself is also on this device.
                for (l, r) in zip(data.into_iter(), data.into_iter().skip(1).chain(once(&id))) {
                    equations.push((UTerm::Node(*l), UTerm::Node(*r)));
                }
            }
            Node::Reduce {
                control: _,
                init: first,
                reduct: second,
            }
            | Node::Ternary {
                op: TernaryOperator::Select,
                first: _,
                second: first,
                third: second,
            } if !editor.get_type(typing[id.idx()]).is_primitive() => {
                // Every input to the reduce, and the reduce itself, are on
                // the same device.
                equations.push((UTerm::Node(first), UTerm::Node(second)));
                equations.push((UTerm::Node(second), UTerm::Node(id)));
            }
            Node::Constant { id: _ }
                if !editor.get_type(typing[id.idx()]).is_primitive()
                    && func_device != Device::AsyncRust =>
            {
                // Constants inside device functions are allocated on that
                // device.
                equations.push((UTerm::Node(id), UTerm::Device(func_device)));
            }
            Node::Read {
                collect,
                indices: _,
            } => {
                if editor.get_type(typing[id.idx()]).is_primitive() {
                    // If this reads a primitive, then the collection needs to
                    // be on the device of this function.
                    equations.push((
                        UTerm::Node(collect),
                        UTerm::Device(backing_device(func_device)),
                    ));
                } else {
                    // If this read just reads a sub-collection, then `collect`
                    // and the read itself need to be on the same device.
                    equations.push((UTerm::Node(collect), UTerm::Node(id)));
                }
            }
            Node::Write {
                collect,
                data,
                indices: _,
            } => {
                if func_device != Device::AsyncRust
                    || editor.get_type(typing[data.idx()]).is_primitive()
                {
                    // If this writes a primitive or this is in a device
                    // function, then the collection needs to be on the backing
                    // device of this function.
                    equations.push((
                        UTerm::Node(collect),
                        UTerm::Device(backing_device(func_device)),
                    ));

                    if func_device != Device::AsyncRust
                        && !editor.get_type(typing[data.idx()]).is_primitive()
                    {
                        // We can only do inter-device copies in AsyncRust
                        // functions.
                        equations.push((UTerm::Node(collect), UTerm::Node(data)));
                    }
                }
                equations.push((UTerm::Node(collect), UTerm::Node(id)));
            }
            Node::Call {
                control: _,
                function: callee,
                dynamic_constants: _,
                ref args,
            } => {
                // If the callee has a definite device for a parameter, add an
                // equation for the corresponding argument.
                for (idx, arg) in args.into_iter().enumerate() {
                    if let Some(device) = node_colors[&callee].1[idx] {
                        equations.push((UTerm::Node(*arg), UTerm::Device(device)));
                    }
                }

                // If the callee has a definite device for the returned value,
                // add an equation for the call node itself.
                if let Some(device) = node_colors[&callee].2 {
                    equations.push((UTerm::Node(id), UTerm::Device(device)));
                }

                // For any object that may be returned by the callee that
                // originates as a parameter in the callee, the device of the
                // corresponding argument and call node itself must be equal.
                for ret in objects[&callee].returned_objects() {
                    if let Some(idx) = objects[&callee].origin(*ret).try_parameter() {
                        equations.push((UTerm::Node(args[idx]), UTerm::Node(id)));
                    }
                }
            }
            Node::LibraryCall {
                library_function: _,
                ref args,
                ty: _,
                device,
            } => {
                for arg in args {
                    equations.push((UTerm::Node(*arg), UTerm::Device(device)));
                }
                equations.push((UTerm::Node(id), UTerm::Device(device)));
            }
            _ => {}
        }
    }

    // Solve the unification problem. I couldn't find a simple enough crate for
    // this, and the problems are usually pretty small, so just use a hand-
    // rolled implementation for now.
    match unify(VecDeque::from(equations)) {
        Ok(solve) => {
            func_colors.0 = solve;
            // Look at parameter and return nodes to get the device signature of
            // the function.
            for id in editor.node_ids() {
                if let Node::Parameter { index } = nodes[id.idx()]
                    && let Some(device) = func_colors.0.get(&id)
                {
                    assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first.");
                    func_colors.1[index] = Some(*device);
                } else if let Node::Return { control: _, data } = nodes[id.idx()]
                    && let Some(device) = func_colors.0.get(&data)
                {
                    assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix.");
                    func_colors.2 = Some(*device);
                }
            }
            Some(func_colors)
        }
        Err(progress) => {
            // If unification failed, then there's some node using a node in
            // `progress` that's expecting a different type than what it got.
            // Pick one and add potentially inter-device copies on each def-use
            // edge. We'll clean these up later.
            let (id, _) = progress.into_iter().next().unwrap();
            let users: Vec<_> = editor.get_users(id).collect();
            let success = editor.edit(|mut edit| {
                let cons = edit.add_zero_constant(typing[id.idx()]);
                for user in users {
                    let cons = edit.add_node(Node::Constant { id: cons });
                    edit = edit.add_schedule(cons, Schedule::NoResetConstant)?;
                    let copy = edit.add_node(Node::Write {
                        collect: cons,
                        data: id,
                        indices: Box::new([]),
                    });
                    edit = edit.replace_all_uses_where(id, copy, |id| *id == user)?;
                }
                Ok(edit)
            });
            assert!(success);
            None
        }
    }
}

fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID {
    assert_ne!(align, 0);
    if align != 1 {
        let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align));
        let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1));
        acc = edit.add_dynamic_constant(DynamicConstant::add(acc, align_m1_dc));
        acc = edit.add_dynamic_constant(DynamicConstant::div(acc, align_dc));
        acc = edit.add_dynamic_constant(DynamicConstant::mul(acc, align_dc));
    }
    acc
}

/*
 * Determine the size of a type in terms of dynamic constants.
 */
fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> DynamicConstantID {
    let ty = edit.get_type(ty_id).clone();
    let size = match ty {
        Type::Control => panic!(),
        Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
            edit.add_dynamic_constant(DynamicConstant::Constant(1))
        }
        Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => {
            edit.add_dynamic_constant(DynamicConstant::Constant(2))
        }
        Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => {
            edit.add_dynamic_constant(DynamicConstant::Constant(4))
        }
        Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => {
            edit.add_dynamic_constant(DynamicConstant::Constant(8))
        }
        Type::Product(fields) => {
            // The layout of product types is like the C-style layout.
            let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0));
            for field in fields {
                // Round up to the alignment of the field, then add the size of
                // the field.
                let field_size = type_size(edit, field, alignments);
                acc_size = align(edit, acc_size, alignments[field.idx()]);
                acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, field_size));
            }
            // Finally, round up to the alignment of the whole product, since
            // the size needs to be a multiple of the alignment.
            acc_size = align(edit, acc_size, alignments[ty_id.idx()]);
            acc_size
        }
        Type::Summation(variants) => {
            // A summation holds every variant in the same memory.
            let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0));
            for variant in variants {
                // Pick the size of the largest variant, since that's the most
                // memory we would need.
                let variant_size = type_size(edit, variant, alignments);
                acc_size = edit.add_dynamic_constant(DynamicConstant::max(acc_size, variant_size));
            }
            // Add one byte for the discriminant and align the whole summation.
            let one = edit.add_dynamic_constant(DynamicConstant::Constant(1));
            acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, one));
            acc_size = align(edit, acc_size, alignments[ty_id.idx()]);
            acc_size
        }
        Type::Array(elem, bounds) => {
            // The layout of an array is row-major linear in memory.
            let mut acc_size = type_size(edit, elem, alignments);
            for bound in bounds {
                acc_size = edit.add_dynamic_constant(DynamicConstant::mul(acc_size, bound));
            }
            acc_size
        }
    };
    size
}

/*
 * Allocate objects in a function. Relies on the allocations of all called
 * functions.
 */
fn object_allocation(
    editor: &mut FunctionEditor,
    typing: &Vec<TypeID>,
    node_colors: &FunctionNodeColors,
    alignments: &Vec<usize>,
    _liveness: &Liveness,
    backing_allocations: &BackingAllocations,
) -> FunctionBackingAllocation {
    let mut fba = BTreeMap::new();

    let node_ids = editor.node_ids();
    editor.edit(|mut edit| {
        // For now, just allocate each object to its own slot.
        let zero = edit.add_dynamic_constant(DynamicConstant::Constant(0));
        for id in node_ids {
            match *edit.get_node(id) {
                Node::Constant { id: _ } => {
                    if !edit.get_type(typing[id.idx()]).is_primitive() {
                        let device = node_colors.0[&id];
                        let (total, offsets) =
                            fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
                        *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]);
                        offsets.insert(id, *total);
                        let type_size = type_size(&mut edit, typing[id.idx()], alignments);
                        *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size));
                    }
                }
                Node::Call {
                    control: _,
                    function: callee,
                    ref dynamic_constants,
                    args: _,
                } => {
                    let dynamic_constants = dynamic_constants.to_vec();
                    let dc_args = (0..dynamic_constants.len())
                        .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i)));
                    let substs = dc_args
                        .zip(dynamic_constants.into_iter())
                        .collect::<HashMap<_, _>>();

                    for device in BACKED_DEVICES {
                        if let Some(mut callee_backing_size) = backing_allocations[&callee]
                            .get(&device)
                            .map(|(callee_total, _)| *callee_total)
                        {
                            let (total, offsets) =
                                fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
                            // We don't know the alignment requirement of the memory
                            // in the callee, so just assume the largest alignment.
                            *total = align(&mut edit, *total, LARGEST_ALIGNMENT);
                            offsets.insert(id, *total);
                            // Substitute the dynamic constant parameters in the
                            // callee's backing size.
                            callee_backing_size = substitute_dynamic_constants(
                                &substs,
                                callee_backing_size,
                                &mut edit,
                            );
                            *total = edit.add_dynamic_constant(DynamicConstant::add(
                                *total,
                                callee_backing_size,
                            ));
                        }
                    }
                }
                _ => {}
            }
        }
        Ok(edit)
    });

    fba
}