Skip to content
Snippets Groups Projects
fork_transforms.rs 58.36 KiB
use std::collections::{HashMap, HashSet};
use std::iter::zip;

use bimap::BiMap;
use itertools::Itertools;

use hercules_ir::*;

use crate::*;

type ForkID = usize;

/** Places each reduce node into its own fork */
pub fn default_reduce_partition(
    editor: &FunctionEditor,
    _fork: NodeID,
    join: NodeID,
) -> SparseNodeMap<ForkID> {
    let mut map = SparseNodeMap::new();

    editor
        .get_users(join)
        .filter(|id| editor.func().nodes[id.idx()].is_reduce())
        .enumerate()
        .for_each(|(fork, reduce)| {
            map.insert(reduce, fork);
        });

    map
}

// TODO: Refine these conditions.
/**  */
pub fn find_reduce_dependencies<'a>(
    function: &'a Function,
    reduce: NodeID,
    fork: NodeID,
) -> impl IntoIterator<Item = NodeID> + 'a {
    let len = function.nodes.len();

    let mut visited: DenseNodeMap<bool> = vec![false; len];
    let mut depdendent: DenseNodeMap<bool> = vec![false; len];

    // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it?
    fn recurse(
        function: &Function,
        node: NodeID,
        fork: NodeID,
        dependent_map: &mut DenseNodeMap<bool>,
        visited: &mut DenseNodeMap<bool>,
    ) -> () {
        // return through dependent_map {

        if visited[node.idx()] {
            return;
        }

        visited[node.idx()] = true;

        if node == fork {
            dependent_map[node.idx()] = true;
            return;
        }

        let binding = get_uses(&function.nodes[node.idx()]);
        let uses = binding.as_ref();

        for used in uses {
            recurse(function, *used, fork, dependent_map, visited);
        }

        dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a);
        return;
    }

    // Note: HACKY, the condition wwe want is 'all nodes  on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph)
    // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node
    // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce)
    // the current solution is just to mark the reduce as dependent at the start of traversing the graph.
    depdendent[reduce.idx()] = true;

    recurse(function, reduce, fork, &mut depdendent, &mut visited);

    // Return node IDs that are dependent
    let ret_val: Vec<_> = depdendent
        .iter()
        .enumerate()
        .filter_map(|(idx, dependent)| {
            if *dependent {
                Some(NodeID::new(idx))
            } else {
                None
            }
        })
        .collect();

    ret_val
}

pub fn copy_subgraph_in_edit<'a, 'b>(
    mut edit: FunctionEdit<'a, 'b>,
    subgraph: HashSet<NodeID>,
) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> {
    let mut map: HashMap<NodeID, NodeID> = HashMap::new();

    // Copy nodes in subgraph
    for old_id in subgraph.iter().cloned() {
        let new_id = edit.copy_node(old_id);
        map.insert(old_id, new_id);
    }

    // Update edges to new nodes
    for old_id in subgraph.iter() {
        edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| {
            map.values().contains(node_id)
        })?;
    }

    Ok((edit, map))
}

pub fn copy_subgraph(
    editor: &mut FunctionEditor,
    subgraph: HashSet<NodeID>,
) -> (
    HashSet<NodeID>,
    HashMap<NodeID, NodeID>,
    Vec<(NodeID, NodeID)>,
) // returns all new nodes, a map from old nodes to new nodes, and 
  // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node,
  // outside means not part of the original subgraph.
{
    let mut map: HashMap<NodeID, NodeID> = HashMap::new();
    let mut new_nodes: HashSet<NodeID> = HashSet::new();

    // Copy nodes
    for old_id in subgraph.iter() {
        editor.edit(|mut edit| {
            let new_id = edit.copy_node(*old_id);
            map.insert(*old_id, new_id);
            new_nodes.insert(new_id);
            Ok(edit)
        });
    }

    // Update edges to new nodes
    for old_id in subgraph.iter() {
        // Replace all uses of old_id w/ new_id, where the use is in new_node
        editor.edit(|edit| {
            edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id))
        });
    }

    // Get all users that aren't in new_nodes.
    let mut outside_users = Vec::new();

    for node in new_nodes.iter() {
        for user in editor.get_users(*node) {
            if !new_nodes.contains(&user) {
                outside_users.push((*node, user));
            }
        }
    }

    (new_nodes, map, outside_users)
}

pub fn find_bufferize_edges(
    editor: &mut FunctionEditor,
    fork: NodeID,
    loop_tree: &LoopTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    data_label: &LabelID,
) -> HashSet<(NodeID, NodeID)> {
    let mut edges: HashSet<_> = HashSet::new();

    for node in &nodes_in_fork_joins[&fork] {
        // Edge from *has data label** to doesn't have data label*
        let node_labels = &editor.func().labels[node.idx()];

        if !node_labels.contains(data_label) {
            continue;
        }

        // Don't draw bufferize edges from fork tids
        if editor.get_users(fork).contains(node) {
            continue;
        }

        for user in editor.get_users(*node) {
            let user_labels = &editor.func().labels[user.idx()];
            if user_labels.contains(data_label) {
                continue;
            }

            if editor.node(user).is_control() || editor.node(node).is_control() {
                continue;
            }

            edges.insert((*node, user));
        }
    }
    edges
}

