Skip to content
Snippets Groups Projects
editor.rs 38.15 KiB
use std::borrow::Borrow;
use std::cell::{Ref, RefCell};
use std::collections::{BTreeMap, HashSet};
use std::mem::take;
use std::ops::Deref;

use bitvec::prelude::*;
use either::Either;

use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use hercules_ir::DynamicConstantView;

/*
 * Helper object for editing Hercules functions in a trackable manner. Edits
 * must be made atomically, that is, only one `.edit` may be called at a time
 * across all editors, and individual edits must leave the function in a valid
 * state.
 */
#[derive(Debug)]
pub struct FunctionEditor<'a> {
    // Wraps a mutable reference to a function. Doesn't provide access to this
    // reference directly, so that we can monitor edits.
    function: &'a mut Function,
    function_id: FunctionID,
    // Keep a RefCell to (dynamic) constants and types to allow function changes
    // to update these
    constants: &'a RefCell<Vec<Constant>>,
    dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
    types: &'a RefCell<Vec<Type>>,
    // Keep a RefCell to the string table that tracks labels, so that new labels
    // can be added as needed
    labels: &'a RefCell<Vec<String>>,
    // Most optimizations need def use info, so provide an iteratively updated
    // mutable version that's automatically updated based on recorded edits.
    mut_def_use: Vec<HashSet<NodeID>>,
    // The pass manager may indicate that only a certain subset of nodes should
    // be modified in a function - what this actually means is that some nodes
    // are off limits for deletion (equivalently modification) or being replaced
    // as a use.
    mutable_nodes: BitVec<u8, Lsb0>,
    // Tracks whether this editor has been used to make any edits to the IR of
    // this function
    modified: bool,
}

/*
 * Helper objects to make a single edit.
 */
#[derive(Debug)]
pub struct FunctionEdit<'a: 'b, 'b> {
    // Reference the active function editor.
    editor: &'b mut FunctionEditor<'a>,
    // Keep track of deleted node IDs.
    deleted_nodeids: HashSet<NodeID>,
    // Keep track of added node IDs.
    added_nodeids: HashSet<NodeID>,
    // Keep track of added and use updated nodes.
    added_and_updated_nodes: BTreeMap<NodeID, Node>,
    // Keep track of added and updated schedules.
    added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>,
    // Keep track of added and updated labels.
    added_and_updated_labels: BTreeMap<NodeID, HashSet<LabelID>>,
    // Keep track of added (dynamic) constants, types, and labels
    added_constants: Vec<Constant>,
    added_dynamic_constants: Vec<DynamicConstant>,
    added_types: Vec<Type>,
    added_labels: Vec<String>,
    // Compute a def-use map entries iteratively.
    updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
    updated_param_types: Option<Vec<TypeID>>,
    updated_return_types: Option<Vec<TypeID>>,
    // Keep track of which deleted and added node IDs directly correspond.
    sub_edits: Vec<(NodeID, NodeID)>,
}

impl<'a: 'b, 'b> FunctionEditor<'a> {
    pub fn new(
        function: &'a mut Function,
        function_id: FunctionID,
        constants: &'a RefCell<Vec<Constant>>,
        dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
        types: &'a RefCell<Vec<Type>>,
        labels: &'a RefCell<Vec<String>>,
        def_use: &ImmutableDefUseMap,
    ) -> Self {
        let mut_def_use = (0..function.nodes.len())
            .map(|idx| {
                def_use
                    .get_users(NodeID::new(idx))
                    .into_iter()
                    .map(|x| *x)
                    .collect()
            })
            .collect();
        let mutable_nodes = bitvec![u8, Lsb0; 1; function.nodes.len()];

        FunctionEditor {
            function,
            function_id,
            constants,
            dynamic_constants,
            types,
            labels,
            mut_def_use,
            mutable_nodes,
            modified: false,
        }
    }

    // Constructs an editor with a specified mask determining which nodes are mutable
    pub fn new_mask(
        function: &'a mut Function,
        function_id: FunctionID,
        constants: &'a RefCell<Vec<Constant>>,
        dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
        types: &'a RefCell<Vec<Type>>,
        labels: &'a RefCell<Vec<String>>,
        def_use: &ImmutableDefUseMap,
        mask: BitVec<u8, Lsb0>,
    ) -> Self {
        let mut_def_use = (0..function.nodes.len())
            .map(|idx| {
                def_use
                    .get_users(NodeID::new(idx))
                    .into_iter()
                    .map(|x| *x)
                    .collect()
            })
            .collect();

        FunctionEditor {
            function,
            function_id,
            constants,
            dynamic_constants,
            types,
            labels,
            mut_def_use,
            mutable_nodes: mask,
            modified: false,
        }
    }

