Skip to content
Snippets Groups Projects
sroa.rs 43.45 KiB
use std::collections::{BTreeMap, HashMap, VecDeque};

use hercules_ir::ir::*;

use crate::*;

/*
 * Top level function to run SROA, intraprocedurally. Product values can be used
 * and created by a relatively small number of nodes. Here are *all* of them:
 *
 * - Phi: can merge SSA values of products - these get broken up into phis on
 *   the individual fields
 *
 * - Reduce: similarly to phis, reduce nodes can cycle product values through
 *   reduction loops - these get broken up into reduces on the fields
 *
 * - Return: can return a product - the product values will be constructed
 *   at the return site
 *
 * - Parameter: can introduce a product - reads will be introduced for each
 *   field
 *
 * - Constant: can introduce a product - these are broken up into constants for
 *   the individual fields
 *
 * - Ternary: the select ternary operator can select between products - these
 *   are broken up into ternary nodes for the individual fields
 *
 * - Call: the call node can use a product value as an argument to another
 *   function, argument values will be constructed at the call site
 *
 * - DataProjection: data projection nodes can produce a product value that was
 *   returned by a function, we will break the value into individual fields
 *
 * - Read: the read node reads primitive fields from product values - these get
 *   replaced by a direct use of the field value
 *   A read can also extract a product from an array or sum; the value read out
 *   will be broken into individual fields (by individual reads from the array)
 *
 * - Write: the write node writes primitive fields in product values - these get
 *   replaced by a direct def of the field value
 *
 * The allow_sroa_arrays variable controls whether products that contain arrays
 * will be broken into pieces. This option is useful to have since breaking
 * these products up can be expensive if it requires destructing and
 * reconstructing the product at any point.
 *
 * TODO: Handle partial selections (i.e. immutable nodes). This will involve
 * actually tracking each source and use of a product and verifying that all of
 * the nodes involved are mutable.
 */