pub fn ff_bufferize_create_not_reduce_cycle_label_helper(
    editor: &mut FunctionEditor,
    fork: NodeID,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> LabelID {
    let join = fork_join_map[&fork];
    let mut nodes_not_in_a_reduce_cycle = nodes_in_fork_joins[&fork].clone();
    for (cycle, reduce) in editor
        .get_users(join)
        .filter_map(|id| reduce_cycles.get(&id).map(|cycle| (cycle, id)))
    {
        nodes_not_in_a_reduce_cycle.remove(&reduce);
        for id in cycle {
            nodes_not_in_a_reduce_cycle.remove(id);
        }
    }
    nodes_not_in_a_reduce_cycle.remove(&join);

    let mut label = LabelID::new(0);
    let success = editor.edit(|mut edit| {
        label = edit.fresh_label();
        for id in nodes_not_in_a_reduce_cycle {
            edit = edit.add_label(id, label)?;
        }
        Ok(edit)
    });

    assert!(success);
    label
}

pub fn ff_bufferize_any_fork<'a, 'b>(
    editor: &'b mut FunctionEditor<'a>,
    loop_tree: &'b LoopTree,
    fork_join_map: &'b HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
    nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
    typing: &'b Vec<TypeID>,
    fork_label: LabelID,
    data_label: Option<LabelID>,
) -> Option<(NodeID, NodeID)>
where
    'a: 'b,
{
    let mut forks: Vec<_> = loop_tree
        .bottom_up_loops()
        .into_iter()
        .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
        .collect();
    forks.reverse();

    for l in forks {
        let fork_info = Loop {
            header: l.0,
            control: l.1.clone(),
        };
        let fork = fork_info.header;
        let join = fork_join_map[&fork];

        if !editor.func().labels[fork.idx()].contains(&fork_label) {
            continue;
        }

        let data_label = data_label.unwrap_or_else(|| {
            ff_bufferize_create_not_reduce_cycle_label_helper(
                editor,
                fork,
                fork_join_map,
                reduce_cycles,
                nodes_in_fork_joins,
            )
        });
        let edges = find_bufferize_edges(
            editor,
            fork,
            &loop_tree,
            &fork_join_map,
            &nodes_in_fork_joins,
            &data_label,
        );
        let result = fork_bufferize_fission_helper(
            editor,
            &fork_info,
            &edges,
            nodes_in_fork_joins,
            typing,
            fork,
            join,
        );
        if result.is_none() {
            continue;
        } else {
            return result;
        }
    }
    return None;
}

pub fn fork_fission<'a>(
    editor: &'a mut FunctionEditor,
    _control_subgraph: &Subgraph,
    _types: &Vec<TypeID>,
    _loop_tree: &LoopTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> () {
    let forks: Vec<_> = editor
        .func()
        .nodes
        .iter()
        .enumerate()
        .filter_map(|(idx, node)| {
            if node.is_fork() {
                Some(NodeID::new(idx))
            } else {
                None
            }
        })
        .collect();

    let control_pred = NodeID::new(0);

    // This does the reduction fission:
    for fork in forks.clone() {
        // FIXME: If there is control in between fork and join, don't just give up.
        let join = fork_join_map[&fork];
        let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
        if join_pred != fork {
            todo!("Can't do fork fission on nodes with internal control")
            // Inner control LOOPs are hard
            // inner control in general *should* work right now without modifications.
        }
        let reduce_partition = default_reduce_partition(editor, fork, join);
        fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
    }
}

/** Split a 1D fork into two forks, placing select intermediate data into buffers. */
pub fn fork_bufferize_fission_helper<'a, 'b>(
    editor: &'b mut FunctionEditor<'a>,
    l: &Loop,
    bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
    data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
    types: &'b Vec<TypeID>,
    fork: NodeID,
    join: NodeID,
) -> Option<(NodeID, NodeID)>
where
    'a: 'b,
{
    if bufferized_edges.is_empty() {
        return None;
    }

    let all_loop_nodes = l.get_all_nodes();

    // FIXME: Cloning hell.
    let data_nodes = data_node_in_fork_joins[&fork].clone();
    let loop_nodes = editor
        .node_ids()
        .filter(|node_id| all_loop_nodes[node_id.idx()]);
    // Clone the subgraph that consists of this fork-join and all data and control nodes in it.
    let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes));

    let mut outside_users = Vec::new(); // old_node, outside_user

    for node in subgraph.iter() {
        for user in editor.get_users(*node) {
            if !subgraph.iter().contains(&user) {
                outside_users.push((*node, user));
            }
        }
    }

    let factors: Vec<_> = editor.func().nodes[fork.idx()]
        .try_fork()
        .unwrap()
        .1
        .iter()
        .cloned()
        .collect();

    let thread_stuff_it = factors.into_iter().enumerate();

    // Control succesors
    let fork_pred = editor
        .get_uses(fork)
        .filter(|a| editor.node(a).is_control())
        .next()
        .unwrap();
    let join_successor = editor
        .get_users(join)
        .filter(|a| editor.node(a).is_control())
        .next()
        .unwrap();

    let mut new_fork_id = NodeID::new(0);

    let edit_result = editor.edit(|edit| {
        let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;

        edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
        edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;

        // Replace outside uses of reduces in old subgraph with reduces in new subgraph.
        for (old_node, outside_user) in outside_users {
            edit = edit
                .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
        }

        let new_fork = map[&fork];

        // FIXME: Do this as part of copy subgraph?
        // Add tids to original subgraph for indexing.
        let mut old_tids = Vec::new();
        let mut new_tids = Vec::new();
        for (dim, _) in thread_stuff_it.clone() {
            let old_id = edit.add_node(Node::ThreadID {
                control: fork,
                dimension: dim,
            });

            let new_id = edit.add_node(Node::ThreadID {
                control: new_fork,
                dimension: dim,
            });

            old_tids.push(old_id);
            new_tids.push(new_id);
        }

        for (src, dst) in bufferized_edges {
            let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
            let position_idx = Index::Position(old_tids.clone().into_boxed_slice());

            let write = edit.add_node(Node::Write {
                collect: NodeID::new(0),
                data: *src,
                indices: vec![position_idx.clone()].into(),
            });
            let ele_type = types[src.idx()];
            let empty_buffer = edit.add_type(hercules_ir::Type::Array(
                ele_type,
                array_dims.collect::<Vec<_>>().into_boxed_slice(),
            ));
            let empty_buffer = edit.add_zero_constant(empty_buffer);
            let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
            edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?;
            let reduce = Node::Reduce {
                control: join,
                init: empty_buffer,
                reduct: write,
            };
            let reduce = edit.add_node(reduce);
            edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?;

            // Fix write node
            edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;

            // Create reads from buffer
            let position_idx = Index::Position(new_tids.clone().into_boxed_slice());

            let read = edit.add_node(Node::Read {
                collect: reduce,
                indices: vec![position_idx].into(),
            });

            // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph
            edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?;
        }

        new_fork_id = new_fork;

        Ok(edit)
    });

    if edit_result {
        Some((fork, new_fork_id))
    } else {
        None
    }
}
/** Split a 1D fork into a separate fork for each reduction. */
pub fn fork_reduce_fission_helper<'a>(
    editor: &'a mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
    original_control_pred: NodeID,           // What the new fork connects to.

    fork: NodeID,
) -> (NodeID, NodeID) {
    let join = fork_join_map[&fork];

    let mut new_control_pred: NodeID = original_control_pred;
    // Important edges are: Reduces,

    // NOTE:
    // Say two reduce are in a fork, s.t  reduce A depends on reduce B
    // If user wants A and B in separate forks:
    // - we can simply refuse
    // - or we can duplicate B

    let mut new_fork = NodeID::new(0);
    let mut new_join = NodeID::new(0);

    // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
    for reduce in reduce_partition {
        let reduce = reduce.0;

        let function = editor.func();
        let subgraph = find_reduce_dependencies(function, reduce, fork);

        let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();

        subgraph.insert(join);
        subgraph.insert(fork);
        subgraph.insert(reduce);

        let (_, mapping, _) = copy_subgraph(editor, subgraph);

        new_fork = mapping[&fork];
        new_join = mapping[&join];

        editor.edit(|mut edit| {
            // Atttach new_fork after control_pred
            let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
            edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
                *usee == new_fork
            })?;

            // Replace uses of reduce
            edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
            Ok(edit)
        });

        new_control_pred = new_join;
    }

    editor.edit(|mut edit| {
        // Replace original join w/ new final join
        edit = edit.replace_all_uses_where(join, new_join, |_| true)?;

        // Delete original join (all reduce users have been moved)
        edit = edit.delete_node(join)?;

        // Replace all users of original fork, and then delete it, leftover users will be DCE'd.
        edit = edit.replace_all_uses(fork, new_fork)?;
        edit.delete_node(fork)
    });

    (new_fork, new_join)
}

