Skip to content
Snippets Groups Projects
reuse_products.rs 6.91 KiB
use std::collections::HashMap;

use hercules_ir::ir::*;

use crate::*;

/*
 * Reuse Products is an optimization pass which identifies when two product
 * values are identical because each field of the "source" product is read and
 * then written into the "destination" product and then replaces the destination
 * product by the source product.
 *
 * This pattern can occur in our code because SROA and IP SROA are both
 * aggressive about breaking products into their fields and reconstructing
 * products right where needed, so if a function returns a product that is
 * produced by a call node, these optimizations will produce code that reads the
 * fields out of the call node and then writes them into the product that is
 * returned.
 *
 * This optimization does not delete any nodes other than the destination nodes,
 * if other nodes become dead as a result the clean up is left to DCE.
 *
 * The analysis for this starts by labeling each product source node (arguments,
 * constants, and call nodes) with themselves as the source of all of their
 * fields. Then, these field sources are propagated along read and write nodes.
 * At the end all nodes with product values are labeled by the source (node and
 * index) of each of its fields. We then check if any node's fields are exactly
 * the fields of some other node (i.e. is exactly the same value as some other
 * node) we replace it with that other node.
 */
pub fn reuse_products(
    editor: &mut FunctionEditor,
    reverse_postorder: &Vec<NodeID>,
    types: &Vec<TypeID>,
) {
    let mut source_nodes = vec![];
    let mut read_write_nodes = vec![];

    for node in reverse_postorder {
        match &editor.node(node) {
            Node::Parameter { .. } | Node::Constant { .. } | Node::Call { .. }
                if editor.get_type(types[node.idx()]).is_product() =>
            {
                source_nodes.push(*node)
            }
            Node::Write { .. } if editor.get_type(types[node.idx()]).is_product() => {
                read_write_nodes.push(*node)
            }
            Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => {
                read_write_nodes.push(*node)
            }
            _ => (),
        }
    }

    let mut product_nodes: HashMap<NodeID, IndexTree<(NodeID, Vec<Index>)>> = HashMap::new();

    for source in source_nodes {
        product_nodes.insert(
            source,
            generate_source_info(editor, source, types[source.idx()]),
        );
    }

    for node in read_write_nodes {
        match editor.node(node) {
            Node::Read { collect, indices } => {
                let Some(collect) = product_nodes.get(collect) else {
                    continue;
                };
                let result = collect.lookup(indices);
                product_nodes.insert(node, result.clone());
            }
            Node::Write {
                collect,
                data,
                indices,
            } => {
                let Some(collect) = product_nodes.get(collect) else {
                    continue;
                };
                let Some(data) = product_nodes.get(data) else {
                    continue;
                };
                let result = collect.clone().replace(indices, data.clone());
                product_nodes.insert(node, result);
            }
            _ => panic!("Non read/write node"),
        }
    }

    // Note that we don't have to worry about some node A being equivalent to node B but node B
    // being equivalent to node C and being replaced first causing an issue when we try to replace
    // node A with B.
    // This cannot occur since the only nodes something can be equivalent with are the source nodes
    // and they are all equivalent to precisely themselves which we ignore.
    for (node, data) in product_nodes {
        let Some(replace_with) = is_other_product(editor, types, data) else {
            continue;
        };

        if replace_with != node {
            editor.edit(|edit| {
                let edit = edit.replace_all_uses(node, replace_with)?;
                edit.delete_node(node)
            });
        }
    }
}

fn generate_source_info(
    editor: &FunctionEditor,
    source: NodeID,
    typ: TypeID,
) -> IndexTree<(NodeID, Vec<Index>)> {
    generate_source_info_at_index(editor, source, typ, vec![])
}

fn generate_source_info_at_index(
    editor: &FunctionEditor,
    source: NodeID,
    typ: TypeID,
    idx: Vec<Index>,
) -> IndexTree<(NodeID, Vec<Index>)> {
    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
        Some(ts.into())
    } else {
        None
    };

    if let Some(ts) = ts {
        // Recurse on each field with an extended index and appropriate type
        let mut fields = vec![];
        for (i, t) in ts.into_iter().enumerate() {
            let mut new_idx = idx.clone();
            new_idx.push(Index::Field(i));
            fields.push(generate_source_info_at_index(editor, source, t, new_idx));
        }
        IndexTree::Node(fields)
    } else {
        // We've reached the leaf
        IndexTree::Leaf((source, idx))
    }
}

fn is_other_product(
    editor: &FunctionEditor,
    types: &Vec<TypeID>,
    node: IndexTree<(NodeID, Vec<Index>)>,
) -> Option<NodeID> {
    let Some(other_node) = find_only_node(&node) else {
        return None;
    };

    if matches_fields_index(editor, types[other_node.idx()], &node, vec![]) {
        Some(other_node)
    } else {
        None
    }
}

fn find_only_node(tree: &IndexTree<(NodeID, Vec<Index>)>) -> Option<NodeID> {
    match tree {
        IndexTree::Leaf((node, _)) => Some(*node),
        IndexTree::Node(fields) => fields
            .iter()
            .map(|t| find_only_node(t))
            .reduce(|n, m| match (n, m) {
                (Some(n), Some(m)) if n == m => Some(n),
                (_, _) => None,
            })
            .flatten(),
    }
}

fn matches_fields_index(
    editor: &FunctionEditor,
    typ: TypeID,
    tree: &IndexTree<(NodeID, Vec<Index>)>,
    index: Vec<Index>,
) -> bool {
    match tree {
        IndexTree::Leaf((_, idx)) => {
            // If in the original value we still have a product, these can't match
            if editor.get_type(typ).is_product() {
                false
            } else {
                *idx == index
            }
        }
        IndexTree::Node(fields) => {
            let ts: Vec<TypeID> = if let Some(ts) = editor.get_type(typ).try_product() {
                ts.into()
            } else {
                return false;
            };

            if fields.len() != ts.len() {
                return false;
            }

            ts.into_iter()
                .zip(fields.iter())
                .enumerate()
                .all(|(i, (ty, field))| {
                    let mut new_index = index.clone();
                    new_index.push(Index::Field(i));
                    matches_fields_index(editor, ty, field, new_index)
                })
        }
    }
}