Skip to content
Snippets Groups Projects
editor.rs 29.70 KiB
extern crate bitvec;
extern crate either;
extern crate hercules_ir;
extern crate itertools;

use std::cell::{Ref, RefCell};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::iter::FromIterator;
use std::mem::take;
use std::ops::Deref;

use self::bitvec::prelude::*;
use self::either::Either;
use self::itertools::Itertools;

use self::hercules_ir::antideps::*;
use self::hercules_ir::dataflow::*;
use self::hercules_ir::def_use::*;
use self::hercules_ir::dom::*;
use self::hercules_ir::gcm::*;
use self::hercules_ir::ir::*;
use self::hercules_ir::loops::*;
use self::hercules_ir::schedule::*;
use self::hercules_ir::subgraph::*;

pub type Edit = (HashSet<NodeID>, HashSet<NodeID>);

/*
 * Helper object for editing Hercules functions in a trackable manner. Edits are
 * recorded in order to repair partitions and debug info.
 * Edits must be made atomically, that is, only one `.edit` may be called at a time
 * across all editors.
 */
#[derive(Debug)]
pub struct FunctionEditor<'a> {
    // Wraps a mutable reference to a function. Doesn't provide access to this
    // reference directly, so that we can monitor edits.
    function: &'a mut Function,
    // Keep a RefCell to (dynamic) constants and types to allow function changes
    // to update these
    constants: &'a RefCell<Vec<Constant>>,
    dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
    types: &'a RefCell<Vec<Type>>,
    // Most optimizations need def use info, so provide an iteratively updated
    // mutable version that's automatically updated based on recorded edits.
    mut_def_use: Vec<HashSet<NodeID>>,
    // Record edits as a mapping from sets of node IDs to sets of node IDs. The
    // sets on the "left" side of this map should be mutually disjoint, and the
    // sets on the "right" side should also be mutually disjoint. All of the
    // node IDs on the left side should be deleted node IDs or IDs of unmodified
    // nodes, and all of the node IDs on the right side should be added node IDs
    // or IDs of unmodified nodes. In other words, there should be no added node
    // IDs on the left side, and no deleted node IDs on the right side. These
    // mappings are stored sequentially in a list, rather than as a map. This is
    // because a transformation may iteratively update a function - i.e., a node
    // ID added in iteration N may be deleted in iteration N + M. To maintain a
    // more precise history of edits, we store each edit individually, which
    // allows us to make more precise repairs of partitions and debug info.
    edits: Vec<Edit>,
    // The pass manager may indicate that only a certain subset of nodes should
    // be modified in a function - what this actually means is that some nodes
    // are off limits for deletion (equivalently modification) or being replaced
    // as a use.
    mutable_nodes: BitVec<u8, Lsb0>,
}

/*
 * Helper objects to make a single edit.
 */
#[derive(Debug)]
pub struct FunctionEdit<'a: 'b, 'b> {
    // Reference the active function editor.
    editor: &'b mut FunctionEditor<'a>,
    // Keep track of deleted node IDs.
    deleted_nodeids: HashSet<NodeID>,
    // Keep track of added node IDs.
    added_nodeids: HashSet<NodeID>,
    // Keep track of added and use updated nodes.
    added_and_updated_nodes: BTreeMap<NodeID, Node>,
    // Keep track of added (dynamic) constants and types
    added_constants: Vec<Constant>,
    added_dynamic_constants: Vec<DynamicConstant>,
    added_types: Vec<Type>,
    // Compute a def-use map entries iteratively.
    updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
    updated_return_type: Option<TypeID>,
}