pub fn fork_coalesce(
    editor: &mut FunctionEditor,
    loops: &LoopTree,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
    let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| {
        if editor.func().nodes[k.idx()].is_fork() {
            Some(k)
        } else {
            None
        }
    });

    let fork_joins: Vec<_> = fork_joins.collect();
    // FIXME: Add a postorder traversal to optimize this.

    // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
    // something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
    for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
        if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
            return true;
        }
    }
    return false;
}

/** Opposite of fork split, takes two fork-joins
    with no control between them, and merges them into a single fork-join.
    Returns None if the forks could not be merged and the NodeIDs of the
    resulting fork and join if it succeeds in merging them.
*/
pub fn fork_coalesce_helper(
    editor: &mut FunctionEditor,
    outer_fork: NodeID,
    inner_fork: NodeID,
    fork_join_map: &HashMap<NodeID, NodeID>,
) -> Option<(NodeID, NodeID)> {
    // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.

    let outer_join = fork_join_map[&outer_fork];
    let inner_join = fork_join_map[&inner_fork];

    let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner

    // FIXME: Iterate all control uses of joins to really collect all reduces
    // (reduces can be attached to inner control)
    for outer_reduce in editor
        .get_users(outer_join)
        .filter(|node| editor.func().nodes[node.idx()].is_reduce())
    {
        // check that inner reduce is of the inner join
        let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()]
            .try_reduce()
            .unwrap();

        let inner_reduce = outer_reduct;
        let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()];

        let Node::Reduce {
            control: inner_control,
            init: inner_init,
            reduct: _,
        } = inner_reduce_node
        else {
            return None;
        };

        // FIXME: check this condition better (i.e reduce might not be attached to join)
        if *inner_control != inner_join {
            return None;
        };
        if *inner_init != outer_reduce {
            return None;
        };

        if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
            return None;
        } else {
            pairs.insert(outer_reduce, inner_reduce);
        }
    }

    // Check for control between join-join and fork-fork
    let (control, _) = editor.node(inner_fork).try_fork().unwrap();

    if control != outer_fork {
        return None;
    }

    let control = editor.node(outer_join).try_join().unwrap();

    if control != inner_join {
        return None;
    }

    // Checklist:
    // Increment inner TIDs
    // Add outer fork's dimension to front of inner fork.
    // Fuse reductions
    //  - Initializer becomes outer initializer
    // Replace uses of outer fork w/ inner fork.
    // Replace uses of outer join w/ inner join.
    // Delete outer fork-join

    let inner_tids: Vec<NodeID> = editor
        .get_users(inner_fork)
        .filter(|node| editor.func().nodes[node.idx()].is_thread_id())
        .collect();

    let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap();
    let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap();
    let num_outer_dims = outer_dims.len();
    let mut new_factors = outer_dims.to_vec();

    // CHECKME / FIXME: Might need to be added the other way.
    new_factors.append(&mut inner_dims.to_vec());

    let mut new_fork = NodeID::new(0);
    let new_join = inner_join; // We'll reuse the inner join as the join of the new fork

    let success = editor.edit(|mut edit| {
        for tid in inner_tids {
            let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
            let new_tid = Node::ThreadID {
                control: fork,
                dimension: dim + num_outer_dims,
            };

            let new_tid = edit.add_node(new_tid);
            edit = edit.replace_all_uses(tid, new_tid)?;
            edit.sub_edit(tid, new_tid);
        }
        // Fuse Reductions
        for (outer_reduce, inner_reduce) in pairs {
            let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap();
            let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap();
            // Set inner init to outer init.
            edit =
                edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
            edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
            edit = edit.delete_node(outer_reduce)?;
        }

        let new_fork_node = Node::Fork {
            control: outer_pred,
            factors: new_factors.into(),
        };
        new_fork = edit.add_node(new_fork_node);

        if edit
            .get_schedule(outer_fork)
            .contains(&Schedule::ParallelFork)
            && edit
                .get_schedule(inner_fork)
                .contains(&Schedule::ParallelFork)
        {
            edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
        }

        edit = edit.replace_all_uses(inner_fork, new_fork)?;
        edit = edit.replace_all_uses(outer_fork, new_fork)?;
        edit = edit.replace_all_uses(outer_join, inner_join)?;
        edit = edit.delete_node(outer_join)?;
        edit = edit.delete_node(inner_fork)?;
        edit = edit.delete_node(outer_fork)?;

        Ok(edit)
    });

    if success {
        Some((new_fork, new_join))
    } else {
        None
    }
}

