use std::collections::HashMap; use hercules_ir::ir::*; use crate::*; /* * Reuse Products is an optimization pass which identifies when two product * values are identical because each field of the "source" product is read and * then written into the "destination" product and then replaces the destination * product by the source product. * * This pattern can occur in our code because SROA and IP SROA are both * aggressive about breaking products into their fields and reconstructing * products right where needed, so if a function returns a product that is * produced by a call node, these optimizations will produce code that reads the * fields out of the call node and then writes them into the product that is * returned. * * This optimization does not delete any nodes other than the destination nodes, * if other nodes become dead as a result the clean up is left to DCE. * * The analysis for this starts by labeling each product source node (arguments, * constants, and call nodes) with themselves as the source of all of their * fields. Then, these field sources are propagated along read and write nodes. * At the end all nodes with product values are labeled by the source (node and * index) of each of its fields. We then check if any node's fields are exactly * the fields of some other node (i.e. is exactly the same value as some other * node) we replace it with that other node. */ pub fn reuse_products( editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>, ) { let mut source_nodes = vec![]; let mut read_write_nodes = vec![]; for node in reverse_postorder { match &editor.node(node) { Node::Parameter { .. } | Node::Constant { .. } | Node::Call { .. } if editor.get_type(types[node.idx()]).is_product() => { source_nodes.push(*node) } Node::Write { .. } if editor.get_type(types[node.idx()]).is_product() => { read_write_nodes.push(*node) } Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => { read_write_nodes.push(*node) } _ => (), } } let mut product_nodes: HashMap<NodeID, IndexTree<(NodeID, Vec<Index>)>> = HashMap::new(); for source in source_nodes { product_nodes.insert( source, generate_source_info(editor, source, types[source.idx()]), ); } for node in read_write_nodes { match editor.node(node) { Node::Read { collect, indices } => { let Some(collect) = product_nodes.get(collect) else { continue; }; let result = collect.lookup(indices); product_nodes.insert(node, result.clone()); } Node::Write { collect, data, indices, } => { let Some(collect) = product_nodes.get(collect) else { continue; }; let Some(data) = product_nodes.get(data) else { continue; }; let result = collect.clone().replace(indices, data.clone()); product_nodes.insert(node, result); } _ => panic!("Non read/write node"), } } // Note that we don't have to worry about some node A being equivalent to node B but node B // being equivalent to node C and being replaced first causing an issue when we try to replace // node A with B. // This cannot occur since the only nodes something can be equivalent with are the source nodes // and they are all equivalent to precisely themselves which we ignore. for (node, data) in product_nodes { let Some(replace_with) = is_other_product(editor, types, data) else { continue; }; if replace_with != node { editor.edit(|edit| { let edit = edit.replace_all_uses(node, replace_with)?; edit.delete_node(node) }); } } } fn generate_source_info( editor: &FunctionEditor, source: NodeID, typ: TypeID, ) -> IndexTree<(NodeID, Vec<Index>)> { generate_source_info_at_index(editor, source, typ, vec![]) } fn generate_source_info_at_index( editor: &FunctionEditor, source: NodeID, typ: TypeID, idx: Vec<Index>, ) -> IndexTree<(NodeID, Vec<Index>)> { let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { Some(ts.into()) } else { None }; if let Some(ts) = ts { // Recurse on each field with an extended index and appropriate type let mut fields = vec![]; for (i, t) in ts.into_iter().enumerate() { let mut new_idx = idx.clone(); new_idx.push(Index::Field(i)); fields.push(generate_source_info_at_index(editor, source, t, new_idx)); } IndexTree::Node(fields) } else { // We've reached the leaf IndexTree::Leaf((source, idx)) } } fn is_other_product( editor: &FunctionEditor, types: &Vec<TypeID>, node: IndexTree<(NodeID, Vec<Index>)>, ) -> Option<NodeID> { let Some(other_node) = find_only_node(&node) else { return None; }; if matches_fields_index(editor, types[other_node.idx()], &node, vec![]) { Some(other_node) } else { None } } fn find_only_node(tree: &IndexTree<(NodeID, Vec<Index>)>) -> Option<NodeID> { match tree { IndexTree::Leaf((node, _)) => Some(*node), IndexTree::Node(fields) => fields .iter() .map(|t| find_only_node(t)) .reduce(|n, m| match (n, m) { (Some(n), Some(m)) if n == m => Some(n), (_, _) => None, }) .flatten(), } } fn matches_fields_index( editor: &FunctionEditor, typ: TypeID, tree: &IndexTree<(NodeID, Vec<Index>)>, index: Vec<Index>, ) -> bool { match tree { IndexTree::Leaf((_, idx)) => { // If in the original value we still have a product, these can't match if editor.get_type(typ).is_product() { false } else { *idx == index } } IndexTree::Node(fields) => { let ts: Vec<TypeID> = if let Some(ts) = editor.get_type(typ).try_product() { ts.into() } else { return false; }; if fields.len() != ts.len() { return false; } ts.into_iter() .zip(fields.iter()) .enumerate() .all(|(i, (ty, field))| { let mut new_index = index.clone(); new_index.push(Index::Field(i)); matches_fields_index(editor, ty, field, new_index) }) } } }