impl<'a: 'b, 'b> FunctionEditor<'a> {
    pub fn new(
        function: &'a mut Function,
        constants: &'a RefCell<Vec<Constant>>,
        dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
        types: &'a RefCell<Vec<Type>>,
        def_use: &ImmutableDefUseMap,
    ) -> Self {
        let mut_def_use = (0..function.nodes.len())
            .map(|idx| {
                def_use
                    .get_users(NodeID::new(idx))
                    .into_iter()
                    .map(|x| *x)
                    .collect()
            })
            .collect();
        let mutable_nodes = bitvec![u8, Lsb0; 1; function.nodes.len()];

        FunctionEditor {
            function,
            constants,
            dynamic_constants,
            types,
            mut_def_use,
            edits: vec![],
            mutable_nodes,
        }
    }

    pub fn edit<F>(&'b mut self, edit: F) -> bool
    where
        F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>,
    {
        // Create the edit helper struct and perform the edit using it.
        let edit_obj = FunctionEdit {
            editor: self,
            deleted_nodeids: HashSet::new(),
            added_nodeids: HashSet::new(),
            added_constants: Vec::new().into(),
            added_dynamic_constants: Vec::new().into(),
            added_types: Vec::new().into(),
            added_and_updated_nodes: BTreeMap::new(),
            updated_def_use: BTreeMap::new(),
            updated_return_type: None,
        };

        if let Ok(populated_edit) = edit(edit_obj) {
            // If the populated edit is returned, then the edit can be performed
            // without modifying immutable nodes.
            let FunctionEdit {
                editor,
                deleted_nodeids,
                added_nodeids,
                added_constants,
                added_dynamic_constants,
                added_types,
                added_and_updated_nodes: added_and_updated,
                updated_def_use,
                updated_return_type,
            } = populated_edit;
            // Step 1: update the mutable def use map.
            for (u, new_users) in updated_def_use {
                // Go through new def-use entries in order. These are either
                // updates to existing nodes, in which case we just modify them
                // in place, or they are user entries for new nodes, in which
                // case we push them.
                if u.idx() < editor.mut_def_use.len() {
                    editor.mut_def_use[u.idx()] = new_users;
                } else {
                    // The new nodes must be traversed in order in a packed
                    // fashion - if a new node was created without an
                    // accompanying new users entry, something is wrong!
                    assert_eq!(editor.mut_def_use.len(), u.idx());
                    editor.mut_def_use.push(new_users);
                }
            }

            // Step 2: add and update nodes.
            for (id, node) in added_and_updated {
                if id.idx() < editor.function.nodes.len() {
                    editor.function.nodes[id.idx()] = node;
                } else {
                    // New nodes should've been assigned increasing IDs starting
                    // at the previous number of nodes, so check that.
                    assert_eq!(editor.function.nodes.len(), id.idx());
                    editor.function.nodes.push(node);
                }
            }

            // Step 3: delete nodes. This is done using "gravestones", where a
            // node other than node ID 0 being a start node is considered a
            // gravestone.
            for id in deleted_nodeids.iter() {
                // Check that there are no users of deleted nodes.
                assert!(editor.mut_def_use[id.idx()].is_empty());
                editor.function.nodes[id.idx()] = Node::Start;
            }

            // Step 4: add a single edit to the edit list.
            editor.edits.push((deleted_nodeids, added_nodeids));

            // Step 5: update the length of mutable_nodes. All added nodes are
            // mutable.
            editor
                .mutable_nodes
                .resize(editor.function.nodes.len(), true);

            // Step 6: update types and constants
            let mut editor_constants = editor.constants.borrow_mut();
            let mut editor_dynamic_constants = editor.dynamic_constants.borrow_mut();
            let mut editor_types = editor.types.borrow_mut();

            editor_constants.extend(added_constants);
            editor_dynamic_constants.extend(added_dynamic_constants);
            editor_types.extend(added_types);

            // Step 7: update return type if necessary
            if let Some(return_type) = updated_return_type {
                editor.function.return_type = return_type;
            }
            true
        } else {
            false
        }
    }

    pub fn func(&self) -> &Function {
        &self.function
    }

    pub fn get_users(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ {
        self.mut_def_use[id.idx()].iter().map(|x| *x)
    }

    pub fn get_type(&self, id: TypeID) -> Ref<'_, Type> {
        Ref::map(self.types.borrow(), |types| &types[id.idx()])
    }

    pub fn get_constant(&self, id: ConstantID) -> Ref<'_, Constant> {
        Ref::map(self.constants.borrow(), |constants| &constants[id.idx()])
    }