pub fn split_any_fork(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
    for (fork, join) in fork_join_map {
        if let Some((forks, joins)) = split_fork(editor, *fork, *join, reduce_cycles)
            && forks.len() > 1
        {
            return Some((forks, joins));
        }
    }
    None
}

/*
 * Split multi-dimensional fork-joins into separate one-dimensional fork-joins.
 * Useful for code generation. A single iteration of `fork_split` only splits
 * at most one fork-join, it must be called repeatedly to split all fork-joins.
 */
pub fn split_fork(
    editor: &mut FunctionEditor,
    fork: NodeID,
    join: NodeID,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
    // A single multi-dimensional fork becomes multiple forks, a join becomes
    // multiple joins, a thread ID becomes a thread ID on the correct
    // fork, and a reduce becomes multiple reduces to shuffle the reduction
    // value through the fork-join nest.
    let nodes = &editor.func().nodes;
    let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
    if factors.len() < 2 {
        return Some((vec![fork], vec![join]));
    }
    let factors: Box<[DynamicConstantID]> = factors.into();
    let join_control = nodes[join.idx()].try_join().unwrap();
    let tids: Vec<_> = editor
        .get_users(fork)
        .filter(|id| nodes[id.idx()].is_thread_id())
        .collect();
    let reduces: Vec<_> = editor
        .get_users(join)
        .filter(|id| nodes[id.idx()].is_reduce())
        .collect();

    let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces
        .iter()
        .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce)))
        .flatten()
        .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user))
        .collect();

    let mut new_forks = vec![];
    let mut new_joins = vec![];
    let success = editor.edit(|mut edit| {
        // Create the forks and a thread ID per fork.
        let mut acc_fork = fork_control;
        let mut new_tids = vec![];
        for factor in factors {
            acc_fork = edit.add_node(Node::Fork {
                control: acc_fork,
                factors: Box::new([factor]),
            });
            new_forks.push(acc_fork);
            edit.sub_edit(fork, acc_fork);
            new_tids.push(edit.add_node(Node::ThreadID {
                control: acc_fork,
                dimension: 0,
            }));
        }

        // Create the joins.
        let mut acc_join = if join_control == fork {
            acc_fork
        } else {
            join_control
        };
        for _ in new_tids.iter() {
            acc_join = edit.add_node(Node::Join { control: acc_join });
            edit.sub_edit(join, acc_join);
            new_joins.push(acc_join);
        }

        // Create the reduces.
        let mut new_reduces = vec![];
        for reduce in reduces.iter() {
            let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap();
            let num_nodes = edit.num_node_ids();
            let mut inner_reduce = NodeID::new(0);
            let mut outer_reduce = NodeID::new(0);
            for (join_idx, join) in new_joins.iter().enumerate() {
                let init = if join_idx == new_joins.len() - 1 {
                    init
                } else {
                    NodeID::new(num_nodes + join_idx + 1)
                };
                let reduct = if join_idx == 0 {
                    reduct
                } else {
                    NodeID::new(num_nodes + join_idx - 1)
                };
                let new_reduce = edit.add_node(Node::Reduce {
                    control: *join,
                    init,
                    reduct,
                });
                assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx));
                edit.sub_edit(*reduce, new_reduce);
                if join_idx == 0 {
                    inner_reduce = new_reduce;
                }
                if join_idx == new_joins.len() - 1 {
                    outer_reduce = new_reduce;
                }
            }
            new_reduces.push((inner_reduce, outer_reduce));
        }

        // Replace everything.
        edit = edit.replace_all_uses(fork, acc_fork)?;
        edit = edit.replace_all_uses(join, acc_join)?;
        for tid in tids.iter() {
            let dim = edit.get_node(*tid).try_thread_id().unwrap().1;
            edit.sub_edit(*tid, new_tids[dim]);
            edit = edit.replace_all_uses(*tid, new_tids[dim])?;
        }
        for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) {
            edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| {
                data_in_reduce_cycle.contains(&(*id, *reduce))
            })?;
            edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| {
                !data_in_reduce_cycle.contains(&(*id, *reduce))
            })?;
        }

        // Delete all the old stuff.
        edit = edit.delete_node(fork)?;
        edit = edit.delete_node(join)?;
        for tid in tids {
            edit = edit.delete_node(tid)?;
        }
        for reduce in reduces {
            edit = edit.delete_node(reduce)?;
        }

        Ok(edit)
    });
    if success {
        new_joins.reverse();
        Some((new_forks, new_joins))
    } else {
        None
    }
}

pub fn chunk_all_forks_unguarded(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    dim_idx: usize,
    tile_size: usize,
    order: bool,
) -> () {
    // Add dc
    let mut dc_id = DynamicConstantID::new(0);
    editor.edit(|mut edit| {
        dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size));
        Ok(edit)
    });

    let order = match order {
        true => &TileOrder::TileInner,
        false => &TileOrder::TileOuter,
    };

    for (fork, _) in fork_join_map {
        if editor.is_mutable(*fork) {
            chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order);
        }
    }
}
// Splits a dimension of a single fork join into multiple.
// Iterates an outer loop original_dim / tile_size times
// adds a tile_size loop as the inner loop
// Assumes that tile size divides original dim evenly.