pub fn sroa(
    editor: &mut FunctionEditor,
    reverse_postorder: &Vec<NodeID>,
    types: &Vec<TypeID>,
    allow_sroa_arrays: bool,
) {
    let mut types: HashMap<NodeID, TypeID> = types
        .iter()
        .enumerate()
        .map(|(i, t)| (NodeID::new(i), *t))
        .collect();

    let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| {
        editor.get_type(typ).is_product()
            && (allow_sroa_arrays || !type_contains_array(editor, typ))
    };

    // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
    // that contains the corresponding fields of the original value
    let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();

    // First: determine all nodes which interact with products (as described above)
    let mut product_nodes: Vec<NodeID> = vec![];
    // We track call, data projection, and return nodes separately since they (may) require
    // constructing new products for the call's arguments, data projection's value, or a
    // returned value
    let mut call_return_nodes: Vec<NodeID> = vec![];

    for node in reverse_postorder {
        match &editor.func().nodes[node.idx()] {
            Node::Phi { .. }
            | Node::Reduce { .. }
            | Node::Parameter { .. }
            | Node::Constant { .. }
            | Node::Ternary {
                first: _,
                second: _,
                third: _,
                op: TernaryOperator::Select,
            } if can_sroa_type(editor, types[&node]) => product_nodes.push(*node),

            Node::Write {
                collect,
                data,
                indices,
            } => {
                let data = *data;
                let collect = *collect;

                // For a write, we may need to split it into two pieces if the it contains a mix of
                // field and non-field indices
                let (fields_write, write_prod_into_non) = {
                    let mut fields = vec![];
                    let mut remainder = vec![];

                    if can_sroa_type(editor, types[&node]) {
                        let mut indices = indices.iter();
                        while let Some(idx) = indices.next() {
                            if idx.is_field() {
                                fields.push(idx.clone());
                            } else {
                                remainder.push(idx.clone());
                                remainder.extend(indices.cloned());
                                break;
                            }
                        }
                    } else {
                        remainder.extend_from_slice(indices);
                    }

                    if fields.is_empty() {
                        if can_sroa_type(editor, types[&data]) {
                            (None, Some((*node, collect, remainder)))
                        } else {
                            (None, None)
                        }
                    } else if remainder.is_empty() {
                        (Some(*node), None)
                    } else {
                        // Here we perform the split into two writes
                        // We need to find the type of the collection that will be extracted from
                        // the collection being modified when we read it at the fields index
                        let after_fields_type = type_at_index(editor, types[&collect], &fields);

                        let mut inner_collection = None;
                        let mut fields_write = None;
                        let mut remainder_write = None;
                        editor.edit(|mut edit| {
                            let read_inner = edit.add_node(Node::Read {
                                collect,
                                indices: fields.clone().into(),
                            });
                            types.insert(read_inner, after_fields_type);
                            product_nodes.push(read_inner);
                            inner_collection = Some(read_inner);

                            let rem_write = edit.add_node(Node::Write {
                                collect: read_inner,
                                data,
                                indices: remainder.clone().into(),
                            });
                            types.insert(rem_write, after_fields_type);
                            remainder_write = Some(rem_write);

                            let complete_write = edit.add_node(Node::Write {
                                collect,
                                data: rem_write,
                                indices: fields.into(),
                            });
                            types.insert(complete_write, types[&collect]);
                            fields_write = Some(complete_write);

                            edit = edit.replace_all_uses(*node, complete_write)?;
                            edit.delete_node(*node)
                        });
                        let inner_collection = inner_collection.unwrap();
                        let fields_write = fields_write.unwrap();
                        let remainder_write = remainder_write.unwrap();

                        if editor.get_type(types[&data]).is_product() {
                            (
                                Some(fields_write),
                                Some((remainder_write, inner_collection, remainder)),
                            )
                        } else {
                            (Some(fields_write), None)
                        }
                    }
                };

                if let Some(node) = fields_write {
                    product_nodes.push(node);
                }

                if let Some((write_node, collection, index)) = write_prod_into_non {
                    let node = write_node;
                    // If we're writing a product into a non-product we need to replace the write
                    // by a sequence of writes that read each field of the product and write them
                    // into the collection, then those write nodes can be ignored for SROA but the
                    // reads will be handled by SROA

                    // The value being written must be the data and so must be a product
                    assert!(editor.get_type(types[&data]).is_product());
                    let fields = generate_reads(editor, types[&data], data);

                    let mut collection = collection;
                    let collection_type = types[&collection];

                    fields.for_each(|field: &Vec<Index>, val: &NodeID| {
                        product_nodes.push(*val);
                        editor.edit(|mut edit| {
                            collection = edit.add_node(Node::Write {
                                collect: collection,
                                data: *val,
                                indices: index
                                    .iter()
                                    .chain(field)
                                    .cloned()
                                    .collect::<Vec<_>>()
                                    .into(),
                            });
                            types.insert(collection, collection_type);
                            Ok(edit)
                        });
                    });

                    editor.edit(|mut edit| {
                        edit = edit.replace_all_uses(node, collection)?;
                        edit.delete_node(node)
                    });
                }
            }
            Node::Read { collect, indices } => {
                // For a read, we split the read into a series of reads where each piece has either
                // only field reads or no field reads. Those with fields are the only ones
                // considered during SROA but any read whose collection is not a product but
                // produces a product (i.e. if there's an array of products) then following the
                // read we replace the read that produces a product by reads of each field and add
                // that information to the node map for the rest of SROA (this produces some reads
                // that mix types of indices, since we only read leaves but that's okay since those
                // reads are not handled by SROA)
                let indices = if can_sroa_type(editor, types[collect]) {
                    indices
                        .chunk_by(|i, j| i.is_field() == j.is_field())
                        .collect::<Vec<_>>()
                } else {
                    vec![indices.as_ref()]
                };

                let (field_reads, non_fields_produce_prod) = {
                    if indices.len() == 0 {
                        // If there are no indices then there were no indices originally, this is
                        // only used with clones of arrays
                        (vec![], vec![])
                    } else if indices.len() == 1 {
                        // If once we perform chunking there's only one set of indices, we can just
                        // use the original node
                        if can_sroa_type(editor, types[collect]) {
                            (vec![*node], vec![])
                        } else if can_sroa_type(editor, types[node]) {
                            (vec![], vec![*node])
                        } else {
                            (vec![], vec![])
                        }
                    } else {
                        let mut field_reads = vec![];
                        let mut non_field = vec![];

                        // To construct the multiple reads we need to track the current collection
                        // and the type of that collection
                        let mut collect = *collect;
                        let mut typ = types[&collect];

                        let indices = indices
                            .into_iter()
                            .map(|i| i.into_iter().cloned().collect::<Vec<_>>())
                            .collect::<Vec<_>>();
                        for index in indices {
                            let is_field_read = index[0].is_field();
                            let field_type = type_at_index(editor, typ, &index);

                            editor.edit(|mut edit| {
                                collect = edit.add_node(Node::Read {
                                    collect,
                                    indices: index.into(),
                                });
                                types.insert(collect, field_type);
                                typ = field_type;
                                Ok(edit)
                            });

                            if is_field_read {
                                field_reads.push(collect);
                            } else if editor.get_type(typ).is_product() {
                                non_field.push(collect);
                            }
                        }

                        // Replace all uses of the original read (with mixed indices) with the
                        // newly constructed reads
                        editor.edit(|mut edit| {
                            edit = edit.replace_all_uses(*node, collect)?;
                            edit.delete_node(*node)
                        });

                        (field_reads, non_field)
                    }
                };

                product_nodes.extend(field_reads);

                for node in non_fields_produce_prod {
                    field_map.insert(node, generate_reads(editor, types[&node], node));
                }
            }

            // We add all calls and returns to the call/return list and check their
            // arguments/return values later
            Node::Call { .. } | Node::Return { .. } => call_return_nodes.push(*node),
            // We add DataProjetion nodes that produce SROAable values
            Node::DataProjection { .. } if can_sroa_type(editor, types[&node]) => {
                call_return_nodes.push(*node);
            }

            _ => (),
        }
    }

    // Next, we handle calls and returns. For returns, for each returned value that is a product,
    // we will insert nodes that read each field of it and then write them into a new product.
    // The writes we create are not put into the list of product nodes since they must remain but
    // the reads are put in the list so that they will be replaced later on.
    // For calls, we do a similar process for each (product) argument.
    // For data projection that produce product values, we create reads for each field of that
    // product and store it into our field map
    for node in call_return_nodes {
        match &editor.func().nodes[node.idx()] {
            Node::Return { control, data } => {
                let control = *control;
                let data = data.to_vec();

                let (new_data, changed) =
                    data.into_iter()
                        .fold((vec![], false), |(mut vals, changed), val_id| {
                            if !can_sroa_type(editor, types[&val_id]) {
                                vals.push(val_id);
                                (vals, changed)
                            } else {
                                vals.push(reconstruct_product(
                                    editor,
                                    types[&val_id],
                                    val_id,
                                    &mut product_nodes,
                                ));
                                (vals, true)
                            }
                        });
                if changed {
                    editor.edit(|mut edit| {
                        let new_return = edit.add_node(Node::Return {
                            control,
                            data: new_data.into(),
                        });
                        edit.sub_edit(node, new_return);
                        edit.delete_node(node)
                    });
                }
            }
            Node::Call {
                control,
                function,
                dynamic_constants,
                args,
            } => {
                let control = *control;
                let function = *function;
                let dynamic_constants = dynamic_constants.clone();
                let args = args.to_vec();

                let (new_args, changed) =
                    args.into_iter()
                        .fold((vec![], false), |(mut vals, changed), arg| {
                            if !can_sroa_type(editor, types[&arg]) {
                                vals.push(arg);
                                (vals, changed)
                            } else {
                                vals.push(reconstruct_product(
                                    editor,
                                    types[&arg],
                                    arg,
                                    &mut product_nodes,
                                ));
                                (vals, true)
                            }
                        });

                if changed {
                    editor.edit(|mut edit| {
                        let new_call = edit.add_node(Node::Call {
                            control,
                            function,
                            dynamic_constants,
                            args: new_args.into(),
                        });
                        edit.sub_edit(node, new_call);
                        let edit = edit.replace_all_uses(node, new_call)?;
                        let edit = edit.delete_node(node)?;

                        Ok(edit)
                    });
                }
            }
            Node::DataProjection { .. } => {
                assert!(can_sroa_type(editor, types[&node]));
                field_map.insert(node, generate_reads(editor, types[&node], node));
            }
            _ => panic!("Processing non-call or return node"),
        }
    }

    #[derive(Debug)]
    enum WorkItem {
        Unhandled(NodeID),
        AllocatedPhi {
            control: NodeID,
            data: Vec<NodeID>,
            node: NodeID,
            fields: IndexTree<NodeID>,
        },
        AllocatedReduce {
            control: NodeID,
            init: NodeID,
            reduct: NodeID,
            node: NodeID,
            fields: IndexTree<NodeID>,
        },
        AllocatedTernary {
            cond: NodeID,
            thn: NodeID,
            els: NodeID,
            node: NodeID,
            fields: IndexTree<NodeID>,
        },
    }

    // Now, we process the other nodes that deal with products.
    // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi,
    // reduce, parameter, constant, and ternary.
    // We do this in several steps: first we break apart parameters and constants
    let mut to_delete = vec![];
    let mut worklist: VecDeque<WorkItem> = VecDeque::new();

    for node in product_nodes {
        match editor.func().nodes[node.idx()] {
            Node::Parameter { .. } => {
                field_map.insert(node, generate_reads(editor, types[&node], node));
            }
            Node::Constant { id } => {
                field_map.insert(node, generate_constant_fields(editor, id, node));
                to_delete.push(node);
            }
            _ => {
                worklist.push_back(WorkItem::Unhandled(node));
            }
        }
    }

    // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
    // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we
    // need to add nodes in a particular order we wait to do that until the end). If we don't have
    // enough information to process a particular node, we add it back to the worklist
    let mut next_id: usize = editor.func().nodes.len();
    let mut to_insert = BTreeMap::new();
    let mut to_replace: Vec<(NodeID, NodeID)> = vec![];

    while let Some(mut item) = worklist.pop_front() {
        if let WorkItem::Unhandled(node) = item {
            match &editor.func().nodes[node.idx()] {
                // For phi, reduce, and ternary, we break them apart into separate nodes for each field
                Node::Phi { control, data } => {
                    let control = *control;
                    let data = data.clone();
                    let fields = allocate_fields(editor, types[&node], &mut next_id);
                    field_map.insert(node, fields.clone());

                    item = WorkItem::AllocatedPhi {
                        control,
                        data: data.into(),
                        node,
                        fields,
                    };
                }
                Node::Reduce {
                    control,
                    init,
                    reduct,
                } => {
                    let control = *control;
                    let init = *init;
                    let reduct = *reduct;
                    let fields = allocate_fields(editor, types[&node], &mut next_id);
                    field_map.insert(node, fields.clone());

                    item = WorkItem::AllocatedReduce {
                        control,
                        init,
                        reduct,
                        node,
                        fields,
                    };
                }
                Node::Ternary {
                    first,
                    second,
                    third,
                    ..
                } => {
                    let first = *first;
                    let second = *second;
                    let third = *third;
                    let fields = allocate_fields(editor, types[&node], &mut next_id);
                    field_map.insert(node, fields.clone());

                    item = WorkItem::AllocatedTernary {
                        cond: first,
                        thn: second,
                        els: third,
                        node,
                        fields,
                    };
                }

                Node::Write {
                    collect,
                    data,
                    indices,
                } => {
                    if let Some(index_map) = field_map.get(collect) {
                        if can_sroa_type(editor, types[&data]) {
                            if let Some(data_idx) = field_map.get(data) {
                                field_map.insert(
                                    node,
                                    index_map.clone().replace(indices, data_idx.clone()),
                                );
                                to_delete.push(node);
                            } else {
                                worklist.push_back(WorkItem::Unhandled(node));
                            }
                        } else {
                            field_map.insert(node, index_map.clone().set(indices, *data));
                            to_delete.push(node);
                        }
                    } else {
                        worklist.push_back(WorkItem::Unhandled(node));
                    }
                }
                Node::Read { collect, indices } => {
                    if let Some(index_map) = field_map.get(collect) {
                        let read_info = index_map.lookup(indices);
                        match read_info {
                            IndexTree::Leaf(field) => {
                                to_replace.push((node, *field));
                            }
                            _ => {}
                        }
                        field_map.insert(node, read_info.clone());
                        to_delete.push(node);
                    } else {
                        worklist.push_back(WorkItem::Unhandled(node));
                    }
                }

                _ => panic!("Unexpected node type"),
            }
        }
        match item {
            WorkItem::Unhandled(_) => {}
            WorkItem::AllocatedPhi {
                control,
                data,
                node,
                fields,
            } => {
                let mut data_fields = vec![];
                let mut ready = true;
                for val in data.iter() {
                    if let Some(val_fields) = field_map.get(val) {
                        data_fields.push(val_fields);
                    } else {
                        ready = false;
                        break;
                    }
                }

                if ready {
                    fields.zip_list(data_fields).for_each(|_, (res, data)| {
                        to_insert.insert(
                            res.idx(),
                            Node::Phi {
                                control,
                                data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(),
                            },
                        );
                    });
                    to_delete.push(node);
                } else {
                    worklist.push_back(WorkItem::AllocatedPhi {
                        control,
                        data,
                        node,
                        fields,
                    });
                }
            }
            WorkItem::AllocatedReduce {
                control,
                init,
                reduct,
                node,
                fields,
            } => {
                if let (Some(init_fields), Some(reduct_fields)) =
                    (field_map.get(&init), field_map.get(&reduct))
                {
                    fields.zip(init_fields).zip(reduct_fields).for_each(
                        |_, ((res, init), reduct)| {
                            to_insert.insert(
                                res.idx(),
                                Node::Reduce {
                                    control,
                                    init: **init,
                                    reduct: **reduct,
                                },
                            );
                        },
                    );
                    to_delete.push(node);
                } else {
                    worklist.push_back(WorkItem::AllocatedReduce {
                        control,
                        init,
                        reduct,
                        node,
                        fields,
                    });
                }
            }
            WorkItem::AllocatedTernary {
                cond,
                thn,
                els,
                node,
                fields,
            } => {
                if let (Some(thn_fields), Some(els_fields)) =
                    (field_map.get(&thn), field_map.get(&els))
                {
                    fields
                        .zip(thn_fields)
                        .zip(els_fields)
                        .for_each(|_, ((res, thn), els)| {
                            to_insert.insert(
                                res.idx(),
                                Node::Ternary {
                                    first: cond,
                                    second: **thn,
                                    third: **els,
                                    op: TernaryOperator::Select,
                                },
                            );
                        });
                    to_delete.push(node);
                } else {
                    worklist.push_back(WorkItem::AllocatedTernary {
                        cond,
                        thn,
                        els,
                        node,
                        fields,
                    });
                }
            }
        }
    }

    // Create new nodes nodes
    editor.edit(|mut edit| {
        for (node_id, node) in to_insert {
            let id = edit.add_node(node);
            assert_eq!(node_id, id.idx());
        }
        Ok(edit)
    });

    // Replace uses of old reads
    // Because a read that is being replaced could also be the node some other read is being
    // replaced by (if the first read is then written into a product that is then read from again)
    // we need to track what nodes have already been replaced (and by what) so we can properly
    // replace uses without leaving users of nodes that should be deleted.
    // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that
    // maps to a particular node (which is needed to maintain the data structure efficiently)
    let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new();
    let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
    for (old, new) in to_replace {
        let new = match replaced_by.get(&new) {
            Some(res) => *res,
            None => new,
        };
        editor.edit(|mut edit| {
            edit.sub_edit(old, new);
            edit.replace_all_uses(old, new)
        });
        replaced_by.insert(old, new);

        let mut replaced = vec![];
        match replaced_of.get_mut(&old) {
            Some(res) => {
                std::mem::swap(res, &mut replaced);
            }
            None => {}
        }

        let new_of = match replaced_of.get_mut(&new) {
            Some(res) => res,
            None => {
                replaced_of.insert(new, vec![]);
                replaced_of.get_mut(&new).unwrap()
            }
        };
        new_of.push(old);

        for n in replaced {
            replaced_by.insert(n, new);
            new_of.push(n);
        }
    }

    // Remove nodes
    editor.edit(|mut edit| {
        for node in to_delete {
            edit = edit.delete_node(node)?
        }
        Ok(edit)
    });
}