    pub fn get_dynamic_constant(&self, id: DynamicConstantID) -> Ref<'_, DynamicConstant> {
        Ref::map(self.dynamic_constants.borrow(), |dynamic_constants| {
            &dynamic_constants[id.idx()]
        })
    }

    pub fn is_mutable(&self, id: NodeID) -> bool {
        self.mutable_nodes[id.idx()]
    }

    pub fn edits(self) -> Vec<Edit> {
        self.edits
    }
}

impl<'a, 'b> FunctionEdit<'a, 'b> {
    fn ensure_updated_def_use_entry(&mut self, id: NodeID) {
        if !self.updated_def_use.contains_key(&id) {
            let old_entry = self
                .editor
                .mut_def_use
                .get(id.idx())
                .map(|entry| entry.clone())
                .unwrap_or_default();
            self.updated_def_use.insert(id, old_entry);
        }
    }

    fn is_mutable(&self, id: NodeID) -> bool {
        id.idx() >= self.editor.mutable_nodes.len() || self.editor.mutable_nodes[id.idx()]
    }

    pub fn add_node(&mut self, node: Node) -> NodeID {
        let id = NodeID::new(self.editor.function.nodes.len() + self.added_nodeids.len());
        // Added nodes need to have an entry in the def-use map.
        self.updated_def_use.insert(id, HashSet::new());
        // Added nodes use other nodes, and we need to update their def-use
        // entries.
        for u in get_uses(&node).as_ref() {
            self.ensure_updated_def_use_entry(*u);
            self.updated_def_use.get_mut(u).unwrap().insert(id);
        }
        // Add the node.
        self.added_and_updated_nodes.insert(id, node);
        self.added_nodeids.insert(id);
        id
    }
    pub fn delete_node(mut self, id: NodeID) -> Result<Self, Self> {
        // We can only delete mutable nodes. Return None if we try to modify an
        // immutable node, as it means the whole edit should be aborted.
        if self.is_mutable(id) {
            assert!(
                !self.added_nodeids.contains(&id),
                "PANIC: Please don't delete a node that was added in the same edit."
            );
            // Deleted nodes use other nodes, and we need to update their def-
            // use entries.
            let uses: Box<[NodeID]> = get_uses(&self.editor.function.nodes[id.idx()])
                .as_ref()
                .into();
            for u in uses {
                self.ensure_updated_def_use_entry(u);
                self.updated_def_use.get_mut(&u).unwrap().remove(&id);
            }
            self.deleted_nodeids.insert(id);
            Ok(self)
        } else {
            Err(self)
        }
    }

    pub fn replace_all_uses(mut self, old: NodeID, new: NodeID) -> Result<Self, Self> {
        // We can only replace uses of mutable nodes. Return None if we try to
        // replace uses of an immutable node, as it means the whole edit should
        // be aborted.
        if self.is_mutable(old) {
            // Update all of the users of the old node.
            self.ensure_updated_def_use_entry(old);
            for user_id in self.updated_def_use[&old].iter() {
                // Replace uses of old with new.
                let mut updated_user = self.get_node(*user_id).clone();
                for u in get_uses_mut(&mut updated_user).as_mut() {
                    if **u == old {
                        **u = new;
                    }
                }
                // Add the updated user to added_and_updated.
                self.added_and_updated_nodes.insert(*user_id, updated_user);
            }

            // All of the users of the old node become users of the new node, so
            // move all of the entries in the def-use from the old to the new.
            let old_entries = take(self.updated_def_use.get_mut(&old).unwrap());
            self.ensure_updated_def_use_entry(new);
            self.updated_def_use
                .get_mut(&new)
                .unwrap()
                .extend(old_entries);

            Ok(self)
        } else {
            Err(self)
        }
    }