enum TileOrder {
    TileInner,
    TileOuter,
}

pub fn chunk_fork_unguarded(
    editor: &mut FunctionEditor,
    fork: NodeID,
    dim_idx: usize,
    tile_size: DynamicConstantID,
    order: &TileOrder,
) -> () {
    // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
    let Node::Fork {
        control: old_control,
        factors: ref old_factors,
    } = *editor.node(fork)
    else {
        return;
    };
    assert!(dim_idx < old_factors.len());
    let mut new_factors: Vec<_> = old_factors.to_vec();
    let fork_users: Vec<_> = editor
        .get_users(fork)
        .map(|f| (f, editor.node(f).clone()))
        .collect();

    match order {
        TileOrder::TileInner => {
            editor.edit(|mut edit| {
                let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
                new_factors.insert(dim_idx + 1, tile_size);
                new_factors[dim_idx] = edit.add_dynamic_constant(outer);

                let new_fork = Node::Fork {
                    control: old_control,
                    factors: new_factors.into(),
                };
                let new_fork = edit.add_node(new_fork);

                edit = edit.replace_all_uses(fork, new_fork)?;
                edit.sub_edit(fork, new_fork);

                for (tid, node) in fork_users {
                    let Node::ThreadID {
                        control: _,
                        dimension: tid_dim,
                    } = node
                    else {
                        continue;
                    };
                    if tid_dim > dim_idx {
                        let new_tid = Node::ThreadID {
                            control: new_fork,
                            dimension: tid_dim + 1,
                        };
                        let new_tid = edit.add_node(new_tid);
                        edit = edit.replace_all_uses(tid, new_tid)?;
                        edit.sub_edit(tid, new_tid);
                        edit = edit.delete_node(tid)?;
                    } else if tid_dim == dim_idx {
                        let tile_tid = Node::ThreadID {
                            control: new_fork,
                            dimension: tid_dim + 1,
                        };
                        let tile_tid = edit.add_node(tile_tid);

                        let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
                        let mul = edit.add_node(Node::Binary {
                            left: tid,
                            right: tile_size,
                            op: BinaryOperator::Mul,
                        });
                        let add = edit.add_node(Node::Binary {
                            left: mul,
                            right: tile_tid,
                            op: BinaryOperator::Add,
                        });
                        edit.sub_edit(tid, add);
                        edit.sub_edit(tid, tile_tid);
                        edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
                    }
                }
                edit = edit.delete_node(fork)?;
                Ok(edit)
            });
        }
        TileOrder::TileOuter => {
            editor.edit(|mut edit| {
                let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
                new_factors.insert(dim_idx, tile_size);
                let inner_dc_id = edit.add_dynamic_constant(inner);
                new_factors[dim_idx + 1] = inner_dc_id;

                let new_fork = Node::Fork {
                    control: old_control,
                    factors: new_factors.into(),
                };
                let new_fork = edit.add_node(new_fork);

                edit = edit.replace_all_uses(fork, new_fork)?;
                edit.sub_edit(fork, new_fork);

                for (tid, node) in fork_users {
                    let Node::ThreadID {
                        control: _,
                        dimension: tid_dim,
                    } = node
                    else {
                        continue;
                    };
                    if tid_dim > dim_idx {
                        let new_tid = Node::ThreadID {
                            control: new_fork,
                            dimension: tid_dim + 1,
                        };
                        let new_tid = edit.add_node(new_tid);
                        edit = edit.replace_all_uses(tid, new_tid)?;
                        edit.sub_edit(tid, new_tid);
                        edit = edit.delete_node(tid)?;
                    } else if tid_dim == dim_idx {
                        let tile_tid = Node::ThreadID {
                            control: new_fork,
                            dimension: tid_dim + 1,
                        };
                        let tile_tid = edit.add_node(tile_tid);
                        let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
                        let mul = edit.add_node(Node::Binary {
                            left: tid,
                            right: inner_dc,
                            op: BinaryOperator::Mul,
                        });
                        let add = edit.add_node(Node::Binary {
                            left: mul,
                            right: tile_tid,
                            op: BinaryOperator::Add,
                        });
                        edit.sub_edit(tid, add);
                        edit.sub_edit(tid, tile_tid);
                        edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
                    }
                }
                edit = edit.delete_node(fork)?;
                Ok(edit)
            });
        }
    }
}

pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
    for (fork, _) in fork_join_map {
        let Node::Fork {
            control: _,
            factors: dims,
        } = editor.node(fork)
        else {
            unreachable!();
        };

        let mut fork = *fork;
        for _ in 0..dims.len() - 1 {
            let outer = 0;
            let inner = 1;
            fork = fork_dim_merge(editor, fork, outer, inner);
        }
    }
}

