Skip to content
Snippets Groups Projects
inline.rs 17.43 KiB
use std::cell::Ref;
use std::collections::HashMap;

use hercules_ir::*;

use crate::*;

/*
 * Top level function to run inlining. Currently, inlines every function call,
 * since mutual recursion is not valid in Hercules IR.
 */
pub fn inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) {
    // Step 1: run topological sort on the call graph to inline the "deepest"
    // function first.
    let topo = callgraph.topo();

    // Step 2: make sure each function has a single return node. If an edit
    // failed to make a function have a single return node, then we can't inline
    // calls of it.
    let single_return_nodes: Vec<_> = editors
        .iter_mut()
        .map(|editor| collapse_returns(editor))
        .collect();

    // Step 3: get dynamic constant IDs for parameters.
    let max_num_dc_params = editors
        .iter()
        .map(|editor| editor.func().num_dynamic_constants)
        .max()
        .unwrap();
    let mut dc_args = vec![];
    editors[0].edit(|mut edit| {
        dc_args = (0..max_num_dc_params as usize)
            .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i)))
            .collect();
        Ok(edit)
    });

    // Step 4: run inlining on each function individually. Iterate the functions
    // in topological order.
    for to_inline_id in topo {
        // Since Rust cannot analyze the accesses into an array of mutable
        // references, we need to do some weirdness here to simultaneously get:
        // 1. A mutable reference to the function we're modifying.
        // 2. Shared references to all of the functions called by that function.
        let callees = callgraph.get_callees(to_inline_id);
        let editor_refs = get_mut_and_immuts(editors, to_inline_id, callees);
        inline_func(editor_refs.0, editor_refs.1, &single_return_nodes, &dc_args);
    }
}

/*
 * Helper function to get from an array of mutable references:
 * 1. A single mutable reference.
 * 2. Several shared references.
 * Where none of the references alias. We need to use this both for function
 * editors and schedules.
 */
fn get_mut_and_immuts<'a, T, I: ID>(
    mut_refs: &'a mut [T],
    mut_id: I,
    shared_id: &[I],
) -> (&'a mut T, HashMap<I, &'a T>) {
    let mut all_id = Vec::from(shared_id);
    all_id.sort_unstable();
    all_id.insert(all_id.binary_search(&mut_id).unwrap_err(), mut_id);
    let mut mut_ref = None;
    let mut shared_refs = HashMap::new();
    let mut cursor = 0;
    let mut slice = &mut *mut_refs;
    for id in all_id {
        let (left, right) = slice.split_at_mut(id.idx() - cursor);
        cursor += left.len() + 1;
        let (left, right) = right.split_at_mut(1);
        let item = &mut left[0];
        if id == mut_id {
            assert!(mut_ref.is_none());
            mut_ref = Some(item);
        } else {
            shared_refs.insert(id, &*item);
        }
        slice = right;
    }

    (mut_ref.unwrap(), shared_refs)
}

/*
 * Run inlining on a single function. Pass a mutable reference to the function
 * to modify and shared references for all called functions.
 */