    pub fn get_node(&self, id: NodeID) -> &Node {
        assert!(!self.deleted_nodeids.contains(&id));
        if let Some(node) = self.added_and_updated_nodes.get(&id) {
            // Refer to added or updated node. This node is guaranteed to be
            // updated with uses after replace_all_uses is called.
            node
        } else {
            // Refer to the original node.
            &self.editor.function.nodes[id.idx()]
        }
    }
    pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
        assert!(!self.deleted_nodeids.contains(&id));
        if let Some(users) = self.updated_def_use.get(&id) {
            // Refer to the updated users set.
            Either::Left(users.iter().map(|x| *x))
        } else {
            // Refer to the original users set.
            Either::Right(self.editor.mut_def_use[id.idx()].iter().map(|x| *x))
        }
    }

    pub fn add_type(&mut self, ty: Type) -> TypeID {
        let pos = self
            .editor
            .types
            .borrow()
            .iter()
            .chain(self.added_types.iter())
            .position(|t| *t == ty);
        if let Some(idx) = pos {
            TypeID::new(idx)
        } else {
            let id = TypeID::new(self.editor.types.borrow().len() + self.added_types.len());
            self.added_types.push(ty);
            id
        }
    }

    pub fn get_type(&self, id: TypeID) -> impl Deref + '_ {
        if id.idx() < self.editor.types.borrow().len() {
            Either::Left(Ref::map(self.editor.types.borrow(), |types| {
                &types[id.idx()]
            }))
        } else {
            Either::Right(
                self.added_types
                    .get(id.idx() - self.editor.types.borrow().len())
                    .unwrap(),
            )
        }
    }

    pub fn add_constant(&mut self, constant: Constant) -> ConstantID {
        let pos = self
            .editor
            .constants
            .borrow()
            .iter()
            .chain(self.added_constants.iter())
            .position(|c| *c == constant);
        if let Some(idx) = pos {
            ConstantID::new(idx)
        } else {
            let id =
                ConstantID::new(self.editor.constants.borrow().len() + self.added_constants.len());
            self.added_constants.push(constant);
            id
        }
    }

    pub fn get_constant(&self, id: ConstantID) -> impl Deref + '_ {
        if id.idx() < self.editor.constants.borrow().len() {
            Either::Left(Ref::map(self.editor.constants.borrow(), |constants| {
                &constants[id.idx()]
            }))
        } else {
            Either::Right(
                self.added_constants
                    .get(id.idx() - self.editor.constants.borrow().len())
                    .unwrap(),
            )
        }
    }

    pub fn add_dynamic_constant(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID {
        let pos = self
            .editor
            .dynamic_constants
            .borrow()
            .iter()
            .chain(self.added_dynamic_constants.iter())
            .position(|c| *c == dynamic_constant);
        if let Some(idx) = pos {
            DynamicConstantID::new(idx)
        } else {
            let id = DynamicConstantID::new(
                self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(),
            );
            self.added_dynamic_constants.push(dynamic_constant);
            id
        }
    }

    pub fn get_dynamic_constant(&self, id: DynamicConstantID) -> impl Deref + '_ {
        if id.idx() < self.editor.dynamic_constants.borrow().len() {
            Either::Left(Ref::map(
                self.editor.dynamic_constants.borrow(),
                |dynamic_constants| &dynamic_constants[id.idx()],
            ))
        } else {
            Either::Right(
                self.added_dynamic_constants
                    .get(id.idx() - self.editor.dynamic_constants.borrow().len())
                    .unwrap(),
            )
        }
    }

    pub fn set_return_type(&mut self, ty: TypeID) {
        self.updated_return_type = Some(ty);
    }
}

/*
 * Simplify an edit sequence into a single, larger, edit.
 */