pub fn fork_dim_merge(
    editor: &mut FunctionEditor,
    fork: NodeID,
    dim_idx1: usize,
    dim_idx2: usize,
) -> NodeID {
    // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2)
    // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2)
    assert_ne!(dim_idx1, dim_idx2);

    // Outer is smaller, and also closer to the left of the factors array.
    let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 {
        (dim_idx2, dim_idx1)
    } else {
        (dim_idx1, dim_idx2)
    };
    let Node::Fork {
        control: old_control,
        factors: ref old_factors,
    } = *editor.node(fork)
    else {
        return fork;
    };
    let mut new_factors: Vec<_> = old_factors.to_vec();
    let fork_users: Vec<_> = editor
        .get_users(fork)
        .map(|f| (f, editor.node(f).clone()))
        .collect();
    let mut new_nodes = vec![];
    let outer_dc_id = new_factors[outer_idx];
    let inner_dc_id = new_factors[inner_idx];
    let mut new_fork = NodeID::new(0);

    editor.edit(|mut edit| {
        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(
            new_factors[outer_idx],
            new_factors[inner_idx],
        ));
        new_factors.remove(inner_idx);
        new_fork = edit.add_node(Node::Fork {
            control: old_control,
            factors: new_factors.into(),
        });
        edit.sub_edit(fork, new_fork);
        edit = edit.replace_all_uses(fork, new_fork)?;
        edit = edit.delete_node(fork)?;

        for (tid, node) in fork_users {
            let Node::ThreadID {
                control: _,
                dimension: tid_dim,
            } = node
            else {
                continue;
            };
            if tid_dim > inner_idx {
                let new_tid = Node::ThreadID {
                    control: new_fork,
                    dimension: tid_dim - 1,
                };
                let new_tid = edit.add_node(new_tid);
                edit = edit.replace_all_uses(tid, new_tid)?;
                edit.sub_edit(tid, new_tid);
            } else if tid_dim == outer_idx {
                let outer_tid = Node::ThreadID {
                    control: new_fork,
                    dimension: outer_idx,
                };
                let outer_tid = edit.add_node(outer_tid);

                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                new_nodes.push(outer_tid);

                // inner_idx % dim(outer_idx)
                let rem = edit.add_node(Node::Binary {
                    left: outer_tid,
                    right: outer_dc,
                    op: BinaryOperator::Rem,
                });
                edit.sub_edit(tid, rem);
                edit.sub_edit(tid, outer_tid);
                edit = edit.replace_all_uses(tid, rem)?;
            } else if tid_dim == inner_idx {
                let outer_tid = Node::ThreadID {
                    control: new_fork,
                    dimension: outer_idx,
                };
                let outer_tid = edit.add_node(outer_tid);

                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
                // inner_idx / dim(outer_idx)
                let div = edit.add_node(Node::Binary {
                    left: outer_tid,
                    right: outer_dc,
                    op: BinaryOperator::Div,
                });
                edit.sub_edit(tid, div);
                edit.sub_edit(tid, outer_tid);
                edit = edit.replace_all_uses(tid, div)?;
            }
        }
        Ok(edit)
    });

    new_fork
}

/*
 * Run fork interchange on all fork-joins that are mutable in an editor.
 */
pub fn fork_interchange_all_forks(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    first_dim: usize,
    second_dim: usize,
) {
    for (fork, join) in fork_join_map {
        if editor.is_mutable(*fork) {
            fork_interchange(editor, *fork, *join, first_dim, second_dim);
        }
    }
}

pub fn fork_interchange(
    editor: &mut FunctionEditor,
    fork: NodeID,
    join: NodeID,
    first_dim: usize,
    second_dim: usize,
) -> Option<NodeID> {
    // Check that every reduce on the join is parallel or associative.
    let nodes = &editor.func().nodes;
    let schedules = &editor.func().schedules;
    if !editor
        .get_users(join)
        .filter(|id| nodes[id.idx()].is_reduce())
        .all(|id| {
            schedules[id.idx()].contains(&Schedule::ParallelReduce)
                || schedules[id.idx()].contains(&Schedule::MonoidReduce)
        })
    {
        // If not, we can't necessarily do interchange.
        return None;
    }

    let Node::Fork {
        control,
        ref factors,
    } = nodes[fork.idx()]
    else {
        panic!()
    };
    let fix_tids: Vec<(NodeID, Node)> = editor
        .get_users(fork)
        .filter_map(|id| {
            nodes[id.idx()]
                .try_thread_id()
                .map(|(_, dim)| {
                    if dim == first_dim {
                        Some((
                            id,
                            Node::ThreadID {
                                control: fork,
                                dimension: second_dim,
                            },
                        ))
                    } else if dim == second_dim {
                        Some((
                            id,
                            Node::ThreadID {
                                control: fork,
                                dimension: first_dim,
                            },
                        ))
                    } else {
                        None
                    }
                })
                .flatten()
        })
        .collect();
    let mut factors = factors.clone();
    factors.swap(first_dim, second_dim);
    let new_fork = Node::Fork { control, factors };
    let mut new_fork_id = None;
    editor.edit(|mut edit| {
        for (old_id, new_tid) in fix_tids {
            let new_id = edit.add_node(new_tid);
            edit = edit.replace_all_uses(old_id, new_id)?;
            edit = edit.delete_node(old_id)?;
        }
        let new_fork = edit.add_node(new_fork);
        if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
            edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
        }
        edit = edit.replace_all_uses(fork, new_fork)?;
        edit = edit.delete_node(fork)?;

        new_fork_id = Some(new_fork);
        Ok(edit)
    });

    new_fork_id
}

/*
 * Run fork unrolling on all fork-joins that are mutable in an editor.
 */
pub fn fork_unroll_all_forks(
    editor: &mut FunctionEditor,
    fork_joins: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
    for (fork, join) in fork_joins {
        if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) {
            break;
        }
    }
}