fn inline_func(
    editor: &mut FunctionEditor,
    called: HashMap<FunctionID, &FunctionEditor>,
    single_return_nodes: &Vec<Option<NodeID>>,
    dc_param_idx_to_dc_id: &Vec<DynamicConstantID>,
) {
    let first_num_nodes = editor.func().nodes.len();
    for id in (0..first_num_nodes).map(NodeID::new) {
        // Break down the call node.
        let Node::Call {
            control,
            function,
            ref dynamic_constants,
            ref args,
        } = editor.func().nodes[id.idx()]
        else {
            continue;
        };

        // Assemble all the info we'll need to do the edit.
        let dcs_a = &dc_param_idx_to_dc_id[..dynamic_constants.len()];
        let dcs_b = dynamic_constants.to_vec();
        let substs = dcs_a
            .iter()
            .map(|i| *i)
            .zip(dcs_b.into_iter())
            .collect::<HashMap<_, _>>();
        let args = args.clone();
        let old_num_nodes = editor.func().nodes.len();
        let old_id_to_new_id = |old_id: NodeID| NodeID::new(old_id.idx() + old_num_nodes);
        let call_pred = get_uses(&editor.func().nodes[control.idx()]);
        assert_eq!(call_pred.as_ref().len(), 1);
        let call_pred = call_pred.as_ref()[0];
        let called_func = called[&function].func();
        let call_users = editor.get_users(id);
        let call_projs = call_users
            .map(|node_id| {
                (
                    node_id,
                    editor.func().nodes[node_id.idx()]
                        .try_data_proj()
                        .expect("PANIC: Call user is not a data projection")
                        .1,
                )
            })
            .collect::<Vec<_>>();
        // We can't inline calls to functions with multiple returns.
        let Some(called_return) = single_return_nodes[function.idx()] else {
            continue;
        };
        let called_return_uses = get_uses(&called_func.nodes[called_return.idx()]);
        let called_return_pred = called_return_uses.as_ref()[0];
        let called_return_data = &called_return_uses.as_ref()[1..];

        // Perform the actual edit.
        editor.edit(|mut edit| {
            // Insert the nodes from the called function. There are a few
            // special cases:
            // - Start: don't add start nodes - later, we'll replace_all_uses on
            //   the start node with the one predecessor of the call's region
            //   node.
            // - Parameter: don't add parameter nodes - later, we'll
            //   replace_all_uses on the parameter nodes with the arguments to
            //   the call node.
            // - Return: don't add return nodes - later, we'll replace_all_uses
            //   on the call's region node with the predecessor to the return
            //   node.
            for (idx, node) in called_func.nodes.iter().enumerate() {
                if node.is_start() || node.is_parameter() || node.is_return() {
                    // We still need to add some node to make sure the IDs line
                    // up. Just add a gravestone.
                    edit.add_node(Node::Start);
                    continue;
                }
                // Get the node from the callee function and replace all the
                // uses with the to-be IDs in the caller function.
                let mut node = node.clone();
                if node.is_fork()
                    || node.is_constant()
                    || node.is_dynamic_constant()
                    || node.is_call()
                {
                    substitute_dynamic_constants_in_node(&substs, &mut node, &mut edit);
                }
                let mut uses = get_uses_mut(&mut node);
                for u in uses.as_mut() {
                    **u = old_id_to_new_id(**u);
                }
                // Add the node and check that the IDs line up.
                let add_id = edit.add_node(node);
                assert_eq!(add_id, old_id_to_new_id(NodeID::new(idx)));
                // Copy the schedule from the callee.
                let callee_schedule = &called_func.schedules[idx];
                for schedule in callee_schedule {
                    edit = edit.add_schedule(add_id, schedule.clone())?;
                }
                // Copy the labels from the callee.
                let callee_labels = &called_func.labels[idx];
                for label in callee_labels {
                    edit = edit.add_label(add_id, *label)?;
                }
            }

            // Stitch the control use of the inlined start node with the
            // predecessor control node of the call's region.
            let start_node = &called_func.nodes[0];
            assert!(start_node.is_start());
            let start_id = old_id_to_new_id(NodeID::new(0));
            edit = edit.replace_all_uses(start_id, call_pred)?;

            // Stich the control use of the original call node's region with
            // the predecessor control of the inlined function's return.
            edit = edit.replace_all_uses(
                control,
                if called_return_pred == NodeID::new(0) {
                    call_pred
                } else {
                    old_id_to_new_id(called_return_pred)
                },
            )?;

            // Replace and delete the call's (data projection) users
            for (proj_id, proj_idx) in call_projs {
                let proj_val = called_return_data[proj_idx];
                edit = edit.replace_all_uses(proj_id, old_id_to_new_id(proj_val))?;
                edit = edit.delete_node(proj_id)?;
            }

            // Stitch uses of parameter nodes in the inlined function to the IDs
            // of arguments provided to the call node.
            for (node_idx, node) in called_func.nodes.iter().enumerate() {
                if let Node::Parameter { index } = node {
                    let param_id = old_id_to_new_id(NodeID::new(node_idx));
                    edit = edit.replace_all_uses(param_id, args[*index])?;
                }
            }

            // Finally delete the call node
            edit = edit.delete_node(control)?;
            edit = edit.delete_node(id)?;

            Ok(edit)
        });
    }
}

#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum ParameterLattice {
    Top,
    Constant(ConstantID),
    // Dynamic constant
    DynamicConstant(DynamicConstantID, FunctionID),
    Bottom,
}

impl ParameterLattice {
    fn from_node(node: &Node, func_id: FunctionID) -> Self {
        use ParameterLattice::*;
        match node {
            Node::Undef { ty: _ } => Top,
            Node::Constant { id } => Constant(*id),
            Node::DynamicConstant { id } => DynamicConstant(*id, func_id),
            _ => Bottom,
        }
    }

    fn meet(&mut self, b: Self, cons: Ref<'_, Vec<Constant>>, dcs: Ref<'_, Vec<DynamicConstant>>) {
        use ParameterLattice::*;
        *self = match (*self, b) {
            (Top, b) => b,
            (a, Top) => a,
            (Bottom, _) | (_, Bottom) => Bottom,
            (Constant(id_a), Constant(id_b)) => {
                if id_a == id_b {
                    Constant(id_a)
                } else {
                    Bottom
                }
            }
            (DynamicConstant(dc_a, f_a), DynamicConstant(dc_b, f_b)) => {
                if dc_a == dc_b && f_a == f_b {
                    DynamicConstant(dc_a, f_a)
                } else if let (
                    ir::DynamicConstant::Constant(dcv_a),
                    ir::DynamicConstant::Constant(dcv_b),
                ) = (&dcs[dc_a.idx()], &dcs[dc_b.idx()])
                    && *dcv_a == *dcv_b
                {
                    DynamicConstant(dc_a, f_a)
                } else {
                    Bottom
                }
            }
            (DynamicConstant(dc, _), Constant(con)) | (Constant(con), DynamicConstant(dc, _)) => {
                match (&cons[con.idx()], &dcs[dc.idx()]) {
                    (ir::Constant::UnsignedInteger64(conv), ir::DynamicConstant::Constant(dcv))
                        if *conv as usize == *dcv =>
                    {
                        Constant(con)
                    }
                    _ => Bottom,
                }
            }
        }
    }
}