fn collapse_edits(edits: &[Edit]) -> Edit {
    let mut total_edit = Edit::default();

    for edit in edits {
        assert!(edit.0.is_disjoint(&edit.1), "PANIC: Edit sequence is malformed - can't add and delete the same node ID in a single edit.");
        assert!(
            total_edit.0.is_disjoint(&edit.0),
            "PANIC: Edit sequence is malformed - can't delete the same node ID twice."
        );
        assert!(
            total_edit.1.is_disjoint(&edit.1),
            "PANIC: Edit sequence is malformed - can't add the same node ID twice."
        );

        for delete in edit.0.iter() {
            total_edit.0.insert(*delete);
            total_edit.1.remove(delete);
        }

        for addition in edit.1.iter() {
            total_edit.0.remove(addition);
            total_edit.1.insert(*addition);
        }
    }

    total_edit
}

/*
 * Plans can be repaired - this entails repairing schedules as well as
 * partitions. `new_function` is the function after the edits have occurred, but
 * before gravestones have been removed.
 */
pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
    // Step 1: collapse all of the edits into a single edit. For repairing
    // partitions, we don't need to consider the intermediate edit states.
    let total_edit = collapse_edits(edits);

    // Step 2: drop schedules for deleted nodes and create empty schedule lists
    // for added nodes.
    for deleted in total_edit.0.iter() {
        plan.schedules[deleted.idx()] = vec![];
    }
    if !total_edit.1.is_empty() {
        assert_eq!(
            total_edit.1.iter().max().unwrap().idx() + 1,
            new_function.nodes.len()
        );
        plan.schedules.resize(new_function.nodes.len(), vec![]);
    }

    // Step 3: figure out the order to add nodes to partitions. Roughly, we look
    // at the added nodes in reverse postorder and partition by control/data. We
    // first add control nodes to partitions using node-specific rules. We then
    // add data nodes based on the partitions of their immediate control uses
    // and users.
    let def_use = def_use(new_function);
    let rev_po = reverse_postorder(&def_use);
    let added_control_nodes: Vec<NodeID> = rev_po
        .iter()
        .filter(|id| total_edit.1.contains(id) && new_function.nodes[id.idx()].is_control())
        .map(|id| *id)
        .collect();
    let added_data_nodes: Vec<NodeID> = rev_po
        .iter()
        .filter(|id| total_edit.1.contains(id) && !new_function.nodes[id.idx()].is_control())
        .map(|id| *id)
        .collect();

    // Step 4: figure out the partitions for added control nodes.
    // Do a bunch of analysis that basically boils down to finding what fork-
    // joins are top-level.
    let control_subgraph = control_subgraph(new_function, &def_use);
    let dom = dominator(&control_subgraph, NodeID::new(0));
    let fork_join_map = fork_join_map(new_function, &control_subgraph);
    let fork_join_nesting = compute_fork_join_nesting(new_function, &dom, &fork_join_map);
    // While building, the new partitions map uses Option since we don't have
    // partitions for new nodes yet, and we need to record that specifically for
    // computing the partitions of region nodes.
    let mut new_partitions: Vec<Option<PartitionID>> = take(&mut plan.partitions)
        .into_iter()
        .map(|part| Some(part))
        .collect();
    new_partitions.resize(new_function.nodes.len(), None);
    // Iterate the added control nodes using a worklist.
    let mut worklist = VecDeque::from(added_control_nodes);
    while let Some(control_id) = worklist.pop_front() {
        let node = &new_function.nodes[control_id.idx()];
        // There are a few cases where this control node needs to start a new
        // partition:
        // 1. It's a non-gravestone start node. This is any start node visited
        //    by the reverse postorder.
        // 2. It's a return node.
        // 3. It's a top-level fork.
        // 4. One of its control predecessors is a top-level join.
        // 5. It's a region node where not every predecessor is in the same
        //    partition (equivalently, not every predecessor is in the same
        //    partition - only region nodes can have multiple predecessors).
        // 6. It's a region node with a call user.
        // 7. Its predecessor is a region node with a call user.
        let top_level_fork = node.is_fork() && fork_join_nesting[&control_id].len() == 1;
        let top_level_join = control_subgraph.preds(control_id).any(|pred| {
            new_function.nodes[pred.idx()].is_join() && fork_join_nesting[&pred].len() == 1
        });
        // It's not possible for every predecessor to not have been assigned a
        // partition yet because of reverse postorder traversal.
        let multi_pred_region = !control_subgraph
            .preds(control_id)
            .map(|pred| new_partitions[pred.idx()])
            .all_equal();
        let region_with_call_user = |id: NodeID| {
            new_function.nodes[id.idx()].is_region()
                && def_use
                    .get_users(id)
                    .as_ref()
                    .into_iter()
                    .any(|id| new_function.nodes[id.idx()].is_call())
        };
        let call_region = region_with_call_user(control_id);
        let pred_is_call_region = control_subgraph
            .preds(control_id)
            .any(|pred| region_with_call_user(pred));

        if node.is_start()
            || node.is_return()
            || top_level_fork
            || top_level_join
            || multi_pred_region
            || call_region
            || pred_is_call_region
        {
            // This control node goes in a new partition.
            let part_id = PartitionID::new(plan.num_partitions);
            plan.num_partitions += 1;
            new_partitions[control_id.idx()] = Some(part_id);
        } else {
            // This control node goes in the partition of any one of its
            // predecessors. They're all the same by condition 3 above.
            let any_pred = control_subgraph.preds(control_id).next().unwrap();
            if new_partitions[any_pred.idx()].is_some() {
                new_partitions[control_id.idx()] = new_partitions[any_pred.idx()];
            } else {
                worklist.push_back(control_id);
            }
        }
    }

    // Step 5: figure out the partitions for added data nodes.
    let antideps = antideps(&new_function, &def_use);
    let loops = loops(&control_subgraph, NodeID::new(0), &dom, &fork_join_map);
    let bbs = gcm(
        new_function,
        &def_use,
        &rev_po,
        &dom,
        &antideps,
        &loops,
        &fork_join_map,
        &mut new_partitions,
    );
    let added_and_to_repartition_data_nodes: Vec<NodeID> = new_partitions
        .iter()
        .enumerate()
        .filter(|(_, part)| part.is_none())
        .map(|(idx, _)| NodeID::new(idx))
        .collect();
    for data_id in added_and_to_repartition_data_nodes {
        new_partitions[data_id.idx()] = new_partitions[bbs[data_id.idx()].idx()];
    }

    // Step 6: wrap everything up.
    plan.partitions = new_partitions.into_iter().map(|id| id.unwrap()).collect();
    plan.partition_devices
        .resize(plan.num_partitions, Device::CPU);
    // Place call partitions on the "AsyncRust" device.
    for idx in 0..new_function.nodes.len() {
        if new_function.nodes[idx].is_call() {
            plan.partition_devices[plan.partitions[idx].idx()] = Device::AsyncRust;
        }
    }
}