pub fn fork_unroll(
    editor: &mut FunctionEditor,
    fork: NodeID,
    join: NodeID,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
    // We can only unroll fork-joins with a compile time known factor list. For
    // simplicity, just unroll fork-joins that have a single dimension.
    let nodes = &editor.func().nodes;
    let Node::Fork {
        control,
        ref factors,
    } = nodes[fork.idx()]
    else {
        panic!()
    };
    if factors.len() != 1 || editor.get_users(fork).count() != 2 {
        return false;
    }
    let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
        return false;
    };
    let tid = editor
        .get_users(fork)
        .filter(|id| nodes[id.idx()].is_thread_id())
        .next()
        .unwrap();
    let inits: HashMap<NodeID, NodeID> = editor
        .get_users(join)
        .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
        .collect();

    editor.edit(|mut edit| {
        // Create a copy of the nodes in the fork join per unrolled iteration,
        // excluding the fork itself, the join itself, the thread IDs of the fork,
        // and the reduces on the join. Keep a running tally of the top control
        // node and the current reduction value.
        let mut top_control = control;
        let mut current_reduces = inits;
        for iter in 0..cons {
            let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
            let iter_tid = edit.add_node(Node::Constant { id: iter_cons });

            // First, add a copy of each node in the fork join unmodified.
            // Record the mapping from old ID to new ID.
            let mut added_ids = HashSet::new();
            let mut old_to_new_ids = HashMap::new();
            let mut new_control = None;
            let mut new_reduces = HashMap::new();
            for node in nodes_in_fork_joins[&fork].iter() {
                if *node == fork {
                    old_to_new_ids.insert(*node, top_control);
                } else if *node == join {
                    new_control = Some(edit.get_node(*node).try_join().unwrap());
                } else if *node == tid {
                    old_to_new_ids.insert(*node, iter_tid);
                } else if let Some(current) = current_reduces.get(node) {
                    old_to_new_ids.insert(*node, *current);
                    new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
                } else {
                    let new_node = edit.add_node(edit.get_node(*node).clone());
                    old_to_new_ids.insert(*node, new_node);
                    added_ids.insert(new_node);
                }
            }

            // Second, replace all the uses in the just added nodes.
            if let Some(new_control) = new_control {
                top_control = old_to_new_ids[&new_control];
            }
            for (reduce, reduct) in new_reduces {
                current_reduces.insert(reduce, old_to_new_ids[&reduct]);
            }
            for (old, new) in old_to_new_ids {
                edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
            }
        }

        // Hook up the control and reduce outputs to the rest of the function.
        edit = edit.replace_all_uses(join, top_control)?;
        for (reduce, reduct) in current_reduces {
            edit = edit.replace_all_uses(reduce, reduct)?;
        }

        // Delete the old fork-join.
        for node in nodes_in_fork_joins[&fork].iter() {
            edit = edit.delete_node(*node)?;
        }
        Ok(edit)
    })
}

/*
 * Looks for fork-joins that are next to each other, not inter-dependent, and
 * have the same bounds. These fork-joins can be fused, pooling together all
 * their reductions.
 */
pub fn fork_fusion_all_forks(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
    for (fork, join) in fork_join_map {
        if editor.is_mutable(*fork)
            && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins)
        {
            break;
        }
    }
}

/*
 * Tries to fuse a given fork join with the immediately following fork-join, if
 * it exists.
 */
fn fork_fusion(
    editor: &mut FunctionEditor,
    top_fork: NodeID,
    top_join: NodeID,
    fork_join_map: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
    let nodes = &editor.func().nodes;
    // Rust operator precedence is not such that these can be put in one big
    // let-else statement. Sad!
    let Some(bottom_fork) = editor
        .get_users(top_join)
        .filter(|id| nodes[id.idx()].is_control())
        .next()
    else {
        return false;
    };
    let Some(bottom_join) = fork_join_map.get(&bottom_fork) else {
        return false;
    };
    let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap();
    let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap();
    assert_eq!(bottom_fork_pred, top_join);
    let top_join_pred = nodes[top_join.idx()].try_join().unwrap();
    let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap();

    // The fork factors must be identical.
    if top_factors != bottom_factors {
        return false;
    }

    // Check that no iterated users of the top's reduces are in the bottom fork-
    // join (iteration stops at a phi or reduce outside the bottom fork-join).
    for reduce in editor
        .get_users(top_join)
        .filter(|id| nodes[id.idx()].is_reduce())
    {
        let mut visited = HashSet::new();
        visited.insert(reduce);
        let mut workset = vec![reduce];
        while let Some(pop) = workset.pop() {
            for u in editor.get_users(pop) {
                if nodes_in_fork_joins[&bottom_fork].contains(&u) {
                    return false;
                } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce())
                    && !nodes_in_fork_joins[&top_fork].contains(&u)
                {
                } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) {
                    visited.insert(u);
                    workset.push(u);
                }
            }
        }
    }

    // Perform the fusion.
    let bottom_tids: Vec<_> = editor
        .get_users(bottom_fork)
        .filter(|id| nodes[id.idx()].is_thread_id())
        .collect();
    editor.edit(|mut edit| {
        edit = edit.replace_all_uses_where(bottom_fork, top_fork, |id| bottom_tids.contains(id))?;
        if bottom_join_pred != bottom_fork {
            // If there is control flow in the bottom fork-join, stitch it into
            // the top fork-join.
            edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| {
                nodes_in_fork_joins[&bottom_fork].contains(id)
            })?;
            edit =
                edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?;
        }
        // Replace the bottom fork and join with the top fork and join.
        edit = edit.replace_all_uses(bottom_fork, top_fork)?;
        edit = edit.replace_all_uses(*bottom_join, top_join)?;
        edit = edit.delete_node(bottom_fork)?;
        edit = edit.delete_node(*bottom_join)?;
        Ok(edit)
    })
}

