inline.rs 17.43 KiB
use std::cell::Ref;
use std::collections::HashMap;
use hercules_ir::*;
use crate::*;
/*
* Top level function to run inlining. Currently, inlines every function call,
* since mutual recursion is not valid in Hercules IR.
*/
pub fn inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) {
// Step 1: run topological sort on the call graph to inline the "deepest"
// function first.
let topo = callgraph.topo();
// Step 2: make sure each function has a single return node. If an edit
// failed to make a function have a single return node, then we can't inline
// calls of it.
let single_return_nodes: Vec<_> = editors
.iter_mut()
.map(|editor| collapse_returns(editor))
.collect();
// Step 3: get dynamic constant IDs for parameters.
let max_num_dc_params = editors
.iter()
.map(|editor| editor.func().num_dynamic_constants)
.max()
.unwrap();
let mut dc_args = vec![];
editors[0].edit(|mut edit| {
dc_args = (0..max_num_dc_params as usize)
.map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i)))
.collect();
Ok(edit)
});
// Step 4: run inlining on each function individually. Iterate the functions
// in topological order.
for to_inline_id in topo {
// Since Rust cannot analyze the accesses into an array of mutable
// references, we need to do some weirdness here to simultaneously get:
// 1. A mutable reference to the function we're modifying.
// 2. Shared references to all of the functions called by that function.
let callees = callgraph.get_callees(to_inline_id);
let editor_refs = get_mut_and_immuts(editors, to_inline_id, callees);
inline_func(editor_refs.0, editor_refs.1, &single_return_nodes, &dc_args);
}
}
/*
* Helper function to get from an array of mutable references:
* 1. A single mutable reference.
* 2. Several shared references.
* Where none of the references alias. We need to use this both for function
* editors and schedules.
*/
fn get_mut_and_immuts<'a, T, I: ID>(
mut_refs: &'a mut [T],
mut_id: I,
shared_id: &[I],
) -> (&'a mut T, HashMap<I, &'a T>) {
let mut all_id = Vec::from(shared_id);
all_id.sort_unstable();
all_id.insert(all_id.binary_search(&mut_id).unwrap_err(), mut_id);
let mut mut_ref = None;
let mut shared_refs = HashMap::new();
let mut cursor = 0;
let mut slice = &mut *mut_refs;
for id in all_id {
let (left, right) = slice.split_at_mut(id.idx() - cursor);
cursor += left.len() + 1;
let (left, right) = right.split_at_mut(1);
let item = &mut left[0];
if id == mut_id {
assert!(mut_ref.is_none());
mut_ref = Some(item);
} else {
shared_refs.insert(id, &*item);
}
slice = right;
}
(mut_ref.unwrap(), shared_refs)
}
/*
* Run inlining on a single function. Pass a mutable reference to the function
* to modify and shared references for all called functions.
*/
fn inline_func(
editor: &mut FunctionEditor,
called: HashMap<FunctionID, &FunctionEditor>,
single_return_nodes: &Vec<Option<NodeID>>,
dc_param_idx_to_dc_id: &Vec<DynamicConstantID>,
) {
let first_num_nodes = editor.func().nodes.len();
for id in (0..first_num_nodes).map(NodeID::new) {
// Break down the call node.
let Node::Call {
control,
function,
ref dynamic_constants,
ref args,
} = editor.func().nodes[id.idx()]
else {
continue;
};
// Assemble all the info we'll need to do the edit.
let dcs_a = &dc_param_idx_to_dc_id[..dynamic_constants.len()];
let dcs_b = dynamic_constants.to_vec();
let substs = dcs_a
.iter()
.map(|i| *i)
.zip(dcs_b.into_iter())
.collect::<HashMap<_, _>>();
let args = args.clone();
let old_num_nodes = editor.func().nodes.len();
let old_id_to_new_id = |old_id: NodeID| NodeID::new(old_id.idx() + old_num_nodes);
let call_pred = get_uses(&editor.func().nodes[control.idx()]);
assert_eq!(call_pred.as_ref().len(), 1);
let call_pred = call_pred.as_ref()[0];
let called_func = called[&function].func();
let call_users = editor.get_users(id);
let call_projs = call_users
.map(|node_id| {
(
node_id,
editor.func().nodes[node_id.idx()]
.try_data_proj()
.expect("PANIC: Call user is not a data projection")
.1,
)
})
.collect::<Vec<_>>();
// We can't inline calls to functions with multiple returns.
let Some(called_return) = single_return_nodes[function.idx()] else {
continue;
};
let called_return_uses = get_uses(&called_func.nodes[called_return.idx()]);
let called_return_pred = called_return_uses.as_ref()[0];
let called_return_data = &called_return_uses.as_ref()[1..];
// Perform the actual edit.
editor.edit(|mut edit| {
// Insert the nodes from the called function. There are a few
// special cases:
// - Start: don't add start nodes - later, we'll replace_all_uses on
// the start node with the one predecessor of the call's region
// node.
// - Parameter: don't add parameter nodes - later, we'll
// replace_all_uses on the parameter nodes with the arguments to
// the call node.
// - Return: don't add return nodes - later, we'll replace_all_uses
// on the call's region node with the predecessor to the return
// node.
for (idx, node) in called_func.nodes.iter().enumerate() {
if node.is_start() || node.is_parameter() || node.is_return() {
// We still need to add some node to make sure the IDs line
// up. Just add a gravestone.
edit.add_node(Node::Start);
continue;
}
// Get the node from the callee function and replace all the
// uses with the to-be IDs in the caller function.
let mut node = node.clone();
if node.is_fork()
|| node.is_constant()
|| node.is_dynamic_constant()
|| node.is_call()
{
substitute_dynamic_constants_in_node(&substs, &mut node, &mut edit);
}
let mut uses = get_uses_mut(&mut node);
for u in uses.as_mut() {
**u = old_id_to_new_id(**u);
}
// Add the node and check that the IDs line up.
let add_id = edit.add_node(node);
assert_eq!(add_id, old_id_to_new_id(NodeID::new(idx)));
// Copy the schedule from the callee.
let callee_schedule = &called_func.schedules[idx];
for schedule in callee_schedule {
edit = edit.add_schedule(add_id, schedule.clone())?;
}
// Copy the labels from the callee.
let callee_labels = &called_func.labels[idx];
for label in callee_labels {
edit = edit.add_label(add_id, *label)?;
}
}
// Stitch the control use of the inlined start node with the
// predecessor control node of the call's region.
let start_node = &called_func.nodes[0];
assert!(start_node.is_start());
let start_id = old_id_to_new_id(NodeID::new(0));
edit = edit.replace_all_uses(start_id, call_pred)?;
// Stich the control use of the original call node's region with
// the predecessor control of the inlined function's return.
edit = edit.replace_all_uses(
control,
if called_return_pred == NodeID::new(0) {
call_pred
} else {
old_id_to_new_id(called_return_pred)
},
)?;
// Replace and delete the call's (data projection) users
for (proj_id, proj_idx) in call_projs {
let proj_val = called_return_data[proj_idx];
edit = edit.replace_all_uses(proj_id, old_id_to_new_id(proj_val))?;
edit = edit.delete_node(proj_id)?;
}
// Stitch uses of parameter nodes in the inlined function to the IDs
// of arguments provided to the call node.
for (node_idx, node) in called_func.nodes.iter().enumerate() {
if let Node::Parameter { index } = node {
let param_id = old_id_to_new_id(NodeID::new(node_idx));
edit = edit.replace_all_uses(param_id, args[*index])?;
}
}
// Finally delete the call node
edit = edit.delete_node(control)?;
edit = edit.delete_node(id)?;
Ok(edit)
});
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum ParameterLattice {
Top,
Constant(ConstantID),
// Dynamic constant
DynamicConstant(DynamicConstantID, FunctionID),
Bottom,
}
impl ParameterLattice {
fn from_node(node: &Node, func_id: FunctionID) -> Self {
use ParameterLattice::*;
match node {
Node::Undef { ty: _ } => Top,
Node::Constant { id } => Constant(*id),
Node::DynamicConstant { id } => DynamicConstant(*id, func_id),
_ => Bottom,
}
}
fn meet(&mut self, b: Self, cons: Ref<'_, Vec<Constant>>, dcs: Ref<'_, Vec<DynamicConstant>>) {
use ParameterLattice::*;
*self = match (*self, b) {
(Top, b) => b,
(a, Top) => a,
(Bottom, _) | (_, Bottom) => Bottom,
(Constant(id_a), Constant(id_b)) => {
if id_a == id_b {
Constant(id_a)
} else {
Bottom
}
}
(DynamicConstant(dc_a, f_a), DynamicConstant(dc_b, f_b)) => {
if dc_a == dc_b && f_a == f_b {
DynamicConstant(dc_a, f_a)
} else if let (
ir::DynamicConstant::Constant(dcv_a),
ir::DynamicConstant::Constant(dcv_b),
) = (&dcs[dc_a.idx()], &dcs[dc_b.idx()])
&& *dcv_a == *dcv_b
{
DynamicConstant(dc_a, f_a)
} else {
Bottom
}
}
(DynamicConstant(dc, _), Constant(con)) | (Constant(con), DynamicConstant(dc, _)) => {
match (&cons[con.idx()], &dcs[dc.idx()]) {
(ir::Constant::UnsignedInteger64(conv), ir::DynamicConstant::Constant(dcv))
if *conv as usize == *dcv =>
{
Constant(con)
}
_ => Bottom,
}
}
}
}
}
/*
* Top level function to inline constant parameters and constant dynamic
* constant parameters. Identifies functions that are:
*
* 1. Not marked as entry.
* 2. At every call site, a particular parameter is always a specific constant
* or dynamic constant.
*
* These functions can have that constant "inlined" - the parameter is removed
* and all uses of the parameter becomes uses of the constant directly.
*/
pub fn const_inline(
editors: &mut [FunctionEditor],
callgraph: &CallGraph,
inline_collections: bool,
) {
// Run const inlining on each function, starting at the most shallow
// function first, since we want to propagate constants down the call graph.
for func_id in callgraph.topo().into_iter().rev() {
let func = editors[func_id.idx()].func();
if func.entry || callgraph.num_callers(func_id) == 0 {
continue;
}
// Figure out what we know about the parameters to this function.
let mut param_lattice = vec![ParameterLattice::Top; func.param_types.len()];
let mut callers = vec![];
for caller in callgraph.get_callers(func_id) {
let editor = &editors[caller.idx()];
let nodes = &editor.func().nodes;
for id in editor.node_ids() {
if let Some((_, callee, _, args)) = nodes[id.idx()].try_call()
&& callee == func_id
{
if editor.is_mutable(id) {
for (idx, id) in args.into_iter().enumerate() {
let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
param_lattice[idx].meet(
lattice,
editor.get_constants(),
editor.get_dynamic_constants(),
);
}
} else {
// If we can't modify the call node in the caller, then
// we can't perform the inlining.
param_lattice = vec![ParameterLattice::Bottom; func.param_types.len()];
}
callers.push((caller, id));
}
}
}
if param_lattice.iter().all(|v| *v == ParameterLattice::Bottom) {
continue;
}
// Replace the arguments.
let editor = &mut editors[func_id.idx()];
let mut param_idx_to_ids: HashMap<usize, Vec<NodeID>> = HashMap::new();
for id in editor.node_ids() {
if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
param_idx_to_ids.entry(idx).or_default().push(id);
}
}
let mut params_to_remove = vec![];
let success = editor.edit(|mut edit| {
let mut param_tys = edit.get_param_types().clone();
let mut decrement_index_by = 0;
for idx in 0..param_tys.len() {
if (inline_collections
|| edit
.get_type(param_tys[idx - decrement_index_by])
.is_primitive())
&& let Some(node) = match param_lattice[idx] {
ParameterLattice::Top => Some(Node::Undef {
ty: param_tys[idx - decrement_index_by],
}),
ParameterLattice::Constant(id) => Some(Node::Constant { id }),
ParameterLattice::DynamicConstant(id, _) => {
// Rust moment.
let maybe_cons = edit.get_dynamic_constant(id).try_constant();
if let Some(val) = maybe_cons {
Some(Node::DynamicConstant {
id: edit.add_dynamic_constant(DynamicConstant::Constant(val)),
})
} else {
None
}
}
_ => None,
}
&& let Some(ids) = param_idx_to_ids.get(&idx)
{
let node = edit.add_node(node);
for id in ids {
edit = edit.replace_all_uses(*id, node)?;
edit = edit.delete_node(*id)?;
}
param_tys.remove(idx - decrement_index_by);
params_to_remove.push(idx);
decrement_index_by += 1;
} else if decrement_index_by != 0
&& let Some(ids) = param_idx_to_ids.get(&idx)
{
let node = edit.add_node(Node::Parameter {
index: idx - decrement_index_by,
});
for id in ids {
edit = edit.replace_all_uses(*id, node)?;
edit = edit.delete_node(*id)?;
}
}
}
edit.set_param_types(param_tys);
Ok(edit)
});
params_to_remove.reverse();
// Update callers.
if success {
for (caller, call) in callers {
let editor = &mut editors[caller.idx()];
let success = editor.edit(|mut edit| {
let Node::Call {
control,
function,
dynamic_constants,
args,
} = edit.get_node(call).clone()
else {
panic!();
};
let mut args = args.into_vec();
for idx in params_to_remove.iter() {
args.remove(*idx);
}
let node = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: args.into_boxed_slice(),
});
edit = edit.replace_all_uses(call, node)?;
edit = edit.delete_node(call)?;
Ok(edit)
});
assert!(success);
}
}
}
}