/*
 * Default plans can be constructed by conservatively inferring schedules and
 * creating partitions by "repairing" a partition where the edit is adding every
 * node in the function.
 */
pub fn default_plan(
    function: &Function,
    dynamic_constants: &Vec<DynamicConstant>,
    def_use: &ImmutableDefUseMap,
    reverse_postorder: &Vec<NodeID>,
    fork_join_map: &HashMap<NodeID, NodeID>,
    fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>,
    bbs: &Vec<NodeID>,
) -> Plan {
    // Start by creating a completely bare-bones plan doing nothing interesting.
    let mut plan = Plan {
        schedules: vec![vec![]; function.nodes.len()],
        partitions: vec![],
        partition_devices: vec![],
        num_partitions: 1,
    };

    // Infer a partitioning by using `repair_plan`, where the "edit" is creating
    // the entire function.
    let edit = (
        HashSet::new(),
        HashSet::from_iter((0..function.nodes.len()).map(NodeID::new)),
    );
    repair_plan(&mut plan, function, &[edit]);
    plan.renumber_partitions();

    // Infer schedules.
    infer_parallel_reduce(function, fork_join_map, &mut plan);
    infer_parallel_fork(
        function,
        def_use,
        fork_join_map,
        fork_join_nesting,
        &mut plan,
    );
    infer_vectorizable(function, dynamic_constants, fork_join_map, &mut plan);
    infer_associative(function, &mut plan);

    // TODO: uncomment once GPU backend is implemented.
    // place_fork_partitions_on_gpu(function, &mut plan);

    plan
}