    // Constructs an editor but makes every node immutable.
    pub fn new_immutable(
        function: &'a mut Function,
        function_id: FunctionID,
        constants: &'a RefCell<Vec<Constant>>,
        dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
        types: &'a RefCell<Vec<Type>>,
        labels: &'a RefCell<Vec<String>>,
        def_use: &ImmutableDefUseMap,
    ) -> Self {
        let mut_def_use = (0..function.nodes.len())
            .map(|idx| {
                def_use
                    .get_users(NodeID::new(idx))
                    .into_iter()
                    .map(|x| *x)
                    .collect()
            })
            .collect();

        let mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()];

        FunctionEditor {
            function,
            function_id,
            constants,
            dynamic_constants,
            types,
            labels,
            mut_def_use,
            mutable_nodes,
            modified: false,
        }
    }

    pub fn modified(&self) -> bool {
        self.modified
    }

    pub fn edit<F>(&'b mut self, edit: F) -> bool
    where
        F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>,
    {
        // Create the edit helper struct and perform the edit using it.
        let edit_obj = FunctionEdit {
            editor: self,
            deleted_nodeids: HashSet::new(),
            added_nodeids: HashSet::new(),
            added_and_updated_nodes: BTreeMap::new(),
            added_and_updated_schedules: BTreeMap::new(),
            added_and_updated_labels: BTreeMap::new(),
            added_constants: Vec::new().into(),
            added_dynamic_constants: Vec::new().into(),
            added_types: Vec::new().into(),
            added_labels: Vec::new().into(),
            updated_def_use: BTreeMap::new(),
            updated_param_types: None,
            updated_return_types: None,
            sub_edits: vec![],
        };

        if let Ok(populated_edit) = edit(edit_obj) {
            // If the populated edit is returned, then the edit can be performed
            // without modifying immutable nodes.
            let FunctionEdit {
                editor,
                deleted_nodeids,
                added_nodeids,
                added_and_updated_nodes,
                added_and_updated_schedules,
                added_and_updated_labels,
                added_constants,
                added_dynamic_constants,
                added_types,
                added_labels,
                updated_def_use,
                updated_param_types,
                updated_return_types,
                sub_edits,
            } = populated_edit;

            // Step 0: determine whether the edit changed the IR by checking if
            // any nodes were deleted, added, or updated in any way
            editor.modified |= !deleted_nodeids.is_empty()
                || !added_nodeids.is_empty()
                || !added_and_updated_nodes.is_empty()
                || !added_and_updated_schedules.is_empty()
                || !added_and_updated_labels.is_empty();

            // Step 1: update the mutable def use map.
            for (u, new_users) in updated_def_use {
                // Go through new def-use entries in order. These are either
                // updates to existing nodes, in which case we just modify them
                // in place, or they are user entries for new nodes, in which
                // case we push them.
                if u.idx() < editor.mut_def_use.len() {
                    editor.mut_def_use[u.idx()] = new_users;
                } else {
                    // The new nodes must be traversed in order in a packed
                    // fashion - if a new node was created without an
                    // accompanying new users entry, something is wrong!
                    assert_eq!(editor.mut_def_use.len(), u.idx());
                    editor.mut_def_use.push(new_users);
                }
            }

            // Step 2: add and update nodes.
            for (id, node) in added_and_updated_nodes {
                if id.idx() < editor.function.nodes.len() {
                    editor.function.nodes[id.idx()] = node;
                } else {
                    // New nodes should've been assigned increasing IDs starting
                    // at the previous number of nodes, so check that.
                    assert_eq!(editor.function.nodes.len(), id.idx());
                    editor.function.nodes.push(node);
                }
            }

            // Step 3.0: add and update schedules.
            editor
                .function
                .schedules
                .resize(editor.function.nodes.len(), vec![]);
            for (id, schedule) in added_and_updated_schedules {
                editor.function.schedules[id.idx()] = schedule;
            }

            // Step 3.1: add and update labels.
            editor
                .function
                .labels
                .resize(editor.function.nodes.len(), HashSet::new());
            for (id, label) in added_and_updated_labels {
                editor.function.labels[id.idx()] = label;
            }
            // Step 4: delete nodes. This is done using "gravestones", where a
            // node other than node ID 0 being a start node is considered a
            // gravestone.
            for id in deleted_nodeids.iter() {
                // Check that there are no users of deleted nodes.
                assert!(
                    editor.mut_def_use[id.idx()].is_empty(),
                    "PANIC: Attempted to delete node {:?}, but there are still users of this node ({:?}).",
                    id, editor.mut_def_use[id.idx()]
                );
                editor.function.nodes[id.idx()] = Node::Start;
            }

            // Step 5.0: propagate schedules along sub-edit edges.
            for (src, dst) in sub_edits.iter() {
                let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]);
                for src_schedule in editor.function.schedules[src.idx()].iter() {
                    if !dst_schedules.contains(src_schedule) {
                        dst_schedules.push(src_schedule.clone());
                    }
                }
                editor.function.schedules[dst.idx()] = dst_schedules;
            }

            // Step 5.1: update and propagate labels
            editor.labels.borrow_mut().extend(added_labels);

            // We propagate labels in two steps, first along sub-edits and then
            // all the labels on any deleted node not used in any sub-edit to all
            // added nodes not in any sub-edit
            let mut sources = deleted_nodeids.clone();
            let mut dests = added_nodeids.clone();

            for (src, dst) in sub_edits {
                let mut dst_labels = take(&mut editor.function.labels[dst.idx()]);
                dst_labels.extend(editor.function.labels[src.idx()].iter());
                editor.function.labels[dst.idx()] = dst_labels;

                sources.remove(&src);
                dests.remove(&dst);
            }

            let mut src_labels = HashSet::new();
            for src in sources {
                src_labels.extend(editor.function.labels[src.idx()].clone());
            }
            for dst in dests {
                editor.function.labels[dst.idx()].extend(src_labels.clone());
            }

            // Step 6: update the length of mutable_nodes. All added nodes are
            // mutable.
            editor
                .mutable_nodes
                .resize(editor.function.nodes.len(), true);

            // Step 7: update types and constants.
            let mut editor_constants = editor.constants.borrow_mut();
            let mut editor_dynamic_constants = editor.dynamic_constants.borrow_mut();
            let mut editor_types = editor.types.borrow_mut();

            editor_constants.extend(added_constants);
            editor_dynamic_constants.extend(added_dynamic_constants);
            editor_types.extend(added_types);

            // Step 8: update parameter types if necessary.
            if let Some(param_types) = updated_param_types {
                editor.function.param_types = param_types;
            }
            // Step 9: update return type if necessary.
            if let Some(return_types) = updated_return_types {
                editor.function.return_types = return_types;
            }

            true
        } else {
            false
        }
    }

    pub fn func(&self) -> &Function {
        &self.function
    }

    pub fn func_id(&self) -> FunctionID {
        self.function_id
    }

    pub fn node(&self, node: impl Borrow<NodeID>) -> &Node {
        &self.function.nodes[node.borrow().idx()]
    }

    pub fn get_types(&self) -> Ref<'_, Vec<Type>> {
        self.types.borrow()
    }

    pub fn get_constants(&self) -> Ref<'_, Vec<Constant>> {
        self.constants.borrow()
    }

    pub fn get_dynamic_constants(&self) -> Ref<'_, Vec<DynamicConstant>> {
        self.dynamic_constants.borrow()
    }

    pub fn get_users(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ {
        self.mut_def_use[id.idx()].iter().map(|x| *x)
    }

    pub fn get_uses(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ {
        get_uses(&self.function.nodes[id.idx()])
            .as_ref()
            .into_iter()
            .map(|x| *x)
            .collect::<Vec<_>>()
            .into_iter()
    }

    pub fn get_type(&self, id: TypeID) -> Ref<'_, Type> {
        Ref::map(self.types.borrow(), |types| &types[id.idx()])
    }

    pub fn get_constant(&self, id: ConstantID) -> Ref<'_, Constant> {
        Ref::map(self.constants.borrow(), |constants| &constants[id.idx()])
    }

    pub fn get_dynamic_constant(&self, id: DynamicConstantID) -> Ref<'_, DynamicConstant> {
        Ref::map(self.dynamic_constants.borrow(), |dynamic_constants| {
            &dynamic_constants[id.idx()]
        })
    }

    pub fn is_mutable(&self, id: NodeID) -> bool {
        self.mutable_nodes[id.idx()]
    }

    pub fn node_ids(&self) -> impl ExactSizeIterator<Item = NodeID> {
        let num = self.function.nodes.len();
        (0..num).map(NodeID::new)
    }

    pub fn dynamic_constant_ids(&self) -> impl ExactSizeIterator<Item = DynamicConstantID> {
        let num = self.dynamic_constants.borrow().len();
        (0..num).map(DynamicConstantID::new)
    }
}