/*
 * Top level function to inline constant parameters and constant dynamic
 * constant parameters. Identifies functions that are:
 *
 * 1. Not marked as entry.
 * 2. At every call site, a particular parameter is always a specific constant
 *    or dynamic constant.
 *
 * These functions can have that constant "inlined" - the parameter is removed
 * and all uses of the parameter becomes uses of the constant directly.
 */
pub fn const_inline(
    editors: &mut [FunctionEditor],
    callgraph: &CallGraph,
    inline_collections: bool,
) {
    // Run const inlining on each function, starting at the most shallow
    // function first, since we want to propagate constants down the call graph.
    for func_id in callgraph.topo().into_iter().rev() {
        let func = editors[func_id.idx()].func();
        if func.entry || callgraph.num_callers(func_id) == 0 {
            continue;
        }

        // Figure out what we know about the parameters to this function.
        let mut param_lattice = vec![ParameterLattice::Top; func.param_types.len()];
        let mut callers = vec![];
        for caller in callgraph.get_callers(func_id) {
            let editor = &editors[caller.idx()];
            let nodes = &editor.func().nodes;
            for id in editor.node_ids() {
                if let Some((_, callee, _, args)) = nodes[id.idx()].try_call()
                    && callee == func_id
                {
                    if editor.is_mutable(id) {
                        for (idx, id) in args.into_iter().enumerate() {
                            let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
                            param_lattice[idx].meet(
                                lattice,
                                editor.get_constants(),
                                editor.get_dynamic_constants(),
                            );
                        }
                    } else {
                        // If we can't modify the call node in the caller, then
                        // we can't perform the inlining.
                        param_lattice = vec![ParameterLattice::Bottom; func.param_types.len()];
                    }
                    callers.push((caller, id));
                }
            }
        }
        if param_lattice.iter().all(|v| *v == ParameterLattice::Bottom) {
            continue;
        }

        // Replace the arguments.
        let editor = &mut editors[func_id.idx()];
        let mut param_idx_to_ids: HashMap<usize, Vec<NodeID>> = HashMap::new();
        for id in editor.node_ids() {
            if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
                param_idx_to_ids.entry(idx).or_default().push(id);
            }
        }
        let mut params_to_remove = vec![];
        let success = editor.edit(|mut edit| {
            let mut param_tys = edit.get_param_types().clone();
            let mut decrement_index_by = 0;
            for idx in 0..param_tys.len() {
                if (inline_collections
                    || edit
                        .get_type(param_tys[idx - decrement_index_by])
                        .is_primitive())
                    && let Some(node) = match param_lattice[idx] {
                        ParameterLattice::Top => Some(Node::Undef {
                            ty: param_tys[idx - decrement_index_by],
                        }),
                        ParameterLattice::Constant(id) => Some(Node::Constant { id }),
                        ParameterLattice::DynamicConstant(id, _) => {
                            // Rust moment.
                            let maybe_cons = edit.get_dynamic_constant(id).try_constant();
                            if let Some(val) = maybe_cons {
                                Some(Node::DynamicConstant {
                                    id: edit.add_dynamic_constant(DynamicConstant::Constant(val)),
                                })
                            } else {
                                None
                            }
                        }
                        _ => None,
                    }
                    && let Some(ids) = param_idx_to_ids.get(&idx)
                {
                    let node = edit.add_node(node);
                    for id in ids {
                        edit = edit.replace_all_uses(*id, node)?;
                        edit = edit.delete_node(*id)?;
                    }
                    param_tys.remove(idx - decrement_index_by);
                    params_to_remove.push(idx);
                    decrement_index_by += 1;
                } else if decrement_index_by != 0
                    && let Some(ids) = param_idx_to_ids.get(&idx)
                {
                    let node = edit.add_node(Node::Parameter {
                        index: idx - decrement_index_by,
                    });
                    for id in ids {
                        edit = edit.replace_all_uses(*id, node)?;
                        edit = edit.delete_node(*id)?;
                    }
                }
            }
            edit.set_param_types(param_tys);
            Ok(edit)
        });
        params_to_remove.reverse();

        // Update callers.
        if success {
            for (caller, call) in callers {
                let editor = &mut editors[caller.idx()];
                let success = editor.edit(|mut edit| {
                    let Node::Call {
                        control,
                        function,
                        dynamic_constants,
                        args,
                    } = edit.get_node(call).clone()
                    else {
                        panic!();
                    };
                    let mut args = args.into_vec();
                    for idx in params_to_remove.iter() {
                        args.remove(*idx);
                    }
                    let node = edit.add_node(Node::Call {
                        control,
                        function,
                        dynamic_constants,
                        args: args.into_boxed_slice(),
                    });
                    edit = edit.replace_all_uses(call, node)?;
                    edit = edit.delete_node(call)?;
                    Ok(edit)
                });
                assert!(success);
            }
        }
    }
}