#[cfg(test)]
mod editor_tests {
    #[allow(unused_imports)]
    use super::*;

    use std::mem::replace;

    use self::hercules_ir::parse::parse;

    fn canonicalize(function: &mut Function) -> Vec<Option<NodeID>> {
        // The reverse postorder traversal from the Start node is a map from new
        // index to old ID.
        let rev_po = reverse_postorder(&def_use(function));
        let num_new_nodes = rev_po.len();

        // Construct a map from old ID to new ID.
        let mut old_to_new = vec![None; function.nodes.len()];
        for (new_idx, old_id) in rev_po.into_iter().enumerate() {
            old_to_new[old_id.idx()] = Some(NodeID::new(new_idx));
        }

        // Move the old nodes before permuting them.
        let mut old_nodes = take(&mut function.nodes);
        function.nodes = vec![Node::Start; num_new_nodes];

        // Permute the old nodes back into the function and fix their uses.
        for (old_idx, new_id) in old_to_new.iter().enumerate() {
            // Check if this old node is in the canonicalized form.
            if let Some(new_id) = new_id {
                // Get the old node.
                let mut node = replace(&mut old_nodes[old_idx], Node::Start);

                // Fix its uses.
                for u in get_uses_mut(&mut node).as_mut() {
                    // Map every use using the old-to-new map. If we try to use
                    // a node that doesn't have a mapping, then the original IR
                    // had a node reachable from the start using another node
                    // not reachable from the start, which is malformed.
                    **u = old_to_new[u.idx()].unwrap();
                }

                // Insert the fixed node into its new spot.
                function.nodes[new_id.idx()] = node;
            }
        }

        old_to_new
    }

    #[test]
    fn example1() {
        // Define the original function.
        let mut src_module = parse(
            "
fn func(x: i32) -> i32
  c = constant(i32, 7)
  y = add(x, c)
  r = return(start, y)
",
        )
        .unwrap();

        // Find the ID of the add node and its uses.
        let func = &mut src_module.functions[0];
        let (add, left, right) = func
            .nodes
            .iter()
            .enumerate()
            .filter_map(|(idx, node)| {
                node.try_binary(BinaryOperator::Add)
                    .map(|(left, right)| (NodeID::new(idx), left, right))
            })
            .next()
            .unwrap();

        let constants_ref = RefCell::new(src_module.constants);
        let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants);
        let types_ref = RefCell::new(src_module.types);
        // Edit the function by replacing the add with a multiply.
        let mut editor = FunctionEditor::new(
            func,
            &constants_ref,
            &dynamic_constants_ref,
            &types_ref,
            &def_use(func),
        );
        let success = editor.edit(|mut edit| {
            let mul = edit.add_node(Node::Binary {
                op: BinaryOperator::Mul,
                left,
                right,
            });
            let edit = edit.replace_all_uses(add, mul)?;
            let edit = edit.delete_node(add)?;
            Ok(edit)
        });
        assert!(success);

        // Canonicalize the function.
        canonicalize(func);

        // Check that the function is correct.
        let mut dst_module = parse(
            "
fn func(x: i32) -> i32
  c = constant(i32, 7)
  y = mul(x, c)
  r = return(start, y)
",
        )
        .unwrap();
        canonicalize(&mut dst_module.functions[0]);
        assert_eq!(src_module.functions[0].nodes, dst_module.functions[0].nodes);
    }
}