impl<'a, 'b> FunctionEdit<'a, 'b> {
    fn ensure_updated_def_use_entry(&mut self, id: NodeID) {
        if !self.updated_def_use.contains_key(&id) {
            let old_entry = self
                .editor
                .mut_def_use
                .get(id.idx())
                .map(|entry| entry.clone())
                .unwrap_or_default();
            self.updated_def_use.insert(id, old_entry);
        }
    }

    fn is_mutable(&self, id: NodeID) -> bool {
        id.idx() >= self.editor.mutable_nodes.len() || self.editor.mutable_nodes[id.idx()]
    }

    pub fn num_dynamic_constants(&self) -> usize {
        self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len()
    }

    pub fn num_node_ids(&self) -> usize {
        self.editor.function.nodes.len() + self.added_nodeids.len()
    }

    pub fn copy_node(&mut self, node: NodeID) -> NodeID {
        self.add_node(self.editor.func().nodes[node.idx()].clone())
    }

    pub fn add_node(&mut self, node: Node) -> NodeID {
        let id = NodeID::new(self.num_node_ids());
        // Added nodes need to have an entry in the def-use map.
        self.updated_def_use.entry(id).or_insert(HashSet::new());
        // Added nodes use other nodes, and we need to update their def-use
        // entries.
        for u in get_uses(&node).as_ref() {
            self.ensure_updated_def_use_entry(*u);
            self.updated_def_use.get_mut(u).unwrap().insert(id);
        }
        // Add the node.
        self.added_and_updated_nodes.insert(id, node);
        self.added_nodeids.insert(id);
        id
    }

