extern crate hercules_ir; use std::collections::HashMap; use std::iter::zip; use self::hercules_ir::callgraph::*; use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; use self::hercules_ir::schedule::*; use crate::*; /* * Top level function to run inlining. Currently, inlines every function call, * since mutual recursion is not valid in Hercules IR. */ pub fn inline( editors: &mut [FunctionEditor], callgraph: &CallGraph, mut plans: Option<&mut Vec<Plan>>, ) { // Step 1: run topological sort on the call graph to inline the "deepest" // function first. Mutual recursion is not currently supported, so assert // that a topological sort exists. let mut num_calls: Vec<usize> = (0..editors.len()) .map(|idx| callgraph.num_callees(FunctionID::new(idx))) .collect(); let mut no_calls_stack: Vec<FunctionID> = num_calls .iter() .enumerate() .filter(|(_, num)| **num == 0) .map(|(idx, _)| FunctionID::new(idx)) .collect(); let mut topo = vec![]; while let Some(no_call_func) = no_calls_stack.pop() { topo.push(no_call_func); for caller in callgraph.get_callers(no_call_func) { num_calls[caller.idx()] -= 1; if num_calls[caller.idx()] == 0 { no_calls_stack.push(*caller); } } } assert_eq!( topo.len(), editors.len(), "PANIC: Found mutual recursion in Hercules IR." ); // Step 2: make sure each function has a single return node. If an edit // failed to make a function have a single return node, then we can't inline // calls of it. let single_return_nodes: Vec<_> = editors .iter_mut() .map(|editor| collapse_returns(editor)) .collect(); // Step 3: verify that each possible dynamic constant parameter index has a // single unique dynamic constant ID. If this isn't true, dynamic constant // substitution won't work, and this should be true anyway! let mut found_idxs = HashMap::new(); for id in editors[0].dynamic_constant_ids() { let dc = editors[0].get_dynamic_constant(id); if let DynamicConstant::Parameter(idx) = *dc { assert!(!found_idxs.contains_key(&idx)); found_idxs.insert(idx, id); } } let mut dc_param_idx_to_dc_id = vec![]; for idx in 0..found_idxs.len() { dc_param_idx_to_dc_id.push(found_idxs[&idx]); } // Step 4: run inlining on each function individually. Iterate the functions // in topological order. for to_inline_id in topo { // Since Rust cannot analyze the accesses into an array of mutable // references, we need to do some weirdness here to simultaneously get: // 1. A mutable reference to the function we're modifying. // 2. Shared references to all of the functions called by that function. // We need to get the same for plans, if we receive them. let callees = callgraph.get_callees(to_inline_id); let editor_refs = get_mut_and_immuts(editors, to_inline_id, callees); let plan_refs = plans .as_mut() .map(|plans| get_mut_and_immuts(*plans, to_inline_id, callees)); inline_func( editor_refs.0, editor_refs.1, plan_refs, &single_return_nodes, &dc_param_idx_to_dc_id, ); } } /* * Helper function to get from an array of mutable references: * 1. A single mutable reference. * 2. Several shared references. * Where none of the references alias. We need to use this both for function * editors and plans. */ fn get_mut_and_immuts<'a, T, I: ID>( mut_refs: &'a mut [T], mut_id: I, shared_id: &[I], ) -> (&'a mut T, HashMap<I, &'a T>) { let mut all_id = Vec::from(shared_id); all_id.sort_unstable(); all_id.insert(all_id.binary_search(&mut_id).unwrap_err(), mut_id); let mut mut_ref = None; let mut shared_refs = HashMap::new(); let mut cursor = 0; let mut slice = &mut *mut_refs; for id in all_id { let (left, right) = slice.split_at_mut(id.idx() - cursor); cursor += left.len() + 1; let (left, right) = right.split_at_mut(1); let item = &mut left[0]; if id == mut_id { assert!(mut_ref.is_none()); mut_ref = Some(item); } else { shared_refs.insert(id, &*item); } slice = right; } (mut_ref.unwrap(), shared_refs) } /* * Run inlining on a single function. Pass a mutable reference to the function * to modify and shared references for all called functions. */ fn inline_func( editor: &mut FunctionEditor, called: HashMap<FunctionID, &FunctionEditor>, plans: Option<(&mut Plan, HashMap<FunctionID, &Plan>)>, single_return_nodes: &Vec<Option<NodeID>>, dc_param_idx_to_dc_id: &Vec<DynamicConstantID>, ) { let first_num_nodes = editor.func().nodes.len(); for id in (0..first_num_nodes).map(NodeID::new) { // Break down the call node. let Node::Call { control, function, ref dynamic_constants, ref args, } = editor.func().nodes[id.idx()] else { continue; }; // Assemble all the info we'll need to do the edit. let dcs_a = &dc_param_idx_to_dc_id[..dynamic_constants.len()]; let dcs_b = dynamic_constants.clone(); let args = args.clone(); let old_num_nodes = editor.func().nodes.len(); let old_id_to_new_id = |old_id: NodeID| NodeID::new(old_id.idx() + old_num_nodes); let call_pred = get_uses(&editor.func().nodes[control.idx()]); assert_eq!(call_pred.as_ref().len(), 1); let call_pred = call_pred.as_ref()[0]; let called_func = called[&function].func(); // We can't inline calls to functions with multiple returns. let Some(called_return) = single_return_nodes[function.idx()] else { continue; }; let called_return_uses = get_uses(&called_func.nodes[called_return.idx()]); let called_return_pred = called_return_uses.as_ref()[0]; let called_return_data = called_return_uses.as_ref()[1]; // Perform the actual edit. let success = editor.edit(|mut edit| { // Insert the nodes from the called function. There are a few // special cases: // - Start: don't add start nodes - later, we'll replace_all_uses on // the start node with the one predecessor of the call's region // node. // - Parameter: don't add parameter nodes - later, we'll // replace_all_uses on the parameter nodes with the arguments to // the call node. // - Return: don't add return nodes - later, we'll replace_all_uses // on the call's region node with the predecessor to the return // node. for (idx, node) in called_func.nodes.iter().enumerate() { if node.is_start() || node.is_parameter() || node.is_return() { // We still need to add some node to make sure the IDs line // up. Just add a gravestone. edit.add_node(Node::Start); continue; } // Get the node from the callee function and replace all the // uses with the to-be IDs in the caller function. let mut node = node.clone(); if node.is_fork() || node.is_constant() || node.is_dynamic_constant() || node.is_call() { // We have to perform the subsitution in two steps. First, // we map every dynamic constant A to a non-sense dynamic // constant ID. Second, we map each non-sense dynamic // constant ID to the appropriate dynamic constant B. Why // not just do this in one step from A to B? We update // dynamic constants one at a time, so imagine the following // A -> B mappings: // ID 0 -> ID 1 // ID 1 -> ID 0 // First, we apply the first mapping. This changes all // references to dynamic constant 0 to dynamic constant 1. // Then, we apply the second mapping. This updates all // already present references to dynamic constant 1, as well // as the new references we just made in the first step. We // actually want to institute all the updates // *simultaneously*, hence the two step maneuver. let num_dcs = edit.num_dynamic_constants(); for (dc_a, dc_n) in zip(dcs_a, num_dcs..) { substitute_dynamic_constants_in_node( *dc_a, DynamicConstantID::new(dc_n), &mut node, &mut edit, ); } for (dc_n, dc_b) in zip(num_dcs.., dcs_b.iter()) { substitute_dynamic_constants_in_node( DynamicConstantID::new(dc_n), *dc_b, &mut node, &mut edit, ); } } let mut uses = get_uses_mut(&mut node); for u in uses.as_mut() { **u = old_id_to_new_id(**u); } // Add the node and check that the IDs line up. let add_id = edit.add_node(node); assert_eq!(add_id, old_id_to_new_id(NodeID::new(idx))); } // Stitch the control use of the inlined start node with the // predecessor control node of the call's region. let start_node = &called_func.nodes[0]; assert!(start_node.is_start()); let start_id = old_id_to_new_id(NodeID::new(0)); edit = edit.replace_all_uses(start_id, call_pred)?; // Stich the control use of the original call node's region with // the predecessor control of the inlined function's return. edit = edit.replace_all_uses( control, if called_return_pred == NodeID::new(0) { call_pred } else { old_id_to_new_id(called_return_pred) }, )?; // Stitch uses of parameter nodes in the inlined function to the IDs // of arguments provided to the call node. for (node_idx, node) in called_func.nodes.iter().enumerate() { if let Node::Parameter { index } = node { let param_id = old_id_to_new_id(NodeID::new(node_idx)); edit = edit.replace_all_uses(param_id, args[*index])?; } } // Finally, delete the call node. edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?; edit = edit.delete_node(control)?; edit = edit.delete_node(id)?; Ok(edit) }); } } /* * Substitute all uses of a dynamic constant A with dynamic constant B in a * dynamic constant C. Return the substituted version of C, once memoized. Takes * a mutable edit instead of an editor since this may create new dynamic * constants, which can only be done inside an edit. */ fn substitute_dynamic_constants( dc_a: DynamicConstantID, dc_b: DynamicConstantID, dc_c: DynamicConstantID, edit: &mut FunctionEdit, ) -> DynamicConstantID { // If C is just A, then just replace all of C with B. if dc_a == dc_c { return dc_b; } // Since we substitute non-sense dynamic constant IDs earlier, we explicitly // check that the provided ID to replace inside of is valid. Otherwise, // ignore. if dc_c.idx() >= edit.num_dynamic_constants() { return dc_c; } // If C is not just A, look inside of it to possibly substitute a child DC. let dc_clone = edit.get_dynamic_constant(dc_c).clone(); match dc_clone { DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc_c, // This is a certified Rust moment. DynamicConstant::Add(left, right) => { let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Add(new_left, new_right)) } else { dc_c } } DynamicConstant::Sub(left, right) => { let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right)) } else { dc_c } } DynamicConstant::Mul(left, right) => { let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Mul(new_left, new_right)) } else { dc_c } } DynamicConstant::Div(left, right) => { let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right)) } else { dc_c } } DynamicConstant::Rem(left, right) => { let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right)) } else { dc_c } } } } /* * Substitute all uses of a dynamic constant A with dynamic constant B in a * type. Return the substituted version of the type, once memozied. */ fn substitute_dynamic_constants_in_type( dc_a: DynamicConstantID, dc_b: DynamicConstantID, ty: TypeID, edit: &mut FunctionEdit, ) -> TypeID { // Look inside the type for references to dynamic constants. let ty_clone = edit.get_type(ty).clone(); match ty_clone { Type::Product(ref fields) => { let new_fields = fields .into_iter() .map(|field_id| substitute_dynamic_constants_in_type(dc_a, dc_b, *field_id, edit)) .collect(); if new_fields != *fields { edit.add_type(Type::Product(new_fields)) } else { ty } } Type::Summation(ref variants) => { let new_variants = variants .into_iter() .map(|variant_id| { substitute_dynamic_constants_in_type(dc_a, dc_b, *variant_id, edit) }) .collect(); if new_variants != *variants { edit.add_type(Type::Summation(new_variants)) } else { ty } } Type::Array(elem_ty, ref dims) => { let new_elem_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, elem_ty, edit); let new_dims = dims .into_iter() .map(|dim_id| substitute_dynamic_constants(dc_a, dc_b, *dim_id, edit)) .collect(); if new_elem_ty != elem_ty || new_dims != *dims { edit.add_type(Type::Array(new_elem_ty, new_dims)) } else { ty } } _ => ty, } } /* * Substitute all uses of a dynamic constant A with dynamic constant B in a * constant. Return the substituted version of the constant, once memozied. */ fn substitute_dynamic_constants_in_constant( dc_a: DynamicConstantID, dc_b: DynamicConstantID, cons: ConstantID, edit: &mut FunctionEdit, ) -> ConstantID { // Look inside the type for references to dynamic constants. let cons_clone = edit.get_constant(cons).clone(); match cons_clone { Constant::Product(ty, fields) => { let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); let new_fields = fields .iter() .map(|field_id| { substitute_dynamic_constants_in_constant(dc_a, dc_b, *field_id, edit) }) .collect(); if new_ty != ty || new_fields != fields { edit.add_constant(Constant::Product(new_ty, new_fields)) } else { cons } } Constant::Summation(ty, idx, variant) => { let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); let new_variant = substitute_dynamic_constants_in_constant(dc_a, dc_b, variant, edit); if new_ty != ty || new_variant != variant { edit.add_constant(Constant::Summation(new_ty, idx, new_variant)) } else { cons } } Constant::Array(ty) => { let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); if new_ty != ty { edit.add_constant(Constant::Array(new_ty)) } else { cons } } _ => cons, } } /* * Substitute all uses of a dynamic constant A with dynamic constant B in a * node. */ fn substitute_dynamic_constants_in_node( dc_a: DynamicConstantID, dc_b: DynamicConstantID, node: &mut Node, edit: &mut FunctionEdit, ) { match node { Node::Fork { control: _, factors, } => { for factor in factors.into_iter() { *factor = substitute_dynamic_constants(dc_a, dc_b, *factor, edit); } } Node::Constant { id } => { *id = substitute_dynamic_constants_in_constant(dc_a, dc_b, *id, edit); } Node::DynamicConstant { id } => { *id = substitute_dynamic_constants(dc_a, dc_b, *id, edit); } Node::Call { control: _, function: _, dynamic_constants, args: _, } => { for dc_arg in dynamic_constants.into_iter() { *dc_arg = substitute_dynamic_constants(dc_a, dc_b, *dc_arg, edit); } } _ => {} } } /* * Top level function to make a function have only a single return. */ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { let returns: Vec<NodeID> = (0..editor.func().nodes.len()) .filter(|idx| editor.func().nodes[*idx].is_return()) .map(NodeID::new) .collect(); assert!(!returns.is_empty()); if returns.len() == 1 { return Some(returns[0]); } let preds_before_returns: Vec<NodeID> = returns .iter() .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0]) .collect(); let data_to_return: Vec<NodeID> = returns .iter() .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[1]) .collect(); // All of the old returns get replaced in a single edit. let mut new_return = None; editor.edit(|mut edit| { let region = edit.add_node(Node::Region { preds: preds_before_returns.into_boxed_slice(), }); let phi = edit.add_node(Node::Phi { control: region, data: data_to_return.into_boxed_slice(), }); for ret in returns { edit = edit.delete_node(ret)?; } new_return = Some(edit.add_node(Node::Return { control: region, data: phi, })); Ok(edit) }); new_return }