pub fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool {
    match &*editor.get_type(typ) {
        Type::Array(_, _) => true,
        Type::Product(ts) | Type::Summation(ts) => {
            ts.iter().any(|t| type_contains_array(editor, *t))
        }
        _ => false,
    }
}

// An index tree is used to store results at many index lists
#[derive(Clone, Debug)]
pub enum IndexTree<T> {
    Leaf(T),
    Node(Vec<IndexTree<T>>),
}

impl<T: std::fmt::Debug> IndexTree<T> {
    pub fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
        self.lookup_idx(idx, 0)
    }

    fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> {
        if n < idx.len() {
            if let Index::Field(i) = idx[n] {
                match self {
                    IndexTree::Leaf(_) => panic!("Invalid field"),
                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
                }
            } else {
                panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
            }
        } else {
            self
        }
    }

    pub fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
        self.set_idx(idx, val, 0)
    }

    fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> {
        if n < idx.len() {
            if let Index::Field(i) = idx[n] {
                match self {
                    IndexTree::Leaf(_) => panic!("Invalid field"),
                    IndexTree::Node(mut ts) => {
                        if i + 1 == ts.len() {
                            let t = ts.pop().unwrap();
                            ts.push(t.set_idx(idx, val, n + 1));
                        } else {
                            let mut t = ts.pop().unwrap();
                            std::mem::swap(&mut ts[i], &mut t);
                            t = t.set_idx(idx, val, n + 1);
                            std::mem::swap(&mut ts[i], &mut t);
                            ts.push(t);
                        }
                        IndexTree::Node(ts)
                    }
                }
            } else {
                panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
            }
        } else {
            IndexTree::Leaf(val)
        }
    }

    pub fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
        self.replace_idx(idx, val, 0)
    }

    fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> {
        if n < idx.len() {
            if let Index::Field(i) = idx[n] {
                match self {
                    IndexTree::Leaf(_) => panic!("Invalid field"),
                    IndexTree::Node(mut ts) => {
                        if i + 1 == ts.len() {
                            let t = ts.pop().unwrap();
                            ts.push(t.replace_idx(idx, val, n + 1));
                        } else {
                            let mut t = ts.pop().unwrap();
                            std::mem::swap(&mut ts[i], &mut t);
                            t = t.replace_idx(idx, val, n + 1);
                            std::mem::swap(&mut ts[i], &mut t);
                            ts.push(t);
                        }
                        IndexTree::Node(ts)
                    }
                }
            } else {
                panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
            }
        } else {
            val
        }
    }

    pub fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
        match (self, other) {
            (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)),
            (IndexTree::Node(t), IndexTree::Node(a)) => {
                let mut fields = vec![];
                for (t, a) in t.into_iter().zip(a.iter()) {
                    fields.push(t.zip(a));
                }
                IndexTree::Node(fields)
            }
            _ => panic!("IndexTrees do not have the same fields, cannot zip"),
        }
    }

    pub fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
        match self {
            IndexTree::Leaf(t) => {
                let mut res = vec![];
                for other in others {
                    match other {
                        IndexTree::Leaf(a) => res.push(a),
                        _ => panic!("IndexTrees do not have the same fields, cannot zip"),
                    }
                }
                IndexTree::Leaf((t, res))
            }
            IndexTree::Node(t) => {
                let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()];
                for other in others {
                    match other {
                        IndexTree::Node(a) => {
                            for (i, a) in a.iter().enumerate() {
                                fields[i].push(a);
                            }
                        }
                        _ => panic!("IndexTrees do not have the same fields, cannot zip"),
                    }
                }
                IndexTree::Node(
                    t.into_iter()
                        .zip(fields.into_iter())
                        .map(|(t, f)| t.zip_list(f))
                        .collect(),
                )
            }
        }
    }

    pub fn for_each<F>(&self, mut f: F)
    where
        F: FnMut(&Vec<Index>, &T),
    {
        self.for_each_idx(&mut vec![], &mut f);
    }

    fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
    where
        F: FnMut(&Vec<Index>, &T),
    {
        match self {
            IndexTree::Leaf(t) => f(idx, t),
            IndexTree::Node(ts) => {
                for (i, t) in ts.iter().enumerate() {
                    idx.push(Index::Field(i));
                    t.for_each_idx(idx, f);
                    idx.pop();
                }
            }
        }
    }
}