    pub fn delete_node(mut self, id: NodeID) -> Result<Self, Self> {
        // We can only delete mutable nodes. Return None if we try to modify an
        // immutable node, as it means the whole edit should be aborted.
        if self.is_mutable(id) {
            assert!(
                !self.added_nodeids.contains(&id),
                "PANIC: Please don't delete a node that was added in the same edit."
            );
            // Deleted nodes use other nodes, and we need to update their def-
            // use entries.
            let uses: Box<[NodeID]> = get_uses(&self.editor.function.nodes[id.idx()])
                .as_ref()
                .into();
            for u in uses {
                self.ensure_updated_def_use_entry(u);
                self.updated_def_use.get_mut(&u).unwrap().remove(&id);
            }
            self.deleted_nodeids.insert(id);
            Ok(self)
        } else {
            Err(self)
        }
    }

    pub fn replace_all_uses(self, old: NodeID, new: NodeID) -> Result<Self, Self> {
        self.replace_all_uses_where(old, new, |_| true)
    }

    pub fn replace_all_uses_where<P>(
        mut self,
        old: NodeID,
        new: NodeID,
        pred: P,
    ) -> Result<Self, Self>
    where
        P: Fn(&NodeID) -> bool,
    {
        // We can only replace uses of mutable nodes. Return None if we try to
        // replace uses of an immutable node, as it means the whole edit should
        // be aborted.
        if self.is_mutable(old) {
            // Update all of the users of the old node.
            self.ensure_updated_def_use_entry(old);
            for user_id in self.updated_def_use[&old].iter() {
                if pred(user_id) {
                    // Replace uses of old with new.
                    let mut updated_user = self.get_node(*user_id).clone();
                    for u in get_uses_mut(&mut updated_user).as_mut() {
                        if **u == old {
                            **u = new;
                        }
                    }
                    // Add the updated user to added_and_updated.
                    self.added_and_updated_nodes.insert(*user_id, updated_user);
                }
            }

            // All of the users of the old node become users of the new node, so
            // move all of the entries in the def-use from the old to the new.
            let old_entries = take(self.updated_def_use.get_mut(&old).unwrap());
            let (new_users, old_users): (Vec<_>, Vec<_>) = old_entries.into_iter().partition(pred);
            self.ensure_updated_def_use_entry(new);
            self.updated_def_use
                .get_mut(&new)
                .unwrap()
                .extend(new_users);
            self.updated_def_use
                .get_mut(&old)
                .unwrap()
                .extend(old_users);

            Ok(self)
        } else {
            Err(self)
        }
    }

