reuse_products.rs 6.91 KiB
use std::collections::HashMap;
use hercules_ir::ir::*;
use crate::*;
/*
* Reuse Products is an optimization pass which identifies when two product
* values are identical because each field of the "source" product is read and
* then written into the "destination" product and then replaces the destination
* product by the source product.
*
* This pattern can occur in our code because SROA and IP SROA are both
* aggressive about breaking products into their fields and reconstructing
* products right where needed, so if a function returns a product that is
* produced by a call node, these optimizations will produce code that reads the
* fields out of the call node and then writes them into the product that is
* returned.
*
* This optimization does not delete any nodes other than the destination nodes,
* if other nodes become dead as a result the clean up is left to DCE.
*
* The analysis for this starts by labeling each product source node (arguments,
* constants, and call nodes) with themselves as the source of all of their
* fields. Then, these field sources are propagated along read and write nodes.
* At the end all nodes with product values are labeled by the source (node and
* index) of each of its fields. We then check if any node's fields are exactly
* the fields of some other node (i.e. is exactly the same value as some other
* node) we replace it with that other node.
*/
pub fn reuse_products(
editor: &mut FunctionEditor,
reverse_postorder: &Vec<NodeID>,
types: &Vec<TypeID>,
) {
let mut source_nodes = vec![];
let mut read_write_nodes = vec![];
for node in reverse_postorder {
match &editor.node(node) {
Node::Parameter { .. } | Node::Constant { .. } | Node::Call { .. }
if editor.get_type(types[node.idx()]).is_product() =>
{
source_nodes.push(*node)
}
Node::Write { .. } if editor.get_type(types[node.idx()]).is_product() => {
read_write_nodes.push(*node)
}
Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => {
read_write_nodes.push(*node)
}
_ => (),
}
}
let mut product_nodes: HashMap<NodeID, IndexTree<(NodeID, Vec<Index>)>> = HashMap::new();
for source in source_nodes {
product_nodes.insert(
source,
generate_source_info(editor, source, types[source.idx()]),
);
}
for node in read_write_nodes {
match editor.node(node) {
Node::Read { collect, indices } => {
let Some(collect) = product_nodes.get(collect) else {
continue;
};
let result = collect.lookup(indices);
product_nodes.insert(node, result.clone());
}
Node::Write {
collect,
data,
indices,
} => {
let Some(collect) = product_nodes.get(collect) else {
continue;
};
let Some(data) = product_nodes.get(data) else {
continue;
};
let result = collect.clone().replace(indices, data.clone());
product_nodes.insert(node, result);
}
_ => panic!("Non read/write node"),
}
}
// Note that we don't have to worry about some node A being equivalent to node B but node B
// being equivalent to node C and being replaced first causing an issue when we try to replace
// node A with B.
// This cannot occur since the only nodes something can be equivalent with are the source nodes
// and they are all equivalent to precisely themselves which we ignore.
for (node, data) in product_nodes {
let Some(replace_with) = is_other_product(editor, types, data) else {
continue;
};
if replace_with != node {
editor.edit(|edit| {
let edit = edit.replace_all_uses(node, replace_with)?;
edit.delete_node(node)
});
}
}
}
fn generate_source_info(
editor: &FunctionEditor,
source: NodeID,
typ: TypeID,
) -> IndexTree<(NodeID, Vec<Index>)> {
generate_source_info_at_index(editor, source, typ, vec![])
}
fn generate_source_info_at_index(
editor: &FunctionEditor,
source: NodeID,
typ: TypeID,
idx: Vec<Index>,
) -> IndexTree<(NodeID, Vec<Index>)> {
let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
Some(ts.into())
} else {
None
};
if let Some(ts) = ts {
// Recurse on each field with an extended index and appropriate type
let mut fields = vec![];
for (i, t) in ts.into_iter().enumerate() {
let mut new_idx = idx.clone();
new_idx.push(Index::Field(i));
fields.push(generate_source_info_at_index(editor, source, t, new_idx));
}
IndexTree::Node(fields)
} else {
// We've reached the leaf
IndexTree::Leaf((source, idx))
}
}
fn is_other_product(
editor: &FunctionEditor,
types: &Vec<TypeID>,
node: IndexTree<(NodeID, Vec<Index>)>,
) -> Option<NodeID> {
let Some(other_node) = find_only_node(&node) else {
return None;
};
if matches_fields_index(editor, types[other_node.idx()], &node, vec![]) {
Some(other_node)
} else {
None
}
}
fn find_only_node(tree: &IndexTree<(NodeID, Vec<Index>)>) -> Option<NodeID> {
match tree {
IndexTree::Leaf((node, _)) => Some(*node),
IndexTree::Node(fields) => fields
.iter()
.map(|t| find_only_node(t))
.reduce(|n, m| match (n, m) {
(Some(n), Some(m)) if n == m => Some(n),
(_, _) => None,
})
.flatten(),
}
}
fn matches_fields_index(
editor: &FunctionEditor,
typ: TypeID,
tree: &IndexTree<(NodeID, Vec<Index>)>,
index: Vec<Index>,
) -> bool {
match tree {
IndexTree::Leaf((_, idx)) => {
// If in the original value we still have a product, these can't match
if editor.get_type(typ).is_product() {
false
} else {
*idx == index
}
}
IndexTree::Node(fields) => {
let ts: Vec<TypeID> = if let Some(ts) = editor.get_type(typ).try_product() {
ts.into()
} else {
return false;
};
if fields.len() != ts.len() {
return false;
}
ts.into_iter()
.zip(fields.iter())
.enumerate()
.all(|(i, (ty, field))| {
let mut new_index = index.clone();
new_index.push(Index::Field(i));
matches_fields_index(editor, ty, field, new_index)
})
}
}
}