use hercules_ir::ir::*; use crate::*; /* * Top-level function for running interprocedural analysis. * * IP SROA expects that all nodes in all functions provided to it can be edited, * since it needs to be able to modify both the functions whose types are being * changed and call sites of those functions. What functions to run IP SROA on * is therefore specified by a separate argument. * * This optimization also takes an allow_sroa_arrays arguments (like non-IP * SROA) which controls whether it will break up products of arrays. */ pub fn interprocedural_sroa( editors: &mut Vec<FunctionEditor>, types: &Vec<Vec<TypeID>>, func_selection: &Vec<bool>, allow_sroa_arrays: bool, ) { let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| { editor.get_type(typ).is_product() && (allow_sroa_arrays || !type_contains_array(editor, typ)) }; let callsites = get_callsites(editors); for ((func_id, apply), callsites) in (0..func_selection.len()) .map(FunctionID::new) .zip(func_selection.iter()) .zip(callsites.into_iter()) { if !apply { continue; } let editor: &mut FunctionEditor = &mut editors[func_id.idx()]; let param_types = &editor.func().param_types.to_vec(); let return_types = &editor.func().return_types.to_vec(); // We determine the new param/return types of the function and track a // map that tells us how the old param/return values are constructed // from the new ones. let mut new_param_types = vec![]; let mut old_param_type_map = vec![]; let mut new_return_types = vec![]; let mut old_return_type_map = vec![]; let mut changed = false; for par_typ in param_types.iter() { if !can_sroa_type(editor, *par_typ) { old_param_type_map.push(IndexTree::Leaf(new_param_types.len())); new_param_types.push(*par_typ); } else { let (types, index) = sroa_type(editor, *par_typ, new_param_types.len()); old_param_type_map.push(index); new_param_types.extend(types); changed = true; } } for ret_typ in return_types.iter() { if !can_sroa_type(editor, *ret_typ) { old_return_type_map.push(IndexTree::Leaf(new_return_types.len())); new_return_types.push(*ret_typ); } else { let (types, index) = sroa_type(editor, *ret_typ, new_return_types.len()); old_return_type_map.push(index); new_return_types.extend(types); changed = true; } } // If the param/return types aren't changed by IP SROA, skip to the next // function. if !changed { continue; } // Modify each parameter in the current function and the param types. let mut param_nodes: Vec<_> = vec![vec![]; param_types.len()]; for id in editor.node_ids() { if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() { param_nodes[idx].push(id); } } println!("{}", editor.func().name); let success = editor.edit(|mut edit| { for (idx, ids) in param_nodes.into_iter().enumerate() { let new_indices = &old_param_type_map[idx]; let built = if let IndexTree::Leaf(new_idx) = new_indices { edit.add_node(Node::Parameter { index: *new_idx }) } else { let prod_ty = param_types[idx]; let cons = edit.add_zero_constant(prod_ty); let mut cons = edit.add_node(Node::Constant { id: cons }); new_indices.for_each(|idx: &Vec<Index>, param_idx: &usize| { let param = edit.add_node(Node::Parameter { index: *param_idx }); cons = edit.add_node(Node::Write { collect: cons, data: param, indices: idx.clone().into_boxed_slice(), }); }); cons }; for id in ids { edit = edit.replace_all_uses(id, built)?; edit = edit.delete_node(id)?; } } edit.set_param_types(new_param_types); Ok(edit) }); assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); // Modify each return in the current function and the return types. let return_nodes: Vec<_> = editor .node_ids() .filter(|id| editor.func().nodes[id.idx()].is_return()) .collect(); let success = editor.edit(|mut edit| { for node in return_nodes { let Node::Return { control, data } = edit.get_node(node) else { panic!() }; let control = *control; let data = data.to_vec(); let mut new_data = vec![]; for (idx, (data_id, update_info)) in data.into_iter().zip(old_return_type_map.iter()).enumerate() { if let IndexTree::Leaf(new_idx) = update_info { // Unchanged return value assert!(new_data.len() == *new_idx); new_data.push(data_id); } else { // SROA'd return value let reads = generate_reads_edit(&mut edit, return_types[idx], data_id); reads.zip(update_info).for_each(|_, (read_id, ret_idx)| { assert!(new_data.len() == **ret_idx); new_data.push(*read_id); }); } } let new_ret = edit.add_node(Node::Return { control, data: new_data.into(), }); edit.sub_edit(node, new_ret); edit = edit.delete_node(node)?; } edit.set_return_types(new_return_types); Ok(edit) }); assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); // Finally, update calls of this function. for (caller, callsite) in callsites { let editor = &mut editors[caller.idx()]; assert!(editor.func_id() == caller); let projs = editor.get_users(callsite).collect::<Vec<_>>(); for proj_id in projs { let Node::DataProjection { data: _, selection } = editor.node(proj_id) else { panic!("Call has a non data-projection user"); }; let new_return_info = &old_return_type_map[*selection]; let typ = types[caller.idx()][proj_id.idx()]; replace_returned_value(editor, proj_id, typ, new_return_info, callsite); } let (control, callee, dc_args, args) = editor.func().nodes[callsite.idx()].try_call().unwrap(); let dc_args = dc_args.clone(); let args = args.clone(); let success = editor.edit(|mut edit| { let mut new_args = vec![]; for (idx, (data_id, update_info)) in args.iter().zip(old_param_type_map.iter()).enumerate() { if let IndexTree::Leaf(new_idx) = update_info { // Unchanged parameter value assert!(new_args.len() == *new_idx); new_args.push(*data_id); } else { // SROA'd parameter value let reads = generate_reads_edit(&mut edit, param_types[idx], *data_id); reads.zip(update_info).for_each(|_, (read_id, ret_idx)| { assert!(new_args.len() == **ret_idx); new_args.push(*read_id); }); } } let new_call = edit.add_node(Node::Call { control, function: callee, dynamic_constants: dc_args, args: new_args.into_boxed_slice(), }); edit = edit.replace_all_uses(callsite, new_call)?; edit = edit.delete_node(callsite)?; Ok(edit) }); assert!(success); } } } fn sroa_type( editor: &FunctionEditor, typ: TypeID, type_index: usize, ) -> (Vec<TypeID>, IndexTree<usize>) { match &*editor.get_type(typ) { Type::Product(ts) => { let mut res_types = vec![]; let mut index = type_index; let mut children = vec![]; for t in ts { let (types, child) = sroa_type(editor, *t, index); index += types.len(); res_types.extend(types); children.push(child); } (res_types, IndexTree::Node(children)) } _ => (vec![typ], IndexTree::Leaf(type_index)), } } // Returns a list for each function of the call sites of that function fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)>> { let mut callsites = vec![vec![]; editors.len()]; for editor in editors { let caller = editor.func_id(); for (callsite, (_, callee, _, _)) in editor .func() .nodes .iter() .enumerate() .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) { assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); callsites[callee.idx()].push((caller, NodeID::new(callsite))); } } callsites } // Replaces a projection node (from before the function signature change) based on the of_new_call // description (which tells how to construct the value from the new returned values). fn replace_returned_value( editor: &mut FunctionEditor, proj_id: NodeID, proj_typ: TypeID, of_new_call: &IndexTree<usize>, call_node: NodeID, ) { let constant = generate_constant(editor, proj_typ); let success = editor.edit(|mut edit| { let mut new_val = edit.add_node(Node::Constant { id: constant }); of_new_call.for_each(|idx, selection| { let new_proj = edit.add_node(Node::DataProjection { data: call_node, selection: *selection, }); new_val = edit.add_node(Node::Write { collect: new_val, data: new_proj, indices: idx.clone().into(), }); }); edit = edit.replace_all_uses(proj_id, new_val)?; edit.delete_node(proj_id) }); assert!(success); }