// Given the editor, type of some collection, and a list of indices to access that type at, returns
// the TypeID of accessing the collection at the given indices
fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID {
    let mut typ = typ;
    for index in idx {
        match index {
            Index::Field(i) => {
                let Type::Product(ref ts) = *editor.get_type(typ) else {
                    panic!("Accessing a field of a non-product type; did typechecking succeed?");
                };
                typ = ts[*i];
            }
            Index::Variant(i) => {
                let Type::Summation(ref ts) = *editor.get_type(typ) else {
                    panic!(
                        "Accessing a variant of a non-summation type; did typechecking succeed?"
                    );
                };
                typ = ts[*i];
            }
            Index::Position(pos) => {
                let Type::Array(elem, ref dims) = *editor.get_type(typ) else {
                    panic!("Accessing an array position of a non-array type; did typechecking succeed?");
                };
                assert!(pos.len() == dims.len(), "Read mismatch array dimensions");
                typ = elem;
            }
        }
    }
    return typ;
}

// Given a product value val of type typ, constructs a copy of that value by extracting all fields
// from that value and then writing them into a new constant
// This process also adds all the read nodes that are generated into the read_list so that the
// reads can be eliminated by later parts of SROA
fn reconstruct_product(
    editor: &mut FunctionEditor,
    typ: TypeID,
    val: NodeID,
    read_list: &mut Vec<NodeID>,
) -> NodeID {
    let fields = generate_reads(editor, typ, val);
    let new_const = generate_constant(editor, typ);

    // Create a constant node
    let mut const_node = None;
    editor.edit(|mut edit| {
        const_node = Some(edit.add_node(Node::Constant { id: new_const }));
        Ok(edit)
    });

    // Generate writes for each field
    let mut value = const_node.expect("Add node cannot fail");
    fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
        read_list.push(*val);
        editor.edit(|mut edit| {
            value = edit.add_node(Node::Write {
                collect: value,
                data: *val,
                indices: idx.clone().into(),
            });
            Ok(edit)
        });
    });

    value
}

// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
// returns an IndexTree that tracks the nodes reading each leaf field
pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
    let mut result = None;

    editor.edit(|mut edit| {
        result = Some(generate_reads_edit(&mut edit, typ, val));
        Ok(edit)
    });

    result.unwrap()
}

// The same as generate_reads but for if we have a FunctionEdit rather than a FunctionEditor
pub fn generate_reads_edit(edit: &mut FunctionEdit, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
    generate_reads_at_index_edit(edit, typ, val, vec![])
}

// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
// fields within this sub-value of val and return the correspondence list
fn generate_reads_at_index_edit(
    edit: &mut FunctionEdit,
    typ: TypeID,
    val: NodeID,
    idx: Vec<Index>,
) -> IndexTree<NodeID> {
    let ts: Option<Vec<TypeID>> = if let Some(ts) = edit.get_type(typ).try_product() {
        Some(ts.into())
    } else {
        None
    };

    if let Some(ts) = ts {
        // For product values, we will recurse down each of its fields with an extended index
        // and the appropriate type of that field
        let mut fields = vec![];
        for (i, t) in ts.into_iter().enumerate() {
            let mut new_idx = idx.clone();
            new_idx.push(Index::Field(i));
            fields.push(generate_reads_at_index_edit(edit, t, val, new_idx));
        }
        IndexTree::Node(fields)
    } else {
        // For non-product types, we've reached a leaf so we generate the read and return it's
        // information
        let read_id = edit.add_node(Node::Read {
            collect: val,
            indices: idx.into(),
        });

        IndexTree::Leaf(read_id)
    }
}

