diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 68fdc26cdeb7b414273c33250b22225b15b4d0da..7a0158fbc3d4dcef22f1f6231ed67d39b84346a8 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1318,6 +1318,14 @@ impl Node { } } + pub fn try_data_proj(&self) -> Option<(NodeID, usize)> { + if let Node::DataProjection { data, selection } = self { + Some((*data, *selection)) + } else { + None + } + } + pub fn try_phi(&self) -> Option<(NodeID, &[NodeID])> { if let Node::Phi { control, data } = self { Some((*control, data)) diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 2a5ad9afdf06f2591ea25c5b97e5bef0ceff2282..99187dd2bbfb2a9bbc2bf5937cb40f5246ba5b29 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -211,6 +211,13 @@ fn inline_func( }, )?; + // Replace and delete the call's (data projection) users + for (proj_id, proj_idx) in call_projs { + let proj_val = called_return_data[proj_idx]; + edit = edit.replace_all_uses(proj_id, old_id_to_new_id(proj_val))?; + edit = edit.delete_node(proj_id)?; + } + // Stitch uses of parameter nodes in the inlined function to the IDs // of arguments provided to the call node. for (node_idx, node) in called_func.nodes.iter().enumerate() { @@ -220,12 +227,7 @@ fn inline_func( } } - // Replace and delete the call's (data projection) users and the call node - for (proj_id, proj_idx) in call_proj { - edit = - edit.replace_all_uses(proj_id, old_id_to_new_id(called_return_data[proj_idx]))?; - edit = edit.delete_node(proj_id)?; - } + // Finally delete the call node edit = edit.delete_node(control)?; edit = edit.delete_node(id)?; diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index 944ef8fd02e54b6164c7eb185c5e5e9aa27b5a28..32fa9cc8e612831c666455555b6c9f550be7caa4 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -5,466 +5,196 @@ use hercules_ir::ir::*; use crate::*; -/** - * Given an editor for each function in a module, return V s.t. - * V[i] = true iff every call node to the function with index i - * is editable. If there are no calls to this function, V[i] = true. - */ -fn get_editable_callsites(editors: &mut Vec<FunctionEditor>) -> Vec<bool> { - let mut callsites_editable = vec![true; editors.len()]; - for editor in editors { - for (idx, (_, function, _, _)) in editor - .func() - .nodes - .iter() - .enumerate() - .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) - { - if !editor.is_mutable(NodeID::new(idx)) { - callsites_editable[function.idx()] = false; - } - } - } - callsites_editable -} - -/** - * Given a type tree, return a Vec containing all leaves which are not units. - */ -fn get_nonempty_leaves(edit: &FunctionEdit, type_id: &TypeID) -> Vec<TypeID> { - let ty = edit.get_type(*type_id).clone(); - match ty { - Type::Product(type_ids) => { - let mut leaves = vec![]; - for type_id in type_ids { - leaves.extend(get_nonempty_leaves(&edit, &type_id)) - } - leaves - } - _ => vec![*type_id], - } -} - -/** - * Given a `source` NodeID which produces a product containing - * all nonempty leaves of the type tree for `type_id` in order, build - * a node producing the `type_id`. +/* + * Top-level function for running interprocedural analysis. * - * `offset` represents the index at which to begin reading - * elements of the `source` product. + * IP SROA expects that all nodes in all functions provided to it can be edited, + * since it needs to be able to modify both the functions whose types are being + * changed and call sites of those functions. What functions to run IP SROA on + * is therefore specified by a separate argument. * - * Returns a 3-tuple of - * 1. Node producing the `type` - * 2. "Next" offset, i.e. `offset` + number of reads performed to build (1) - * 3. List of node IDs which read `source` (tracked so that these will not - * be replaced by replace_all_uses_where) + * This optimization also takes an allow_sroa_arrays arguments (like non-IP + * SROA) which controls whether it will break up products of arrays. */ -fn build_uncompressed_product( - edit: &mut FunctionEdit, - source: &NodeID, - type_id: &TypeID, - offset: usize, -) -> (NodeID, usize, Vec<NodeID>) { - let ty = edit.get_type(*type_id).clone(); - match ty { - Type::Product(child_type_ids) => { - // Step 1. Create an empty constant for the type. We'll write - // child values into this constant. - let empty_constant_id = edit.add_zero_constant(*type_id); - let empty_constant_node = edit.add_node(Node::Constant { - id: empty_constant_id, - }); - // Step 2. Build a node that generates each inner type. - // Since `source` contains nonempty leaves *in order*, - // we must process inner types in order; as part of this, - // inner type i+1 must read from where inner type i left off, - // hence we track the `current_offset` at which we are reading. - // Similarly, to combine results of all recursive calls, - // we keep the invariant that, at iteration i+1, currently_writing_to - // is an instance of `type_id` for which the first i elements - // have been populated based on inorder nonempty leaves - // (and, at iteration 0, it is empty). - let mut current_offset = offset; - let mut currently_writing_to = empty_constant_node; - let mut readers = vec![]; - for (idx, child_type_id) in child_type_ids.iter().enumerate() { - let (child_data, next_offset, child_readers) = - build_uncompressed_product(edit, source, child_type_id, current_offset); - current_offset = next_offset; - currently_writing_to = edit.add_node(Node::Write { - collect: currently_writing_to, - data: child_data, - indices: Box::new([Index::Field(idx)]), - }); - readers.extend(child_readers) - } - (currently_writing_to, current_offset, readers) - } - _ => { - // If the type is not a product, then we've reached a nonempty - // leaf, which we must read from source. Since this is a single - // read, the new offset increases by only 1. - let reader = edit.add_node(Node::Read { - collect: *source, - indices: Box::new([Index::Field(offset)]), - }); - (reader, offset + 1, vec![reader]) +pub fn interprocedural_sroa( + editors: &mut Vec<FunctionEditor>, + types: &Vec<Vec<TypeID>>, + func_selection: &Vec<bool>, + allow_sroa_arrays: bool, +) { + let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| { + editor.get_type(typ).is_product() + && (allow_sroa_arrays || !type_contains_array(editor, typ)) + }; + + let callsites = get_callsites(editors); + + for ((func_id, apply), callsites) in (0..func_selection.len()).map(FunctionID::new).zip(func_selection.iter()).zip(callsites.into_iter()) { + if !apply { + continue; } - } -} -/** - * Given a node with a product value, read the product's values - * *in order* into the nonempty leaves of a product type represented - * by type_id. Returns the ID of the resulting node, as well as the IDs - * of all nodes which read from `node_id`. - */ -fn uncompress_product( - edit: &mut FunctionEdit, - node_id: &NodeID, - type_id: &TypeID, -) -> (NodeID, Vec<NodeID>) { - let (uncompressed_value, _, readers) = build_uncompressed_product(edit, node_id, type_id, 0); - (uncompressed_value, readers) -} - -/** -* Let `read_from` be a node with a value of type `type_id`. -* Let `source` be a product value. -* Returns a node representing the value obtained by writing -* nonempty leaves of `read_from` *in order* into `source`, -* starting at `offset`. -* -* `source` should be a product type with at least enough indices -* to support this operation. Typically, `build_compressed_product` -* should be called initially with a `source` created by adding a -* zero constant for the flattened `type_id`. -* -* Returns: -* 1. The ID of the node to which all nonempty leaves have been written -* 2. The first offset after `offset` which was not written to. -*/ -fn build_compressed_product( - mut edit: &mut FunctionEdit, - source: &NodeID, - type_id: &TypeID, - offset: usize, - read_from: &NodeID, -) -> (NodeID, usize) { - let ty = edit.get_type(*type_id).clone(); - match ty { - Type::Product(child_type_ids) => { - // Iterate through child types in order. For each type, construct - // a node that reads the corresponding value from `read_from`, - // and pass it as the node to read from in the recursive call. - let mut next_offset = offset; - let mut next_destination = *source; - for (idx, child_type_id) in child_type_ids.iter().enumerate() { - let child_value = edit.add_node(Node::Read { - collect: *read_from, - indices: Box::new([Index::Field(idx)]), - }); - (next_destination, next_offset) = build_compressed_product( - &mut edit, - &next_destination, - &child_type_id, - next_offset, - &child_value, - ); + let editor: &mut FunctionEditor = &mut editors[func_id.idx()]; + let return_types = &editor.func().return_types.to_vec(); + + // We determine the new return types of the function and track a map + // that tells us how the old return values are constructed from the + // new ones + let mut new_return_types = vec![]; + let mut old_return_type_map = vec![]; + let mut changed = false; + + for ret_typ in return_types.iter() { + if !can_sroa_type(editor, *ret_typ) { + old_return_type_map.push(IndexTree::Leaf(new_return_types.len())); + new_return_types.push(*ret_typ); + } else { + let (types, index) = sroa_type(editor, *ret_typ, new_return_types.len()); + old_return_type_map.push(index); + new_return_types.extend(types); + changed = true; } - (next_destination, next_offset) } - _ => { - let writer = edit.add_node(Node::Write { - collect: *source, - data: *read_from, - indices: Box::new([Index::Field(offset)]), - }); - (writer, offset + 1) - } - } -} - -/** - * Given a node which has a value of the given type (which must be a product) - * generate a new product node created by inserting nonempty leaves of the - * source node *in order*. Returns the ID of this node, as well as the ID of - * its type. - */ -fn compress_product( - edit: &mut FunctionEdit, - node_id: &NodeID, - type_id: &TypeID, -) -> (NodeID, TypeID) { - let nonempty_leaves = get_nonempty_leaves(&edit, &type_id); - let compressed_type = Type::Product(nonempty_leaves.into_boxed_slice()); - let compressed_type_id = edit.add_type(compressed_type); - - let empty_compressed_constant_id = edit.add_zero_constant(compressed_type_id); - let empty_compressed_node_id = edit.add_node(Node::Constant { - id: empty_compressed_constant_id, - }); - - let (compressed_value, _) = - build_compressed_product(edit, &empty_compressed_node_id, type_id, 0, node_id); - - (compressed_value, compressed_type_id) -} - -fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) { - // Track whether we successfully applied edits to return statements, - // so that callsites are only modified when returns were. This is - // initialized to false, so that `is_compressed` is false when - // the corresponding entry in `callsites_editable` is false. - let mut is_compressed = vec![false; editors.len()]; - let old_return_type_ids: Vec<_> = editors - .iter() - .map(|editor| editor.func().return_type) - .collect(); - - // Step 1. Track mapping of dynamic constant indexes to ids, so that - // we can substitute when generating empty constants later. The reason - // this works is that the following property is satisfied: - // Let f and g be two functions such that f has d_f dynamic constants - // and g has d_g dynamic constants. Wlog assume d_f < d_g. Then, the - // first d_f dynamic constants of g are the dynamic constants of f. - // For any call node, the ith dynamic constant in the node is provided - // for the ith dynamic constant of the function called. So, when we need - // to take a type and replace d function dynamic constants with their - // values from a call, it suffices to look at the first d entries of - // dc_param_idx_to_dc_id to get the id of the dynamic constants in the function, - // and then replace dc_param_idx_to_dc_id[i] with call.dynamic_constants[i], - // for all i. - let max_num_dc_params = editors - .iter() - .map(|editor| editor.func().num_dynamic_constants) - .max() - .unwrap(); - let mut dc_args = vec![]; - editors[0].edit(|mut edit| { - dc_args = (0..max_num_dc_params as usize) - .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i))) - .collect(); - Ok(edit) - }); - // Step 2. Modify the return type of all editors corresponding to a function - // for which we can edit every callsite, and the return type is a product. - for (idx, editor) in editors.iter_mut().enumerate() { - if !all_callsites_editable[idx] { + // If the return type is not changed by IP SROA, skip to the next function + if !changed { continue; } - let old_return_id = NodeID::new( - (0..editor.func().nodes.len()) - .filter(|idx| editor.func().nodes[*idx].is_return()) - .next() - .unwrap(), - ); - let old_return_type_id = old_return_type_ids[idx]; - - is_compressed[idx] = editor.get_type(editor.func().return_type).is_product() - && editor.edit(|mut edit| { - let return_node = edit.get_node(old_return_id); - let (return_control, return_data) = return_node.try_return().unwrap(); - - let (compressed_data_id, compressed_type_id) = - compress_product(&mut edit, &return_data, &old_return_type_id); + // Now, modify each return in the current function and the return type + let return_nodes = editor.func().nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| if node.try_return().is_some() { + Some(NodeID::new(idx)) + } else { + None + }) + .collect::<Vec<_>>(); + let success = editor.edit(|mut edit| { + for node in return_nodes { + let Node::Return { control, data } = edit.get_node(node) else { + panic!() + }; + let control = *control; + let data = data.to_vec(); + + let mut new_data = vec![]; + for (idx, (data_id, update_info)) in data.into_iter().zip(old_return_type_map.iter()).enumerate() { + if let IndexTree::Leaf(new_idx) = update_info { + // Unchanged return value + assert!(new_data.len() == *new_idx); + new_data.push(data_id); + } else { + // SROA'd return value + let reads = generate_reads_edit(&mut edit, return_types[idx], data_id); + reads.zip(update_info).for_each(|_, (read_id, ret_idx)| { + assert!(new_data.len() == **ret_idx); + new_data.push(*read_id); + }); + } + } - edit.set_return_type(compressed_type_id); - let new_return_id = edit.add_node(Node::Return { - control: return_control, - data: compressed_data_id, + let new_ret = edit.add_node(Node::Return { + control, + data: new_data.into(), }); - edit.sub_edit(old_return_id, new_return_id); - let edit = edit.replace_all_uses(old_return_id, new_return_id)?; - edit.delete_node(old_return_id) - }); - } - - // Step 3: For every editor, update all mutable callsites corresponding to - // calls to functions which have been compressed. Since we only compress returns - // for functions for which every callsite is mutable, this should never fail, - // so we panic if it does. - for (_, editor) in editors.iter_mut().enumerate() { - let call_node_ids: Vec<_> = (0..editor.func().nodes.len()) - .map(NodeID::new) - .filter(|id| editor.func().nodes[id.idx()].is_call()) - .filter(|id| editor.is_mutable(*id)) - .collect(); - - for call_node_id in call_node_ids { - let (_, function_id, ref dynamic_constants, _) = - editor.func().nodes[call_node_id.idx()].try_call().unwrap(); - if !is_compressed[function_id.idx()] { - continue; + edit.sub_edit(node, new_ret); + edit = edit.delete_node(node)?; } - // Before creating the uncompressed product, we must update - // the type of the uncompressed product to reflect the dynamic - // constants provided when calling the function. Since we can - // only replace one constant at a time, we need to map - // constants to dummy values, and then map these to the - // replacement values (this prevents the case of replacements - // (0->1), (1->2) causing conflicts when we have [0, 1], we should - // get [1, 2], not [2, 2], which a naive loop would generate). - - // A similar loop exists in the inline pass but at the node level. - // If this becomes a common pattern, it would be worth creating - // a better abstraction around bulk replacement. - - let new_dcs = (*dynamic_constants).to_vec(); - let old_dcs = dc_args[..new_dcs.len()].to_vec(); - assert_eq!(old_dcs.len(), new_dcs.len()); - let substs = old_dcs - .into_iter() - .zip(new_dcs.into_iter()) - .collect::<HashMap<_, _>>(); - - let edit_successful = editor.edit(|mut edit| { - let substituted = substitute_dynamic_constants_in_type( - &substs, - old_return_type_ids[function_id.idx()], - &mut edit, - ); - - let (expanded_product, readers) = - uncompress_product(&mut edit, &call_node_id, &substituted); - edit.replace_all_uses_where(call_node_id, expanded_product, |id| { - !readers.contains(id) - }) - }); - - if !edit_successful { - panic!("Tried and failed to edit mutable callsite!"); + edit.set_return_types(new_return_types); + + Ok(edit) + }); + assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); + + // Finally, update calls of this function + // In particular, we actually don't have to update the call node at all but have to update + // its DataProjection users + for (caller, callsite) in callsites { + let editor = &mut editors[caller.idx()]; + assert!(editor.func_id() == caller); + let projs = editor.get_users(callsite).collect::<Vec<_>>(); + for proj_id in projs { + let Node::DataProjection { data: _, selection } = editor.node(proj_id) else { + panic!("Call has a non data-projection user"); + }; + let new_return_info = &old_return_type_map[*selection]; + let typ = types[caller.idx()][proj_id.idx()]; + replace_returned_value(editor, proj_id, typ, new_return_info, callsite); } } } } -fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) { - // Track whether we removed a singleton product from the return of each - // editor's function. Defaults to false so that if the function was not - // edited (i.e. because not all callsites are editable), then no callsites - // will be edited. - let mut singleton_removed = vec![false; editors.len()]; - let old_return_type_ids: Vec<_> = editors - .iter() - .map(|editor| editor.func().return_type) - .collect(); - - // Step 1. For all editors which correspond to a function for whic hall - // callsites are editable, modify their return type by extracting the - // value from the singleton and returning it directly. - for (idx, editor) in editors.iter_mut().enumerate() { - if !all_callsites_editable[idx] { - continue; - } - - let return_type = editor.get_type(old_return_type_ids[idx]).clone(); - singleton_removed[idx] = match return_type { - Type::Product(tys) if tys.len() == 1 && all_callsites_editable[idx] => { - let old_return_id = NodeID::new( - (0..editor.func().nodes.len()) - .filter(|idx| editor.func().nodes[*idx].is_return()) - .next() - .unwrap(), - ); - - editor.edit(|mut edit| { - let (old_control, old_data) = - edit.get_node(old_return_id).try_return().unwrap(); - - let extracted_singleton_id = edit.add_node(Node::Read { - collect: old_data, - indices: Box::new([Index::Field(0)]), - }); - let new_return_id = edit.add_node(Node::Return { - control: old_control, - data: extracted_singleton_id, - }); - edit.sub_edit(old_return_id, new_return_id); - edit.set_return_type(tys[0]); - - edit.delete_node(old_return_id) - }) +fn sroa_type(editor: &FunctionEditor, typ: TypeID, type_index: usize) -> (Vec<TypeID>, IndexTree<usize>) { + match &*editor.get_type(typ) { + Type::Product(ts) => { + let mut res_types = vec![]; + let mut index = type_index; + let mut children = vec![]; + for t in ts { + let (types, child) = sroa_type(editor, *t, index); + index += types.len(); + res_types.extend(types); + children.push(child); } - _ => false, + (res_types, IndexTree::Node(children)) } + _ => (vec![typ], IndexTree::Leaf(type_index)), } +} - // Step 2. For each editor, find all callsites and reconstruct - // the singleton product at each if the return of the corresponding - // function was modified. This should always succeed since we only - // edited functions for which all callsites were mutable, so panic - // if an edit does not succeed. - for editor in editors.iter_mut() { - let call_node_ids: Vec<_> = (0..editor.func().nodes.len()) - .map(NodeID::new) - .filter(|id| editor.func().nodes[id.idx()].is_call()) - .filter(|id| editor.is_mutable(*id)) - .collect(); - - for call_node_id in call_node_ids { - let (_, function, dc_args, _) = - editor.func().nodes[call_node_id.idx()].try_call().unwrap(); - - let dc_args = dc_args.to_vec(); - - if singleton_removed[function.idx()] { - let edit_successful = editor.edit(|mut edit| { - let dc_params = (0..dc_args.len()) - .map(|param_idx| { - edit.add_dynamic_constant(DynamicConstant::Parameter(param_idx)) - }) - .collect::<Vec<_>>(); - let substs = dc_params - .into_iter() - .zip(dc_args.into_iter()) - .collect::<HashMap<_, _>>(); - - let substituted = substitute_dynamic_constants_in_type( - &substs, - old_return_type_ids[function.idx()], - &mut edit, - ); - let empty_constant_id = edit.add_zero_constant(substituted); - let empty_node_id = edit.add_node(Node::Constant { - id: empty_constant_id, - }); - - let restored_singleton_id = edit.add_node(Node::Write { - collect: empty_node_id, - data: call_node_id, - indices: Box::new([Index::Field(0)]), - }); - edit.replace_all_uses_where(call_node_id, restored_singleton_id, |id| { - *id != restored_singleton_id - }) - }); +// Returns a list for each function of the call sites of that function +fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)>> { + let mut callsites = vec![vec![]; editors.len()]; - if !edit_successful { - panic!("Tried and failed to edit mutable callsite!"); - } - } + for editor in editors { + let caller = editor.func_id(); + for (callsite, (_, callee, _, _)) in editor + .func() + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) { + assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); + callsites[callee.idx()].push((caller, NodeID::new(callsite))); } } + + callsites } -pub fn interprocedural_sroa(editors: &mut Vec<FunctionEditor>) { - // SROA is implemented in two phases. First, we flatten (or "compress") - // all product return types, so that they are only depth 1 products, - // and do not contain any empty products. - // Next, if any return type is now a singleton product, we - // remove the singleton and just retun the type directly. - // We only apply these changes to functions for which - // all their callsites are editable. - let all_callsites_editable = get_editable_callsites(editors); - compress_return_products(editors, &all_callsites_editable); - remove_return_singletons(editors, &all_callsites_editable); +// Replaces a projection node (from before the function signature change) based on the of_new_call +// description (which tells how to construct the value from the new returned values). +fn replace_returned_value( + editor: &mut FunctionEditor, + proj_id: NodeID, + proj_typ: TypeID, + of_new_call: &IndexTree<usize>, + call_node: NodeID, +) { + let constant = generate_constant(editor, proj_typ); + + let success = editor.edit(|mut edit| { + let mut new_val = edit.add_node(Node::Constant { + id: constant, + }); + of_new_call.for_each(|idx, selection| { + let new_proj = edit.add_node(Node::DataProjection { + data: call_node, + selection: *selection, + }); + new_val = edit.add_node(Node::Write { + collect: new_val, + data: new_proj, + indices: idx.clone().into(), + }); + }); - // Run DCE to prevent issues with schedule repair. - for editor in editors.iter_mut() { - dce(editor); - } + edit = edit.replace_all_uses(proj_id, new_val)?; + edit.delete_node(proj_id) + }); + assert!(success); } diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index eff0a7296161149b32def1da5e6b46d681fc9434..68a1b25e0c9461f81ecea9c8502b40aeda55ed44 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -328,19 +328,19 @@ pub fn sroa( match &editor.func().nodes[node.idx()] { Node::Return { control, data } => { let control = *control; - let data = data.clone(); + let data = data.to_vec(); let (new_data, changed) = data.into_iter() .fold((vec![], false), |(mut vals, changed), val_id| { - if !can_sroa_type(editor, types[val_id]) { - vals.push(*val_id); + if !can_sroa_type(editor, types[&val_id]) { + vals.push(val_id); (vals, changed) } else { vals.push(reconstruct_product( editor, - types[val_id], - *val_id, + types[&val_id], + val_id, &mut product_nodes, )); (vals, true) @@ -366,19 +366,19 @@ pub fn sroa( let control = *control; let function = *function; let dynamic_constants = dynamic_constants.clone(); - let args = args.clone(); + let args = args.to_vec(); let (new_args, changed) = args.into_iter() .fold((vec![], false), |(mut vals, changed), arg| { - if !can_sroa_type(editor, types[arg]) { - vals.push(*arg); + if !can_sroa_type(editor, types[&arg]) { + vals.push(arg); (vals, changed) } else { vals.push(reconstruct_product( editor, - types[arg], - *arg, + types[&arg], + arg, &mut product_nodes, )); (vals, true) @@ -736,7 +736,7 @@ pub fn sroa( }); } -fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool { +pub fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool { match &*editor.get_type(typ) { Type::Array(_, _) => true, Type::Product(ts) | Type::Summation(ts) => { @@ -978,20 +978,31 @@ fn reconstruct_product( // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and // returns an IndexTree that tracks the nodes reading each leaf field -fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { - let res = generate_reads_at_index(editor, typ, val, vec![]); - res +pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + let mut result = None; + + editor.edit(|mut edit| { + result = Some(generate_reads_edit(&mut edit, typ, val)); + Ok(edit) + }); + + result.unwrap() +} + +// The same as generate_reads but for if we have a FunctionEdit rather than a FunctionEditor +pub fn generate_reads_edit(edit: &mut FunctionEdit, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + generate_reads_at_index_edit(edit, typ, val, vec![]) } // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list -fn generate_reads_at_index( - editor: &mut FunctionEditor, +fn generate_reads_at_index_edit( + edit: &mut FunctionEdit, typ: TypeID, val: NodeID, idx: Vec<Index>, ) -> IndexTree<NodeID> { - let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + let ts: Option<Vec<TypeID>> = if let Some(ts) = edit.get_type(typ).try_product() { Some(ts.into()) } else { None @@ -1004,22 +1015,18 @@ fn generate_reads_at_index( for (i, t) in ts.into_iter().enumerate() { let mut new_idx = idx.clone(); new_idx.push(Index::Field(i)); - fields.push(generate_reads_at_index(editor, t, val, new_idx)); + fields.push(generate_reads_at_index_edit(edit, t, val, new_idx)); } IndexTree::Node(fields) } else { // For non-product types, we've reached a leaf so we generate the read and return it's // information - let mut read_id = None; - editor.edit(|mut edit| { - read_id = Some(edit.add_node(Node::Read { - collect: val, - indices: idx.clone().into(), - })); - Ok(edit) + let read_id = edit.add_node(Node::Read { + collect: val, + indices: idx.into(), }); - IndexTree::Leaf(read_id.expect("Add node canont fail")) + IndexTree::Leaf(read_id) } } @@ -1035,7 +1042,7 @@ macro_rules! add_const { } // Given a type, builds a default constant of that type -fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { +pub fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { let t = editor.get_type(typ).clone(); match t { diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 4ac5a732d12f87b6eb76c2771c2c5addfa42cb50..a888cf74dc223a8e52466daa1bacdec533809de4 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -56,6 +56,7 @@ impl Pass { Pass::Print => num == 1, Pass::Rename => num == 1, Pass::SROA => num == 0 || num == 1, + Pass::InterproceduralSROA => num == 0 || num == 1, Pass::Xdot => num == 0 || num == 1, _ => num == 0, } @@ -70,6 +71,7 @@ impl Pass { Pass::Print => "1", Pass::Rename => "1", Pass::SROA => "0 or 1", + Pass::InterproceduralSROA => "0 or 1", Pass::Xdot => "0 or 1", _ => "0", }