    pub fn sub_edit(&mut self, src: NodeID, dst: NodeID) {
        assert!(!self.added_nodeids.contains(&src));
        assert!(!self.deleted_nodeids.contains(&dst));
        self.sub_edits.push((src, dst));
    }

    pub fn get_name(&self) -> &str {
        &self.editor.function.name
    }

    pub fn get_num_dynamic_constant_params(&self) -> u32 {
        self.editor.function.num_dynamic_constants
    }

    pub fn get_node(&self, id: NodeID) -> &Node {
        assert!(!self.deleted_nodeids.contains(&id));
        if let Some(node) = self.added_and_updated_nodes.get(&id) {
            // Refer to added or updated node. This node is guaranteed to be
            // updated with uses after replace_all_uses is called.
            node
        } else {
            // Refer to the original node.
            &self.editor.function.nodes[id.idx()]
        }
    }

    pub fn get_schedule(&self, id: NodeID) -> &Vec<Schedule> {
        // The user may get the schedule of a to-be deleted node.
        if let Some(schedule) = self.added_and_updated_schedules.get(&id) {
            // Refer to added or updated schedule.
            schedule
        } else {
            // Refer to the original schedule of this node.
            &self.editor.function.schedules[id.idx()]
        }
    }

    pub fn add_schedule(mut self, id: NodeID, schedule: Schedule) -> Result<Self, Self> {
        if self.is_mutable(id) {
            if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) {
                if !schedules.contains(&schedule) {
                    schedules.push(schedule);
                }
            } else {
                let empty = vec![];
                let schedules = self
                    .editor
                    .function
                    .schedules
                    .get(id.idx())
                    .unwrap_or(&empty);
                if !schedules.contains(&schedule) {
                    let mut schedules = schedules.clone();
                    schedules.push(schedule);
                    self.added_and_updated_schedules.insert(id, schedules);
                }
            }
            Ok(self)
        } else {
            Err(self)
        }
    }

    pub fn clear_schedule(mut self, id: NodeID) -> Result<Self, Self> {
        if self.is_mutable(id) {
            self.added_and_updated_schedules.insert(id, vec![]);
            Ok(self)
        } else {
            Err(self)
        }
    }

    pub fn get_label(&self, id: NodeID) -> &HashSet<LabelID> {
        // The user may get the labels of a to-be deleted node.
        if let Some(label) = self.added_and_updated_labels.get(&id) {
            // Refer to added or updated label.
            label
        } else {
            // Refer to the origin label of this code.
            &self.editor.function.labels[id.idx()]
        }
    }

    pub fn add_label(mut self, id: NodeID, label: LabelID) -> Result<Self, Self> {
        if self.is_mutable(id) {
            if let Some(labels) = self.added_and_updated_labels.get_mut(&id) {
                labels.insert(label);
            } else {
                let mut labels = self
                    .editor
                    .function
                    .labels
                    .get(id.idx())
                    .unwrap_or(&HashSet::new())
                    .clone();
                labels.insert(label);
                self.added_and_updated_labels.insert(id, labels);
            }
            Ok(self)
        } else {
            Err(self)
        }
    }

    // Creates or returns the LabelID for a given label name
    pub fn new_label(&mut self, name: String) -> LabelID {
        let pos = self
            .editor
            .labels
            .borrow()
            .iter()
            .chain(self.added_labels.iter())
            .position(|l| *l == name);
        if let Some(idx) = pos {
            LabelID::new(idx)
        } else {
            let idx = self.editor.labels.borrow().len() + self.added_labels.len();
            self.added_labels.push(name);
            LabelID::new(idx)
        }
    }

    // Creates an entirely fresh label and returns its LabelID
    pub fn fresh_label(&mut self) -> LabelID {
        let idx = self.editor.labels.borrow().len() + self.added_labels.len();
        self.added_labels.push(format!("#fresh_{}", idx));
        LabelID::new(idx)
    }

    pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
        assert!(!self.deleted_nodeids.contains(&id));
        if let Some(users) = self.updated_def_use.get(&id) {
            // Refer to the updated users set.
            Either::Left(users.iter().map(|x| *x))
        } else {
            // Refer to the original users set.
            Either::Right(self.editor.mut_def_use[id.idx()].iter().map(|x| *x))
        }
    }

    pub fn add_type(&mut self, ty: Type) -> TypeID {
        let pos = self
            .editor
            .types
            .borrow()
            .iter()
            .chain(self.added_types.iter())
            .position(|t| *t == ty);
        if let Some(idx) = pos {
            TypeID::new(idx)
        } else {
            let id = TypeID::new(self.editor.types.borrow().len() + self.added_types.len());
            self.added_types.push(ty);
            id
        }
    }

    pub fn get_type(&self, id: TypeID) -> impl Deref<Target = Type> + '_ {
        if id.idx() < self.editor.types.borrow().len() {
            Either::Left(Ref::map(self.editor.types.borrow(), |types| {
                &types[id.idx()]
            }))
        } else {
            Either::Right(
                self.added_types
                    .get(id.idx() - self.editor.types.borrow().len())
                    .unwrap(),
            )
        }
    }

    pub fn add_constant(&mut self, constant: Constant) -> ConstantID {
        let pos = self
            .editor
            .constants
            .borrow()
            .iter()
            .chain(self.added_constants.iter())
            .position(|c| *c == constant);
        if let Some(idx) = pos {
            ConstantID::new(idx)
        } else {
            let id =
                ConstantID::new(self.editor.constants.borrow().len() + self.added_constants.len());
            self.added_constants.push(constant);
            id
        }
    }

    pub fn add_zero_constant(&mut self, id: TypeID) -> ConstantID {
        let ty = self.get_type(id).clone();
        let constant_to_construct = match ty {
            Type::Boolean => Constant::Boolean(false),
            Type::Integer8 => Constant::Integer8(0),
            Type::Integer16 => Constant::Integer16(0),
            Type::Integer32 => Constant::Integer32(0),
            Type::Integer64 => Constant::Integer64(0),
            Type::UnsignedInteger8 => Constant::UnsignedInteger8(0),
            Type::UnsignedInteger16 => Constant::UnsignedInteger16(0),
            Type::UnsignedInteger32 => Constant::UnsignedInteger32(0),
            Type::UnsignedInteger64 => Constant::UnsignedInteger64(0),
            Type::Float8 | Type::BFloat16 => panic!(),
            Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(0.0)),
            Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(0.0)),
            Type::Control => panic!("PANIC: Can't create zero constant for the control type."),
            Type::Product(tys) => {
                let dummy_elems: Vec<_> =
                    tys.iter().map(|ty| self.add_zero_constant(*ty)).collect();
                Constant::Product(id, dummy_elems.into_boxed_slice())
            }
            Type::Summation(tys) => Constant::Summation(id, 0, self.add_zero_constant(tys[0])),
            Type::Array(_, _) => Constant::Array(id),
            Type::MultiReturn(_) => {
                panic!("PANIC: Can't create zero constant for multi-return types.")
            }
        };
        self.add_constant(constant_to_construct)
    }

    pub fn add_one_constant(&mut self, id: TypeID) -> ConstantID {
        let ty = self.get_type(id).clone();
        let constant_to_construct = match ty {
            Type::Boolean => Constant::Boolean(true),
            Type::Integer8 => Constant::Integer8(1),
            Type::Integer16 => Constant::Integer16(1),
            Type::Integer32 => Constant::Integer32(1),
            Type::Integer64 => Constant::Integer64(1),
            Type::UnsignedInteger8 => Constant::UnsignedInteger8(1),
            Type::UnsignedInteger16 => Constant::UnsignedInteger16(1),
            Type::UnsignedInteger32 => Constant::UnsignedInteger32(1),
            Type::UnsignedInteger64 => Constant::UnsignedInteger64(1),
            Type::Float8 | Type::BFloat16 => panic!(),
            Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(1.0)),
            Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(1.0)),
            Type::Control => panic!("PANIC: Can't create one constant for the control type."),
            Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
                panic!("PANIC: Can't create one constant of a collection type.")
            }
            Type::MultiReturn(_) => {
                panic!("PANIC: Can't create one constant for multi-return types.")
            }
        };
        self.add_constant(constant_to_construct)
    }

    pub fn add_largest_constant(&mut self, id: TypeID) -> ConstantID {
        let ty = self.get_type(id).clone();
        let constant_to_construct = match ty {
            Type::Boolean => Constant::Boolean(true),
            Type::Integer8 => Constant::Integer8(i8::MAX),
            Type::Integer16 => Constant::Integer16(i16::MAX),
            Type::Integer32 => Constant::Integer32(i32::MAX),
            Type::Integer64 => Constant::Integer64(i64::MAX),
            Type::UnsignedInteger8 => Constant::UnsignedInteger8(u8::MAX),
            Type::UnsignedInteger16 => Constant::UnsignedInteger16(u16::MAX),
            Type::UnsignedInteger32 => Constant::UnsignedInteger32(u32::MAX),
            Type::UnsignedInteger64 => Constant::UnsignedInteger64(u64::MAX),
            Type::Float8 | Type::BFloat16 => panic!(),
            Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(f32::INFINITY)),
            Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(f64::INFINITY)),
            Type::Control => panic!("PANIC: Can't create largest constant for the control type."),
            Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
                panic!("PANIC: Can't create largest constant of a collection type.")
            }
            Type::MultiReturn(_) => {
                panic!("PANIC: Can't create largest constant for multi-return types.")
            }
        };
        self.add_constant(constant_to_construct)
    }

    pub fn add_smallest_constant(&mut self, id: TypeID) -> ConstantID {
        let ty = self.get_type(id).clone();
        let constant_to_construct = match ty {
            Type::Boolean => Constant::Boolean(true),
            Type::Integer8 => Constant::Integer8(i8::MIN),
            Type::Integer16 => Constant::Integer16(i16::MIN),
            Type::Integer32 => Constant::Integer32(i32::MIN),
            Type::Integer64 => Constant::Integer64(i64::MIN),
            Type::UnsignedInteger8 => Constant::UnsignedInteger8(u8::MIN),
            Type::UnsignedInteger16 => Constant::UnsignedInteger16(u16::MIN),
            Type::UnsignedInteger32 => Constant::UnsignedInteger32(u32::MIN),
            Type::UnsignedInteger64 => Constant::UnsignedInteger64(u64::MIN),
            Type::Float8 | Type::BFloat16 => panic!(),
            Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(f32::NEG_INFINITY)),
            Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(f64::NEG_INFINITY)),
            Type::Control => panic!("PANIC: Can't create smallest constant for the control type."),
            Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
                panic!("PANIC: Can't create smallest constant of a collection type.")
            }
            Type::MultiReturn(_) => {
                panic!("PANIC: Can't create smallest constant for multi-return types.")
            }
        };
        self.add_constant(constant_to_construct)
    }

    pub fn get_constant(&self, id: ConstantID) -> impl Deref<Target = Constant> + '_ {
        if id.idx() < self.editor.constants.borrow().len() {
            Either::Left(Ref::map(self.editor.constants.borrow(), |constants| {
                &constants[id.idx()]
            }))
        } else {
            Either::Right(
                self.added_constants
                    .get(id.idx() - self.editor.constants.borrow().len())
                    .unwrap(),
            )
        }
    }

    pub fn add_dynamic_constant(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID {
        self.dc_normalize(dynamic_constant)
    }

    pub fn get_dynamic_constant(
        &self,
        id: DynamicConstantID,
    ) -> impl Deref<Target = DynamicConstant> + '_ {
        if id.idx() < self.editor.dynamic_constants.borrow().len() {
            Either::Left(Ref::map(
                self.editor.dynamic_constants.borrow(),
                |dynamic_constants| &dynamic_constants[id.idx()],
            ))
        } else {
            Either::Right(
                self.added_dynamic_constants
                    .get(id.idx() - self.editor.dynamic_constants.borrow().len())
                    .unwrap(),
            )
        }
    }

    pub fn get_param_types(&self) -> &Vec<TypeID> {
        self.updated_param_types
            .as_ref()
            .unwrap_or(&self.editor.function.param_types)
    }

    pub fn get_return_types(&self) -> &Vec<TypeID> {
        self.updated_return_types
            .as_ref()
            .unwrap_or(&self.editor.function.return_types)
    }

    pub fn set_param_types(&mut self, tys: Vec<TypeID>) {
        self.updated_param_types = Some(tys);
    }

    pub fn set_return_types(&mut self, tys: Vec<TypeID>) {
        self.updated_return_types = Some(tys);
    }
}