macro_rules! add_const {
    ($editor:ident, $const:expr) => {{
        let mut res = None;
        $editor.edit(|mut edit| {
            res = Some(edit.add_constant($const));
            Ok(edit)
        });
        res.expect("Add constant cannot fail")
    }};
}

// Given a type, builds a default constant of that type
pub fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
    let t = editor.get_type(typ).clone();

    match t {
        Type::Product(ts) => {
            let mut cs = vec![];
            for t in ts {
                cs.push(generate_constant(editor, t));
            }
            add_const!(editor, Constant::Product(typ, cs.into()))
        }
        Type::Boolean => add_const!(editor, Constant::Boolean(false)),
        Type::Integer8 => add_const!(editor, Constant::Integer8(0)),
        Type::Integer16 => add_const!(editor, Constant::Integer16(0)),
        Type::Integer32 => add_const!(editor, Constant::Integer32(0)),
        Type::Integer64 => add_const!(editor, Constant::Integer64(0)),
        Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)),
        Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)),
        Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)),
        Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)),
        Type::Float8 | Type::BFloat16 => panic!(),
        Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))),
        Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))),
        Type::Summation(ts) => {
            let const_id = generate_constant(editor, ts[0]);
            add_const!(editor, Constant::Summation(typ, 0, const_id))
        }
        Type::Array(_, _) => {
            add_const!(editor, Constant::Array(typ))
        }
        Type::Control => panic!("Cannot create constant of control type"),
        Type::MultiReturn(_) => panic!("Cannot create constant of multi-return type"),
    }
}

// Given a constant cnst adds node to the function which are the constant values of each field and
// returns a list of pairs of indices and the node that holds that index
fn generate_constant_fields(
    editor: &mut FunctionEditor,
    cnst: ConstantID,
    old_node: NodeID,
) -> IndexTree<NodeID> {
    let cs: Option<Vec<ConstantID>> =
        if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
            Some(cs.into())
        } else {
            None
        };

    if let Some(cs) = cs {
        let mut fields = vec![];
        for c in cs {
            fields.push(generate_constant_fields(editor, c, old_node));
        }
        IndexTree::Node(fields)
    } else {
        let mut node = None;
        editor.edit(|mut edit| {
            node = Some(edit.add_node(Node::Constant { id: cnst }));
            edit.sub_edit(old_node, node.unwrap());
            Ok(edit)
        });
        IndexTree::Leaf(node.expect("Add node cannot fail"))
    }
}

// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
// id provided
fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
        Some(ts.into())
    } else {
        None
    };

    if let Some(ts) = ts {
        let mut fields = vec![];
        for t in ts {
            fields.push(allocate_fields(editor, t, id));
        }
        IndexTree::Node(fields)
    } else {
        let node = *id;
        *id += 1;
        IndexTree::Leaf(NodeID::new(node))
    }
}