From c9445e5256cab3279f3119fd2baa8bfd2e4cc8a8 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 18 Nov 2024 14:50:21 -0600 Subject: [PATCH] Start rewriting SROA --- hercules_ir/src/ir.rs | 8 + hercules_opt/src/pass.rs | 31 ++- hercules_opt/src/sroa.rs | 507 ++++++++++++++++----------------------- juno_frontend/src/lib.rs | 6 + juno_samples/products.jn | 7 + 5 files changed, 250 insertions(+), 309 deletions(-) create mode 100644 juno_samples/products.jn diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5cf549a8..1486f5ce 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -728,6 +728,14 @@ impl Type { None } } + + pub fn try_product(&self) -> Option<&[TypeID]> { + if let Type::Product(ts) = self { + Some(ts) + } else { + None + } + } } impl Constant { diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index cb56f709..a7d48efb 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -494,16 +494,37 @@ impl PassManager { let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let typing = self.typing.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - sroa( + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( &mut self.module.functions[idx], + &constants_ref, + &dynamic_constants_ref, + &types_ref, &def_uses[idx], + ); + sroa( + &mut editor, &reverse_postorders[idx], - &typing[idx], - &self.module.types, - &mut self.module.constants, + &typing[idx] ); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + let edits = &editor.edits(); + if let Some(plans) = self.plans.as_mut() { + repair_plan(&mut plans[idx], &self.module.functions[idx], edits); + } + let grave_mapping = self.module.functions[idx].delete_gravestones(); + if let Some(plans) = self.plans.as_mut() { + plans[idx].fix_gravestones(&grave_mapping); + } } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::Inline => { diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index cb5ecd25..ae5d76dc 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,15 +1,12 @@ -extern crate bitvec; extern crate hercules_ir; -use std::collections::HashMap; -use std::iter::zip; - -use self::bitvec::prelude::*; +use std::collections::{HashMap, LinkedList}; use self::hercules_ir::dataflow::*; -use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; +use crate::*; + /* * Top level function to run SROA, intraprocedurally. Product values can be used * and created by a relatively small number of nodes. Here are *all* of them: @@ -20,11 +17,11 @@ use self::hercules_ir::ir::*; * - Reduce: similarly to phis, reduce nodes can cycle product values through * reduction loops - these get broken up into reduces on the fields * - * + Return: can return a product - these are untouched, and are the sinks for - * unbroken product values + * - Return: can return a product - the product values will be constructed + * at the return site * - * + Parameter: can introduce a product - these are untouched, and are the - * sources for unbroken product values + * - Parameter: can introduce a product - reads will be introduced for each + * field * * - Constant: can introduce a product - these are broken up into constants for * the individual fields @@ -32,334 +29,236 @@ use self::hercules_ir::ir::*; * - Ternary: the select ternary operator can select between products - these * are broken up into ternary nodes for the individual fields * - * + Call: the call node can use a product value as an argument to another - * function, and can produce a product value as a result - these are - * untouched, and are the sink and source for unbroken product values + * - Call: the call node can use a product value as an argument to another + * function, and can produce a product value as a result. Argument values + * will be constructed at the call site and the return value will be broken + * into individual fields * * - Read: the read node reads primitive fields from product values - these get - * replaced by a direct use of the field value from the broken product value, - * but are retained when the product value is unbroken + * replaced by a direct use of the field value * * - Write: the write node writes primitive fields in product values - these get - * replaced by a direct def of the field value from the broken product value, - * but are retained when the product value is unbroken - * - * The nodes above with the list marker "+" are retained for maintaining API/ABI - * compatability with other Hercules functions and the host code. These are - * called "sink" or "source" nodes in comments below. + * replaced by a direct def of the field value */ -pub fn sroa( - function: &mut Function, - def_use: &ImmutableDefUseMap, - reverse_postorder: &Vec<NodeID>, - typing: &Vec<TypeID>, - types: &Vec<Type>, - constants: &mut Vec<Constant>, -) { - // Determine which sources of product values we want to try breaking up. We - // can determine easily on the soure side if a node produces a product that - // shouldn't be broken up by just examining the node type. However, the way - // that products are used is also important for determining if the product - // can be broken up. We backward dataflow this info to the sources of - // product values. - #[derive(PartialEq, Eq, Clone, Debug)] - enum ProductUseLattice { - // The product value used by this node is eventually used by a sink. - UsedBySink, - // This node uses multiple product values - the stored node ID indicates - // which is eventually used by a sink. This lattice value is produced by - // read and write nodes implementing partial indexing. - SpecificUsedBySink(NodeID), - // This node doesn't use a product node, or the product node it does use - // is not in turn used by a sink. - UnusedBySink, - } +pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { + // This map stores a map from NodeID and fields to the NodeID which contains just that field of + // the original value + let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new(); - impl Semilattice for ProductUseLattice { - fn meet(a: &Self, b: &Self) -> Self { - match (a, b) { - (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink, - (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => { - if id1 == id2 { - Self::SpecificUsedBySink(*id1) - } else { - Self::UsedBySink - } - } - (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => { - Self::SpecificUsedBySink(*id) + // First: determine all nodes which interact with products (as described above) + let mut product_nodes = vec![]; + // We track call and return nodes separately since they (may) require constructing new products + // for the call's arguments or the return's value + let mut call_return_nodes = vec![]; + // We track writes separately since they should be processed once their input product has been + // processed, so we handle them after all other nodes (though before reads) + let mut write_nodes = vec![]; + // We also track reads separately because they aren't handled until all other nodes have been + // processed + let mut read_nodes = vec![]; + + let func = editor.func(); + + for node in reverse_postorder { + match func.nodes[node.idx()] { + Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select } + if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node), + + Node::Write { .. } + if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node), + Node::Read { collect, .. } + if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node), + + // We add all calls to the call/return list and check their arguments later + Node::Call { .. } => { + call_return_nodes.push(node); + if editor.get_type(types[node.idx()]).is_product() { + read_nodes.push(node); } - _ => Self::UnusedBySink, } - } + Node::Return { control: _, data } + if editor.get_type(types[data.idx()]).is_product() => + call_return_nodes.push(node), - fn bottom() -> Self { - Self::UsedBySink - } - - fn top() -> Self { - Self::UnusedBySink + _ => () } } - // Run dataflow analysis to find which product values are used by a sink. - let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| { - match function.nodes[id.idx()] { - Node::Return { - control: _, - data: _, - } => { - if types[typing[id.idx()].idx()].is_product() { - ProductUseLattice::UsedBySink - } else { - ProductUseLattice::UnusedBySink - } - } - Node::Call { - control: _, - function: _, - dynamic_constants: _, - args: _, - } => todo!(), - // For reads and writes, we only want to propagate the use of the - // product to the collect input of the node. - Node::Read { - collect, - indices: _, + // Next, we handle calls and returns. For returns, we will insert nodes that read each field of + // the returned product and then write them into a new product. These writes are not put into + // the list of product nodes since they must remain but the reads are so that they will be + // replaced later on. + // For calls, we do a similar process for each (product) argument. Additionally, if the call + // returns a product, we create reads for each field in that product and store it into our + // field map + for node in call_return_nodes { + match &editor.func().nodes[node.idx()] { + Node::Return { control, data } => { + assert!(editor.get_type(types[data.idx()]).is_product()); + let control = *control; + let new_data = reconstruct_product(editor, types[data.idx()], *data); + editor.edit(|mut edit| { + edit.add_node(Node::Return { control, data: new_data }); + edit.delete_node(*node) + }); } - | Node::Write { - collect, - data: _, - indices: _, - } => { - let meet = succ_outs - .iter() - .fold(ProductUseLattice::top(), |acc, latt| { - ProductUseLattice::meet(&acc, latt) - }); - if meet == ProductUseLattice::UnusedBySink { - ProductUseLattice::UnusedBySink - } else { - ProductUseLattice::SpecificUsedBySink(collect) - } - } - // For non-sink nodes. - _ => { - if function.nodes[id.idx()].is_control() { - return ProductUseLattice::UnusedBySink; - } - let meet = succ_outs - .iter() - .fold(ProductUseLattice::top(), |acc, latt| { - ProductUseLattice::meet(&acc, latt) - }); - if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet { - if meet_id == id { - ProductUseLattice::UsedBySink + Node::Call { control, function, dynamic_constants, args } => { + let control = *control; + let function = *function; + let dynamic_constants = dynamic_constants.clone(); + let args = args.clone(); + + // If the call returns a product, we generate reads for each field + let fields = + if editor.get_type(types[node.idx()]).is_product() { + Some(generate_reads(editor, types[node.idx()], *node)) } else { - ProductUseLattice::UnusedBySink + None + }; + + let mut new_args = vec![]; + for arg in args { + if editor.get_type(types[arg.idx()]).is_product() { + new_args.push(reconstruct_product(editor, types[arg.idx()], arg)); + } else { + new_args.push(arg); } - } else { - meet } + editor.dit(|mut edit| { + let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() }); + let edit = edit.replace_all_uses(*node, new_call)?; + let edit = edit.delete_node(*node)?; + + match fields { + None => {} + Some(fields) => { + for (idx, node) in fields { + field_map.insert((new_call, idx), node); + } + } + } + + Ok(edit) + }); } + _ => panic!("Processing non-call or return node") } - }); + } - // Only product values introduced as constants can be replaced by scalars. - let to_sroa: Vec<(NodeID, ConstantID)> = product_uses - .into_iter() - .enumerate() - .filter_map(|(node_idx, product_use)| { - if ProductUseLattice::UnusedBySink == product_use - && types[typing[node_idx].idx()].is_product() - { - function.nodes[node_idx] - .try_constant() - .map(|cons_id| (NodeID::new(node_idx), cons_id)) - } else { - None - } - }) - .collect(); + // Now, we process all other non-read/write nodes that deal with products. + // The first step is to identify the NodeIDs that contain each field of each of these nodes + todo!() +} - // Perform SROA. TODO: repair def-use when there are multiple product - // constants to SROA away. - assert!(to_sroa.len() < 2); - for (constant_node_id, constant_id) in to_sroa { - // Get the field constants to replace the product constant with. - let product_constant = constants[constant_id.idx()].clone(); - let constant_fields = product_constant.try_product_fields().unwrap(); +// Given a product value val of type typ, constructs a copy of that value by extracting all fields +// from that value and then writing them into a new constant +fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID { + let fields = generate_reads(editor, typ, val); + let new_const = generate_constant(editor, typ); - // DFS to find all data nodes that use the product constant. - let to_replace = sroa_dfs(constant_node_id, function, def_use); + // Create a constant node + let mut const_node = None; + editor.edit(|mut edit| { + const_node = Some(edit.add_node(Node::Constant { id: new_const })); + Ok(edit) + }); - // Assemble a mapping from old nodes IDs acting on the product constant - // to new nodes IDs operating on the field constants. - let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace - .iter() - .map(|old_id| match function.nodes[old_id.idx()] { - Node::Phi { - control: _, - data: _, - } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } - | Node::Constant { id: _ } - | Node::Ternary { - op: _, - first: _, - second: _, - third: _, - } - | Node::Write { - collect: _, - data: _, - indices: _, - } => { - let new_ids = (0..constant_fields.len()) - .map(|_| { - let id = NodeID::new(function.nodes.len()); - function.nodes.push(Node::Start); - id - }) - .collect(); - (*old_id, new_ids) - } - Node::Read { - collect: _, - indices: _, - } => (*old_id, vec![]), - _ => panic!("PANIC: Invalid node using a constant product found during SROA."), - }) - .collect(); + // Generate writes for each field + let mut value = const_node.expect("Add node cannot fail"); + for (idx, val) in fields { + editor.edit(|mut edit| { + value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() }); + Ok(edit) + }); + } - // Replace the old nodes with the new nodes. Since we've already - // allocated the node IDs, at this point we can iterate through the to- - // replace nodes in an arbitrary order. - for (old_id, new_ids) in &old_to_new_id_map { - // First, add the new nodes to the node list. - let node = function.nodes[old_id.idx()].clone(); - match node { - // Replace the original constant with constants for each field. - Node::Constant { id: _ } => { - for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) { - function.nodes[new_id.idx()] = Node::Constant { id: *field_id }; - } - } - // Replace writes using the constant as the data use with a - // series of writes writing the invidiual constant fields. TODO: - // handle the case where the constant is the collect use of the - // write node. - Node::Write { - collect, - data, - ref indices, - } => { - // Create the write chain. - assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet."); - let mut collect_def = collect; - for (idx, (new_id, new_data_def)) in - zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate() - { - let mut new_indices = indices.clone().into_vec(); - new_indices.push(Index::Field(idx)); - function.nodes[new_id.idx()] = Node::Write { - collect: collect_def, - data: *new_data_def, - indices: new_indices.into_boxed_slice(), - }; - collect_def = *new_id; - } + value +} - // Replace uses of the old write with the new write. - for user in def_use.get_users(*old_id) { - get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def); - } - } - _ => todo!(), - } +// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and +// returns a list of pairs of the indices and the node that reads that index +fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> { + generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>() +} - // Delete the old node. - function.nodes[old_id.idx()] = Node::Start; +// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) +// fields within this sub-value of val and return the correspondence list +fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> { + let ts : Option<Vec<TypeID>> = { + if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None } + }; + + if let Some(ts) = ts { + // For product values, we will recurse down each of its fields with an extended index + // and the appropriate type of that field + let mut result = LinkedList::new(); + for (i, t) in ts.into_iter().enumerate() { + let mut new_idx = idx.clone(); + new_idx.push(Index::Field(i)); + result.append(&mut generate_reads_at_index(editor, t, val, new_idx)); + } + + result + } else { + // For non-product types, we've reached a leaf so we generate the read and return it's + // information + let mut read_id = None; + editor.edit(|mut edit| { + read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() })); + Ok(edit) + }); + + LinkedList::from([(idx, read_id.expect("Add node cannot fail"))]) } } -fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> { - // Initialize order vector and bitset for tracking which nodes have been - // visited. - let order = Vec::with_capacity(def_uses.num_nodes()); - let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()]; - - // Order and visited are threaded through arguments / return pair of - // sroa_dfs_helper for ownership reasons. - let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited); - order +macro_rules! add_const { + ($editor:ident, $const:expr) => {{ + let mut res = None; + $editor.edit(|mut edit| { + res = Some(edit.add_constant($const)); + Ok(edit) + }); + res.expect("Add constant cannot fail") + }} } -fn sroa_dfs_helper( - node: NodeID, - def: NodeID, - function: &Function, - def_uses: &ImmutableDefUseMap, - mut order: Vec<NodeID>, - mut visited: BitVec<u8, Lsb0>, -) -> (Vec<NodeID>, BitVec<u8, Lsb0>) { - if visited[node.idx()] { - // If already visited, return early. - (order, visited) - } else { - // Set visited to true. - visited.set(node.idx(), true); +// Given a type, builds a default constant of that type +fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { + let t = editor.get_type(typ).clone(); - // Before iterating users, push this node. - order.push(node); - match function.nodes[node.idx()] { - Node::Phi { - control: _, - data: _, + match t { + Type::Product(ts) => { + let mut cs = vec![]; + for t in ts { + cs.push(generate_constant(editor, t)); } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } - | Node::Constant { id: _ } - | Node::Ternary { - op: _, - first: _, - second: _, - third: _, - } => {} - Node::Read { - collect, - indices: _, - } => { - assert_eq!(def, collect); - return (order, visited); - } - Node::Write { - collect, - data, - indices: _, - } => { - if def == data { - return (order, visited); - } - assert_eq!(def, collect); - } - _ => panic!("PANIC: Invalid node using a constant product found during SROA."), + add_const!(editor, Constant::Product(typ, cs.into())) } - - // Iterate over users, if we shouldn't stop here. - for user in def_uses.get_users(node) { - (order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited); + Type::Boolean => add_const!(editor, Constant::Boolean(false)), + Type::Integer8 => add_const!(editor, Constant::Integer8(0)), + Type::Integer16 => add_const!(editor, Constant::Integer16(0)), + Type::Integer32 => add_const!(editor, Constant::Integer32(0)), + Type::Integer64 => add_const!(editor, Constant::Integer64(0)), + Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)), + Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)), + Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)), + Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)), + Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))), + Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))), + Type::Summation(ts) => { + let const_id = generate_constant(editor, ts[0]); + add_const!(editor, Constant::Summation(typ, 0, const_id)) } - - (order, visited) + Type::Array(elem, _) => { + add_const!(editor, Constant::Array(typ)) + } + Type::Control => panic!("Cannot create constant of control type") } } diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index cccadbdd..d4263fe8 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -151,6 +151,12 @@ pub fn compile_ir( if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } + // TEMPORARY + add_pass!(pm, verify, SROA); + if x_dot { + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + } + // TEMPORARY add_pass!(pm, verify, Inline); add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); diff --git a/juno_samples/products.jn b/juno_samples/products.jn new file mode 100644 index 00000000..a6dd8862 --- /dev/null +++ b/juno_samples/products.jn @@ -0,0 +1,7 @@ +fn test_call(x : i32, y : f32) -> (i32, f32) { + return (x, y); +} + +fn test(x : i32, y : f32) -> (i32, f32) { + return test_call(x, y); +} -- GitLab