/*
 * Looks for monoid reductions where the initial input is not the identity
 * element, and converts them into a form whose initial input is an identity
 * element. This aides in parallelizing outer loops. Looks only at reduces with
 * the monoid reduce schedule, since that indicates a particular structure which
 * is annoying to check for again.
 */
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
    for id in editor.node_ids() {
        if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
            continue;
        }
        let nodes = &editor.func().nodes;
        let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
            continue;
        };
        let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();

        match nodes[reduct.idx()] {
            Node::Binary {
                op,
                left: _,
                right: _,
            } if (op == BinaryOperator::Add || op == BinaryOperator::Or)
                && !is_zero(editor, init) =>
            {
                editor.edit(|mut edit| {
                    let zero = edit.add_zero_constant(typing[init.idx()]);
                    let zero = edit.add_node(Node::Constant { id: zero });
                    edit.sub_edit(id, zero);
                    edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?;
                    let final_op = edit.add_node(Node::Binary {
                        op,
                        left: init,
                        right: id,
                    });
                    for u in out_uses {
                        edit.sub_edit(u, final_op);
                    }
                    edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                });
            }
            Node::Binary {
                op,
                left: _,
                right: _,
            } if (op == BinaryOperator::Mul || op == BinaryOperator::And)
                && !is_one(editor, init) =>
            {
                editor.edit(|mut edit| {
                    let one = edit.add_one_constant(typing[init.idx()]);
                    let one = edit.add_node(Node::Constant { id: one });
                    edit.sub_edit(id, one);
                    edit = edit.replace_all_uses_where(init, one, |u| *u == id)?;
                    let final_op = edit.add_node(Node::Binary {
                        op,
                        left: init,
                        right: id,
                    });
                    for u in out_uses {
                        edit.sub_edit(u, final_op);
                    }
                    edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                });
            }
            Node::IntrinsicCall {
                intrinsic: Intrinsic::Max,
                args: _,
            } if !is_smallest(editor, init) => {
                editor.edit(|mut edit| {
                    let smallest = edit.add_smallest_constant(typing[init.idx()]);
                    let smallest = edit.add_node(Node::Constant { id: smallest });
                    edit.sub_edit(id, smallest);
                    edit = edit.replace_all_uses_where(init, smallest, |u| *u == id)?;
                    let final_op = edit.add_node(Node::IntrinsicCall {
                        intrinsic: Intrinsic::Max,
                        args: Box::new([init, id]),
                    });
                    for u in out_uses {
                        edit.sub_edit(u, final_op);
                    }
                    edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                });
            }
            Node::IntrinsicCall {
                intrinsic: Intrinsic::Min,
                args: _,
            } if !is_largest(editor, init) => {
                editor.edit(|mut edit| {
                    let largest = edit.add_largest_constant(typing[init.idx()]);
                    let largest = edit.add_node(Node::Constant { id: largest });
                    edit.sub_edit(id, largest);
                    edit = edit.replace_all_uses_where(init, largest, |u| *u == id)?;
                    let final_op = edit.add_node(Node::IntrinsicCall {
                        intrinsic: Intrinsic::Min,
                        args: Box::new([init, id]),
                    });
                    for u in out_uses {
                        edit.sub_edit(u, final_op);
                    }
                    edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
                });
            }
            _ => {}
        }
    }
}

/*
 * Extends the dimensions of a fork-join to be a multiple of a number and gates
 * the execution of the body.
 */
pub fn extend_all_forks(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    multiple: usize,
) {
    for (fork, join) in fork_join_map {
        if editor.is_mutable(*fork) {
            extend_fork(editor, *fork, *join, multiple);
        }
    }
}

fn extend_fork(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, multiple: usize) {
    let nodes = &editor.func().nodes;
    let (fork_pred, factors) = nodes[fork.idx()].try_fork().unwrap();
    let factors = factors.to_vec();
    let fork_succ = editor
        .get_users(fork)
        .filter(|id| nodes[id.idx()].is_control())
        .next()
        .unwrap();
    let join_pred = nodes[join.idx()].try_join().unwrap();
    let ctrl_between = fork != join_pred;
    let reduces: Vec<_> = editor
        .get_users(join)
        .filter_map(|id| nodes[id.idx()].try_reduce().map(|x| (id, x)))
        .collect();

    editor.edit(|mut edit| {
        // We can round up a dynamic constant A to a multiple of another dynamic
        // constant B via the following math:
        // ((A + B - 1) / B) * B
        let new_factors: Vec<_> = factors
            .iter()
            .map(|factor| {
                let b = edit.add_dynamic_constant(DynamicConstant::Constant(multiple));
                let apb = edit.add_dynamic_constant(DynamicConstant::add(*factor, b));
                let o = edit.add_dynamic_constant(DynamicConstant::Constant(1));
                let apbmo = edit.add_dynamic_constant(DynamicConstant::sub(apb, o));
                let apbmodb = edit.add_dynamic_constant(DynamicConstant::div(apbmo, b));
                edit.add_dynamic_constant(DynamicConstant::mul(apbmodb, b))
            })
            .collect();

        // Create the new control structure.
        let new_fork = edit.add_node(Node::Fork {
            control: fork_pred,
            factors: new_factors.into_boxed_slice(),
        });
        edit = edit.replace_all_uses_where(fork, new_fork, |id| *id != fork_succ)?;
        edit.sub_edit(fork, new_fork);
        let conds: Vec<_> = factors
            .iter()
            .enumerate()
            .map(|(idx, old_factor)| {
                let tid = edit.add_node(Node::ThreadID {
                    control: new_fork,
                    dimension: idx,
                });
                let old_bound = edit.add_node(Node::DynamicConstant { id: *old_factor });
                edit.add_node(Node::Binary {
                    op: BinaryOperator::LT,
                    left: tid,
                    right: old_bound,
                })
            })
            .collect();
        let cond = conds
            .into_iter()
            .reduce(|left, right| {
                edit.add_node(Node::Binary {
                    op: BinaryOperator::And,
                    left,
                    right,
                })
            })
            .unwrap();
        let branch = edit.add_node(Node::If {
            control: new_fork,
            cond,
        });
        let false_proj = edit.add_node(Node::ControlProjection {
            control: branch,
            selection: 0,
        });
        let true_proj = edit.add_node(Node::ControlProjection {
            control: branch,
            selection: 1,
        });
        if ctrl_between {
            edit = edit.replace_all_uses_where(fork, true_proj, |id| *id == fork_succ)?;
        }
        let bottom_region = edit.add_node(Node::Region {
            preds: Box::new([false_proj, if ctrl_between { join_pred } else { true_proj }]),
        });
        let new_join = edit.add_node(Node::Join {
            control: bottom_region,
        });
        edit = edit.replace_all_uses(join, new_join)?;
        edit.sub_edit(join, new_join);
        edit = edit.delete_node(fork)?;
        edit = edit.delete_node(join)?;

        // Update the reduces to use phis on the region node to gate their execution.
        for (reduce, (_, init, reduct)) in reduces {
            let phi = edit.add_node(Node::Phi {
                control: bottom_region,
                data: Box::new([reduce, reduct]),
            });
            let new_reduce = edit.add_node(Node::Reduce {
                control: new_join,
                init,
                reduct: phi,
            });
            edit = edit.replace_all_uses(reduce, new_reduce)?;
            edit.sub_edit(reduce, new_reduce);
            edit = edit.delete_node(reduce)?;
        }

        Ok(edit)
    });
}