impl<'a, 'b> DynamicConstantView for FunctionEdit<'a, 'b> {
    fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ {
        self.get_dynamic_constant(id)
    }

    fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID {
        let pos = self
            .editor
            .dynamic_constants
            .borrow()
            .iter()
            .chain(self.added_dynamic_constants.iter())
            .position(|c| *c == dc);
        if let Some(idx) = pos {
            DynamicConstantID::new(idx)
        } else {
            let id = DynamicConstantID::new(
                self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(),
            );
            self.added_dynamic_constants.push(dc);
            id
        }
    }
}

#[cfg(test)]
mod editor_tests {
    #[allow(unused_imports)]
    use super::*;

    use std::mem::replace;

    use hercules_ir::dataflow::reverse_postorder;
    use hercules_ir::parse::parse;

    fn canonicalize(function: &mut Function) -> Vec<Option<NodeID>> {
        // The reverse postorder traversal from the Start node is a map from new
        // index to old ID.
        let rev_po = reverse_postorder(&def_use(function));
        let num_new_nodes = rev_po.len();

        // Construct a map from old ID to new ID.
        let mut old_to_new = vec![None; function.nodes.len()];
        for (new_idx, old_id) in rev_po.into_iter().enumerate() {
            old_to_new[old_id.idx()] = Some(NodeID::new(new_idx));
        }

        // Move the old nodes before permuting them.
        let mut old_nodes = take(&mut function.nodes);
        function.nodes = vec![Node::Start; num_new_nodes];

        // Permute the old nodes back into the function and fix their uses.
        for (old_idx, new_id) in old_to_new.iter().enumerate() {
            // Check if this old node is in the canonicalized form.
            if let Some(new_id) = new_id {
                // Get the old node.
                let mut node = replace(&mut old_nodes[old_idx], Node::Start);

                // Fix its uses.
                for u in get_uses_mut(&mut node).as_mut() {
                    // Map every use using the old-to-new map. If we try to use
                    // a node that doesn't have a mapping, then the original IR
                    // had a node reachable from the start using another node
                    // not reachable from the start, which is malformed.
                    **u = old_to_new[u.idx()].unwrap();
                }

                // Insert the fixed node into its new spot.
                function.nodes[new_id.idx()] = node;
            }
        }

        old_to_new
    }

    #[test]
    fn example1() {
        // Define the original function.
        let mut src_module = parse(
            "
fn func(x: i32) -> i32
  c = constant(i32, 7)
  y = add(x, c)
  r = return(start, y)
",
        )
        .unwrap();

        // Find the ID of the add node and its uses.
        let func = &mut src_module.functions[0];
        let (add, left, right) = func
            .nodes
            .iter()
            .enumerate()
            .filter_map(|(idx, node)| {
                node.try_binary(BinaryOperator::Add)
                    .map(|(left, right)| (NodeID::new(idx), left, right))
            })
            .next()
            .unwrap();

        let constants_ref = RefCell::new(src_module.constants);
        let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants);
        let types_ref = RefCell::new(src_module.types);
        let labels_ref = RefCell::new(src_module.labels);
        // Edit the function by replacing the add with a multiply.
        let mut editor = FunctionEditor::new(
            func,
            FunctionID::new(0),
            &constants_ref,
            &dynamic_constants_ref,
            &types_ref,
            &labels_ref,
            &def_use(func),
        );
        let success = editor.edit(|mut edit| {
            let mul = edit.add_node(Node::Binary {
                op: BinaryOperator::Mul,
                left,
                right,
            });
            let edit = edit.replace_all_uses(add, mul)?;
            let edit = edit.delete_node(add)?;
            Ok(edit)
        });
        assert!(success);

        // Canonicalize the function.
        canonicalize(func);

        // Check that the function is correct.
        let mut dst_module = parse(
            "
fn func(x: i32) -> i32
  c = constant(i32, 7)
  y = mul(x, c)
  r = return(start, y)
",
        )
        .unwrap();
        canonicalize(&mut dst_module.functions[0]);
        assert_eq!(src_module.functions[0].nodes, dst_module.functions[0].nodes);
    }
}