From c9445e5256cab3279f3119fd2baa8bfd2e4cc8a8 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 18 Nov 2024 14:50:21 -0600 Subject: [PATCH 01/12] Start rewriting SROA --- hercules_ir/src/ir.rs | 8 + hercules_opt/src/pass.rs | 31 ++- hercules_opt/src/sroa.rs | 507 ++++++++++++++++----------------------- juno_frontend/src/lib.rs | 6 + juno_samples/products.jn | 7 + 5 files changed, 250 insertions(+), 309 deletions(-) create mode 100644 juno_samples/products.jn diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5cf549a8..1486f5ce 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -728,6 +728,14 @@ impl Type { None } } + + pub fn try_product(&self) -> Option<&[TypeID]> { + if let Type::Product(ts) = self { + Some(ts) + } else { + None + } + } } impl Constant { diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index cb56f709..a7d48efb 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -494,16 +494,37 @@ impl PassManager { let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let typing = self.typing.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - sroa( + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( &mut self.module.functions[idx], + &constants_ref, + &dynamic_constants_ref, + &types_ref, &def_uses[idx], + ); + sroa( + &mut editor, &reverse_postorders[idx], - &typing[idx], - &self.module.types, - &mut self.module.constants, + &typing[idx] ); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + let edits = &editor.edits(); + if let Some(plans) = self.plans.as_mut() { + repair_plan(&mut plans[idx], &self.module.functions[idx], edits); + } + let grave_mapping = self.module.functions[idx].delete_gravestones(); + if let Some(plans) = self.plans.as_mut() { + plans[idx].fix_gravestones(&grave_mapping); + } } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::Inline => { diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index cb5ecd25..ae5d76dc 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,15 +1,12 @@ -extern crate bitvec; extern crate hercules_ir; -use std::collections::HashMap; -use std::iter::zip; - -use self::bitvec::prelude::*; +use std::collections::{HashMap, LinkedList}; use self::hercules_ir::dataflow::*; -use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; +use crate::*; + /* * Top level function to run SROA, intraprocedurally. Product values can be used * and created by a relatively small number of nodes. Here are *all* of them: @@ -20,11 +17,11 @@ use self::hercules_ir::ir::*; * - Reduce: similarly to phis, reduce nodes can cycle product values through * reduction loops - these get broken up into reduces on the fields * - * + Return: can return a product - these are untouched, and are the sinks for - * unbroken product values + * - Return: can return a product - the product values will be constructed + * at the return site * - * + Parameter: can introduce a product - these are untouched, and are the - * sources for unbroken product values + * - Parameter: can introduce a product - reads will be introduced for each + * field * * - Constant: can introduce a product - these are broken up into constants for * the individual fields @@ -32,334 +29,236 @@ use self::hercules_ir::ir::*; * - Ternary: the select ternary operator can select between products - these * are broken up into ternary nodes for the individual fields * - * + Call: the call node can use a product value as an argument to another - * function, and can produce a product value as a result - these are - * untouched, and are the sink and source for unbroken product values + * - Call: the call node can use a product value as an argument to another + * function, and can produce a product value as a result. Argument values + * will be constructed at the call site and the return value will be broken + * into individual fields * * - Read: the read node reads primitive fields from product values - these get - * replaced by a direct use of the field value from the broken product value, - * but are retained when the product value is unbroken + * replaced by a direct use of the field value * * - Write: the write node writes primitive fields in product values - these get - * replaced by a direct def of the field value from the broken product value, - * but are retained when the product value is unbroken - * - * The nodes above with the list marker "+" are retained for maintaining API/ABI - * compatability with other Hercules functions and the host code. These are - * called "sink" or "source" nodes in comments below. + * replaced by a direct def of the field value */ -pub fn sroa( - function: &mut Function, - def_use: &ImmutableDefUseMap, - reverse_postorder: &Vec<NodeID>, - typing: &Vec<TypeID>, - types: &Vec<Type>, - constants: &mut Vec<Constant>, -) { - // Determine which sources of product values we want to try breaking up. We - // can determine easily on the soure side if a node produces a product that - // shouldn't be broken up by just examining the node type. However, the way - // that products are used is also important for determining if the product - // can be broken up. We backward dataflow this info to the sources of - // product values. - #[derive(PartialEq, Eq, Clone, Debug)] - enum ProductUseLattice { - // The product value used by this node is eventually used by a sink. - UsedBySink, - // This node uses multiple product values - the stored node ID indicates - // which is eventually used by a sink. This lattice value is produced by - // read and write nodes implementing partial indexing. - SpecificUsedBySink(NodeID), - // This node doesn't use a product node, or the product node it does use - // is not in turn used by a sink. - UnusedBySink, - } +pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { + // This map stores a map from NodeID and fields to the NodeID which contains just that field of + // the original value + let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new(); - impl Semilattice for ProductUseLattice { - fn meet(a: &Self, b: &Self) -> Self { - match (a, b) { - (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink, - (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => { - if id1 == id2 { - Self::SpecificUsedBySink(*id1) - } else { - Self::UsedBySink - } - } - (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => { - Self::SpecificUsedBySink(*id) + // First: determine all nodes which interact with products (as described above) + let mut product_nodes = vec![]; + // We track call and return nodes separately since they (may) require constructing new products + // for the call's arguments or the return's value + let mut call_return_nodes = vec![]; + // We track writes separately since they should be processed once their input product has been + // processed, so we handle them after all other nodes (though before reads) + let mut write_nodes = vec![]; + // We also track reads separately because they aren't handled until all other nodes have been + // processed + let mut read_nodes = vec![]; + + let func = editor.func(); + + for node in reverse_postorder { + match func.nodes[node.idx()] { + Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select } + if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node), + + Node::Write { .. } + if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node), + Node::Read { collect, .. } + if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node), + + // We add all calls to the call/return list and check their arguments later + Node::Call { .. } => { + call_return_nodes.push(node); + if editor.get_type(types[node.idx()]).is_product() { + read_nodes.push(node); } - _ => Self::UnusedBySink, } - } + Node::Return { control: _, data } + if editor.get_type(types[data.idx()]).is_product() => + call_return_nodes.push(node), - fn bottom() -> Self { - Self::UsedBySink - } - - fn top() -> Self { - Self::UnusedBySink + _ => () } } - // Run dataflow analysis to find which product values are used by a sink. - let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| { - match function.nodes[id.idx()] { - Node::Return { - control: _, - data: _, - } => { - if types[typing[id.idx()].idx()].is_product() { - ProductUseLattice::UsedBySink - } else { - ProductUseLattice::UnusedBySink - } - } - Node::Call { - control: _, - function: _, - dynamic_constants: _, - args: _, - } => todo!(), - // For reads and writes, we only want to propagate the use of the - // product to the collect input of the node. - Node::Read { - collect, - indices: _, + // Next, we handle calls and returns. For returns, we will insert nodes that read each field of + // the returned product and then write them into a new product. These writes are not put into + // the list of product nodes since they must remain but the reads are so that they will be + // replaced later on. + // For calls, we do a similar process for each (product) argument. Additionally, if the call + // returns a product, we create reads for each field in that product and store it into our + // field map + for node in call_return_nodes { + match &editor.func().nodes[node.idx()] { + Node::Return { control, data } => { + assert!(editor.get_type(types[data.idx()]).is_product()); + let control = *control; + let new_data = reconstruct_product(editor, types[data.idx()], *data); + editor.edit(|mut edit| { + edit.add_node(Node::Return { control, data: new_data }); + edit.delete_node(*node) + }); } - | Node::Write { - collect, - data: _, - indices: _, - } => { - let meet = succ_outs - .iter() - .fold(ProductUseLattice::top(), |acc, latt| { - ProductUseLattice::meet(&acc, latt) - }); - if meet == ProductUseLattice::UnusedBySink { - ProductUseLattice::UnusedBySink - } else { - ProductUseLattice::SpecificUsedBySink(collect) - } - } - // For non-sink nodes. - _ => { - if function.nodes[id.idx()].is_control() { - return ProductUseLattice::UnusedBySink; - } - let meet = succ_outs - .iter() - .fold(ProductUseLattice::top(), |acc, latt| { - ProductUseLattice::meet(&acc, latt) - }); - if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet { - if meet_id == id { - ProductUseLattice::UsedBySink + Node::Call { control, function, dynamic_constants, args } => { + let control = *control; + let function = *function; + let dynamic_constants = dynamic_constants.clone(); + let args = args.clone(); + + // If the call returns a product, we generate reads for each field + let fields = + if editor.get_type(types[node.idx()]).is_product() { + Some(generate_reads(editor, types[node.idx()], *node)) } else { - ProductUseLattice::UnusedBySink + None + }; + + let mut new_args = vec![]; + for arg in args { + if editor.get_type(types[arg.idx()]).is_product() { + new_args.push(reconstruct_product(editor, types[arg.idx()], arg)); + } else { + new_args.push(arg); } - } else { - meet } + editor.dit(|mut edit| { + let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() }); + let edit = edit.replace_all_uses(*node, new_call)?; + let edit = edit.delete_node(*node)?; + + match fields { + None => {} + Some(fields) => { + for (idx, node) in fields { + field_map.insert((new_call, idx), node); + } + } + } + + Ok(edit) + }); } + _ => panic!("Processing non-call or return node") } - }); + } - // Only product values introduced as constants can be replaced by scalars. - let to_sroa: Vec<(NodeID, ConstantID)> = product_uses - .into_iter() - .enumerate() - .filter_map(|(node_idx, product_use)| { - if ProductUseLattice::UnusedBySink == product_use - && types[typing[node_idx].idx()].is_product() - { - function.nodes[node_idx] - .try_constant() - .map(|cons_id| (NodeID::new(node_idx), cons_id)) - } else { - None - } - }) - .collect(); + // Now, we process all other non-read/write nodes that deal with products. + // The first step is to identify the NodeIDs that contain each field of each of these nodes + todo!() +} - // Perform SROA. TODO: repair def-use when there are multiple product - // constants to SROA away. - assert!(to_sroa.len() < 2); - for (constant_node_id, constant_id) in to_sroa { - // Get the field constants to replace the product constant with. - let product_constant = constants[constant_id.idx()].clone(); - let constant_fields = product_constant.try_product_fields().unwrap(); +// Given a product value val of type typ, constructs a copy of that value by extracting all fields +// from that value and then writing them into a new constant +fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID { + let fields = generate_reads(editor, typ, val); + let new_const = generate_constant(editor, typ); - // DFS to find all data nodes that use the product constant. - let to_replace = sroa_dfs(constant_node_id, function, def_use); + // Create a constant node + let mut const_node = None; + editor.edit(|mut edit| { + const_node = Some(edit.add_node(Node::Constant { id: new_const })); + Ok(edit) + }); - // Assemble a mapping from old nodes IDs acting on the product constant - // to new nodes IDs operating on the field constants. - let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace - .iter() - .map(|old_id| match function.nodes[old_id.idx()] { - Node::Phi { - control: _, - data: _, - } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } - | Node::Constant { id: _ } - | Node::Ternary { - op: _, - first: _, - second: _, - third: _, - } - | Node::Write { - collect: _, - data: _, - indices: _, - } => { - let new_ids = (0..constant_fields.len()) - .map(|_| { - let id = NodeID::new(function.nodes.len()); - function.nodes.push(Node::Start); - id - }) - .collect(); - (*old_id, new_ids) - } - Node::Read { - collect: _, - indices: _, - } => (*old_id, vec![]), - _ => panic!("PANIC: Invalid node using a constant product found during SROA."), - }) - .collect(); + // Generate writes for each field + let mut value = const_node.expect("Add node cannot fail"); + for (idx, val) in fields { + editor.edit(|mut edit| { + value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() }); + Ok(edit) + }); + } - // Replace the old nodes with the new nodes. Since we've already - // allocated the node IDs, at this point we can iterate through the to- - // replace nodes in an arbitrary order. - for (old_id, new_ids) in &old_to_new_id_map { - // First, add the new nodes to the node list. - let node = function.nodes[old_id.idx()].clone(); - match node { - // Replace the original constant with constants for each field. - Node::Constant { id: _ } => { - for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) { - function.nodes[new_id.idx()] = Node::Constant { id: *field_id }; - } - } - // Replace writes using the constant as the data use with a - // series of writes writing the invidiual constant fields. TODO: - // handle the case where the constant is the collect use of the - // write node. - Node::Write { - collect, - data, - ref indices, - } => { - // Create the write chain. - assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet."); - let mut collect_def = collect; - for (idx, (new_id, new_data_def)) in - zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate() - { - let mut new_indices = indices.clone().into_vec(); - new_indices.push(Index::Field(idx)); - function.nodes[new_id.idx()] = Node::Write { - collect: collect_def, - data: *new_data_def, - indices: new_indices.into_boxed_slice(), - }; - collect_def = *new_id; - } + value +} - // Replace uses of the old write with the new write. - for user in def_use.get_users(*old_id) { - get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def); - } - } - _ => todo!(), - } +// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and +// returns a list of pairs of the indices and the node that reads that index +fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> { + generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>() +} - // Delete the old node. - function.nodes[old_id.idx()] = Node::Start; +// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) +// fields within this sub-value of val and return the correspondence list +fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> { + let ts : Option<Vec<TypeID>> = { + if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None } + }; + + if let Some(ts) = ts { + // For product values, we will recurse down each of its fields with an extended index + // and the appropriate type of that field + let mut result = LinkedList::new(); + for (i, t) in ts.into_iter().enumerate() { + let mut new_idx = idx.clone(); + new_idx.push(Index::Field(i)); + result.append(&mut generate_reads_at_index(editor, t, val, new_idx)); + } + + result + } else { + // For non-product types, we've reached a leaf so we generate the read and return it's + // information + let mut read_id = None; + editor.edit(|mut edit| { + read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() })); + Ok(edit) + }); + + LinkedList::from([(idx, read_id.expect("Add node cannot fail"))]) } } -fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> { - // Initialize order vector and bitset for tracking which nodes have been - // visited. - let order = Vec::with_capacity(def_uses.num_nodes()); - let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()]; - - // Order and visited are threaded through arguments / return pair of - // sroa_dfs_helper for ownership reasons. - let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited); - order +macro_rules! add_const { + ($editor:ident, $const:expr) => {{ + let mut res = None; + $editor.edit(|mut edit| { + res = Some(edit.add_constant($const)); + Ok(edit) + }); + res.expect("Add constant cannot fail") + }} } -fn sroa_dfs_helper( - node: NodeID, - def: NodeID, - function: &Function, - def_uses: &ImmutableDefUseMap, - mut order: Vec<NodeID>, - mut visited: BitVec<u8, Lsb0>, -) -> (Vec<NodeID>, BitVec<u8, Lsb0>) { - if visited[node.idx()] { - // If already visited, return early. - (order, visited) - } else { - // Set visited to true. - visited.set(node.idx(), true); +// Given a type, builds a default constant of that type +fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { + let t = editor.get_type(typ).clone(); - // Before iterating users, push this node. - order.push(node); - match function.nodes[node.idx()] { - Node::Phi { - control: _, - data: _, + match t { + Type::Product(ts) => { + let mut cs = vec![]; + for t in ts { + cs.push(generate_constant(editor, t)); } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } - | Node::Constant { id: _ } - | Node::Ternary { - op: _, - first: _, - second: _, - third: _, - } => {} - Node::Read { - collect, - indices: _, - } => { - assert_eq!(def, collect); - return (order, visited); - } - Node::Write { - collect, - data, - indices: _, - } => { - if def == data { - return (order, visited); - } - assert_eq!(def, collect); - } - _ => panic!("PANIC: Invalid node using a constant product found during SROA."), + add_const!(editor, Constant::Product(typ, cs.into())) } - - // Iterate over users, if we shouldn't stop here. - for user in def_uses.get_users(node) { - (order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited); + Type::Boolean => add_const!(editor, Constant::Boolean(false)), + Type::Integer8 => add_const!(editor, Constant::Integer8(0)), + Type::Integer16 => add_const!(editor, Constant::Integer16(0)), + Type::Integer32 => add_const!(editor, Constant::Integer32(0)), + Type::Integer64 => add_const!(editor, Constant::Integer64(0)), + Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)), + Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)), + Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)), + Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)), + Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))), + Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))), + Type::Summation(ts) => { + let const_id = generate_constant(editor, ts[0]); + add_const!(editor, Constant::Summation(typ, 0, const_id)) } - - (order, visited) + Type::Array(elem, _) => { + add_const!(editor, Constant::Array(typ)) + } + Type::Control => panic!("Cannot create constant of control type") } } diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index cccadbdd..d4263fe8 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -151,6 +151,12 @@ pub fn compile_ir( if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } + // TEMPORARY + add_pass!(pm, verify, SROA); + if x_dot { + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + } + // TEMPORARY add_pass!(pm, verify, Inline); add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); diff --git a/juno_samples/products.jn b/juno_samples/products.jn new file mode 100644 index 00000000..a6dd8862 --- /dev/null +++ b/juno_samples/products.jn @@ -0,0 +1,7 @@ +fn test_call(x : i32, y : f32) -> (i32, f32) { + return (x, y); +} + +fn test(x : i32, y : f32) -> (i32, f32) { + return test_call(x, y); +} -- GitLab From 6013f4791dbc2e49765e8117365f21cfbdb31f86 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 18 Nov 2024 21:20:00 -0600 Subject: [PATCH 02/12] Acyclic SROA working (mostly) --- hercules_opt/src/sroa.rs | 369 +++++++++++++++++++++++++++++++++------ juno_frontend/src/lib.rs | 2 +- 2 files changed, 317 insertions(+), 54 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index ae5d76dc..b9380197 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,6 +1,6 @@ extern crate hercules_ir; -use std::collections::{HashMap, LinkedList}; +use std::collections::{HashMap, LinkedList, VecDeque}; use self::hercules_ir::dataflow::*; use self::hercules_ir::ir::*; @@ -41,46 +41,33 @@ use crate::*; * replaced by a direct def of the field value */ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { - // This map stores a map from NodeID and fields to the NodeID which contains just that field of - // the original value - let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new(); + // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID + // that contains the corresponding fields of the original value + let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); // First: determine all nodes which interact with products (as described above) - let mut product_nodes = vec![]; + let mut product_nodes : Vec<NodeID> = vec![]; // We track call and return nodes separately since they (may) require constructing new products // for the call's arguments or the return's value - let mut call_return_nodes = vec![]; - // We track writes separately since they should be processed once their input product has been - // processed, so we handle them after all other nodes (though before reads) - let mut write_nodes = vec![]; - // We also track reads separately because they aren't handled until all other nodes have been - // processed - let mut read_nodes = vec![]; + let mut call_return_nodes : Vec<NodeID> = vec![]; let func = editor.func(); for node in reverse_postorder { match func.nodes[node.idx()] { Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } - | Node::Constant { .. } + | Node::Constant { .. } | Node::Write { .. } | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select } - if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node), + if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), - Node::Write { .. } - if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node), Node::Read { collect, .. } - if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node), + if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node), // We add all calls to the call/return list and check their arguments later - Node::Call { .. } => { - call_return_nodes.push(node); - if editor.get_type(types[node.idx()]).is_product() { - read_nodes.push(node); - } - } + Node::Call { .. } => call_return_nodes.push(*node), Node::Return { control: _, data } if editor.get_type(types[data.idx()]).is_product() => - call_return_nodes.push(node), + call_return_nodes.push(*node), _ => () } @@ -98,10 +85,10 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: Node::Return { control, data } => { assert!(editor.get_type(types[data.idx()]).is_product()); let control = *control; - let new_data = reconstruct_product(editor, types[data.idx()], *data); + let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); editor.edit(|mut edit| { edit.add_node(Node::Return { control, data: new_data }); - edit.delete_node(*node) + edit.delete_node(node) }); } Node::Call { control, function, dynamic_constants, args } => { @@ -113,7 +100,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // If the call returns a product, we generate reads for each field let fields = if editor.get_type(types[node.idx()]).is_product() { - Some(generate_reads(editor, types[node.idx()], *node)) + Some(generate_reads(editor, types[node.idx()], node)) } else { None }; @@ -121,23 +108,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: let mut new_args = vec![]; for arg in args { if editor.get_type(types[arg.idx()]).is_product() { - new_args.push(reconstruct_product(editor, types[arg.idx()], arg)); + new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes)); } else { new_args.push(arg); } } - editor.dit(|mut edit| { + editor.edit(|mut edit| { let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() }); - let edit = edit.replace_all_uses(*node, new_call)?; - let edit = edit.delete_node(*node)?; + let edit = edit.replace_all_uses(node, new_call)?; + let edit = edit.delete_node(node)?; match fields { None => {} Some(fields) => { - for (idx, node) in fields { - field_map.insert((new_call, idx), node); - } - } + field_map.insert(new_call, fields); + } } Ok(edit) @@ -147,14 +132,243 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } } - // Now, we process all other non-read/write nodes that deal with products. - // The first step is to identify the NodeIDs that contain each field of each of these nodes - todo!() + enum WorkItem { + Unhandled(NodeID), + AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> }, + AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> }, + AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> }, + } + + // Now, we process the other nodes that deal with products. + // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi, + // reduce, parameter, constant, and ternary. + // We do this in several steps: first we break apart parameters and constants + let mut to_delete = vec![]; + let mut worklist: VecDeque<WorkItem> = VecDeque::new(); + + for node in product_nodes { + match editor.func().nodes[node.idx()] { + Node::Parameter { .. } => { + field_map.insert(node, generate_reads(editor, types[node.idx()], node)); + } + Node::Constant { id } => { + field_map.insert(node, generate_constant_fields(editor, id)); + to_delete.push(node); + } + _ => { worklist.push_back(WorkItem::Unhandled(node)); } + } + } + + // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. + // We track the current NodeID and add nodes whenever possible, otherwise we return values to + // the worklist + let mut next_id : usize = editor.func().nodes.len(); + let mut cur_id : usize = editor.func().nodes.len(); + while let Some(mut item) = worklist.pop_front() { + if let WorkItem::Unhandled(node) = item { + match &editor.func().nodes[node.idx()] { + // For phi, reduce, and ternary, we break them apart into separate nodes for each field + Node::Phi { control, data } => { + let control = *control; + let data = data.clone(); + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedPhi { + control, + data: data.into(), + fields }; + } + Node::Reduce { control, init, reduct } => { + let control = *control; + let init = *init; + let reduct = *reduct; + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedReduce { control, init, reduct, fields }; + } + Node::Ternary { first, second, third, .. } => { + let first = *first; + let second = *second; + let third = *third; + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedTernary { first, second, third, fields }; + } + + Node::Write { collect, data, indices } => { + if let Some(index_map) = field_map.get(collect) { + if editor.get_type(types[data.idx()]).is_product() { + if let Some(data_idx) = field_map.get(data) { + field_map.insert(node, index_map.clone().replace(indices, data_idx.clone())); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } else { + field_map.insert(node, index_map.clone().set(indices, *data)); + to_delete.push(node); + } + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } + Node::Read { collect, indices } => { + if let Some(index_map) = field_map.get(collect) { + let read_info = index_map.lookup(indices); + match read_info { + IndexTree::Leaf(field) => { + editor.edit(|edit| { + edit.replace_all_uses(node, *field) + }); + } + _ => {} + } + field_map.insert(node, read_info.clone()); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } + + _ => panic!("Unexpected node type") + } + } + match item { + WorkItem::Unhandled(_) => {} + _ => todo!() + } + } + + // Actually deleting nodes seems to break things right now + /* + println!("{:?}", to_delete); + editor.edit(|mut edit| { + for node in to_delete { + edit = edit.delete_node(node)? + } + Ok(edit) + }); + */ +} + +// An index tree is used to store results at many index lists +#[derive(Clone, Debug)] +enum IndexTree<T> { + Leaf(T), + Node(Vec<IndexTree<T>>), +} + +impl<T: std::fmt::Debug> IndexTree<T> { + fn lookup(&self, idx: &[Index]) -> &IndexTree<T> { + self.lookup_idx(idx, 0) + } + + fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1), + } + } else { + // TODO: This could be hit because of an array inside of a product + panic!("Error handling lookup of field"); + } + } else { + self + } + } + + fn set(self, idx: &[Index], val: T) -> IndexTree<T> { + self.set_idx(idx, val, 0) + } + + fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(mut ts) => { + if i + 1 == ts.len() { + let t = ts.pop().unwrap(); + ts.push(t.set_idx(idx, val, n+1)); + } else { + let mut t = ts.pop().unwrap(); + std::mem::swap(&mut ts[i], &mut t); + t = t.set_idx(idx, val, n+1); + std::mem::swap(&mut ts[i], &mut t); + ts.push(t); + } + IndexTree::Node(ts) + } + } + } else { + panic!("Error handling set of field"); + } + } else { + IndexTree::Leaf(val) + } + } + + fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> { + self.replace_idx(idx, val, 0) + } + + fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(mut ts) => { + if i + 1 == ts.len() { + let t = ts.pop().unwrap(); + ts.push(t.replace_idx(idx, val, n+1)); + } else { + let mut t = ts.pop().unwrap(); + std::mem::swap(&mut ts[i], &mut t); + t = t.replace_idx(idx, val, n+1); + std::mem::swap(&mut ts[i], &mut t); + ts.push(t); + } + IndexTree::Node(ts) + } + } + } else { + panic!("Error handling set of field"); + } + } else { + val + } + } + + fn for_each<F>(&self, mut f: F) + where F: FnMut(&Vec<Index>, &T) { + self.for_each_idx(&mut vec![], &mut f); + } + + fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F) + where F: FnMut(&Vec<Index>, &T) { + match self { + IndexTree::Leaf(t) => f(idx, t), + IndexTree::Node(ts) => { + for (i, t) in ts.iter().enumerate() { + idx.push(Index::Field(i)); + t.for_each_idx(idx, f); + idx.pop(); + } + } + } + } } // Given a product value val of type typ, constructs a copy of that value by extracting all fields // from that value and then writing them into a new constant -fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID { +// This process also adds all the read nodes that are generated into the read_list so that the +// reads can be eliminated by later parts of SROA +fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID { let fields = generate_reads(editor, typ, val); let new_const = generate_constant(editor, typ); @@ -167,44 +381,44 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> // Generate writes for each field let mut value = const_node.expect("Add node cannot fail"); - for (idx, val) in fields { + fields.for_each(|idx: &Vec<Index>, val: &NodeID| { + read_list.push(*val); editor.edit(|mut edit| { - value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() }); + value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() }); Ok(edit) }); - } + }); value } // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and // returns a list of pairs of the indices and the node that reads that index -fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> { - generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>() +fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + let res = generate_reads_at_index(editor, typ, val, vec![]); + res } // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list -fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> { - let ts : Option<Vec<TypeID>> = { +fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> { + let ts : Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { Some(ts.into()) } else { None - } - }; + }; if let Some(ts) = ts { // For product values, we will recurse down each of its fields with an extended index // and the appropriate type of that field - let mut result = LinkedList::new(); + let mut fields = vec![]; for (i, t) in ts.into_iter().enumerate() { let mut new_idx = idx.clone(); new_idx.push(Index::Field(i)); - result.append(&mut generate_reads_at_index(editor, t, val, new_idx)); + fields.push(generate_reads_at_index(editor, t, val, new_idx)); } - - result + IndexTree::Node(fields) } else { // For non-product types, we've reached a leaf so we generate the read and return it's // information @@ -214,7 +428,7 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID Ok(edit) }); - LinkedList::from([(idx, read_id.expect("Add node cannot fail"))]) + IndexTree::Leaf(read_id.expect("Add node canont fail")) } } @@ -262,3 +476,52 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { Type::Control => panic!("Cannot create constant of control type") } } + +// Given a constant cnst adds node to the function which are the constant values of each field and +// returns a list of pairs of indices and the node that holds that index +fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> { + let cs : Option<Vec<ConstantID>> = + if let Some(cs) = editor.get_constant(cnst).try_product_fields() { + Some(cs.into()) + } else { + None + }; + + if let Some(cs) = cs { + let mut fields = vec![]; + for c in cs { + fields.push(generate_constant_fields(editor, c)); + } + IndexTree::Node(fields) + } else { + let mut node = None; + editor.edit(|mut edit| { + node = Some(edit.add_node(Node::Constant { id: cnst })); + Ok(edit) + }); + IndexTree::Leaf(node.expect("Add node cannot fail")) + } +} + +// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the +// id provided +fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> { + let ts : Option<Vec<TypeID>> = + if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; + + if let Some(ts) = ts { + let mut fields = vec![]; + for t in ts { + fields.push(allocate_fields(editor, t, id)); + } + IndexTree::Node(fields) + } else { + let node = *id; + *id += 1; + IndexTree::Leaf(NodeID::new(node)) + } +} diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index d4263fe8..c3dc262a 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -157,7 +157,7 @@ pub fn compile_ir( pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } // TEMPORARY - add_pass!(pm, verify, Inline); + //add_pass!(pm, verify, Inline); add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); add_pass!(pm, verify, GVN); -- GitLab From b44f450c359dd94899de5d700943cb65f379d044 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 18 Nov 2024 21:23:10 -0600 Subject: [PATCH 03/12] Re-enabled inlining --- juno_frontend/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index c3dc262a..d4263fe8 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -157,7 +157,7 @@ pub fn compile_ir( pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } // TEMPORARY - //add_pass!(pm, verify, Inline); + add_pass!(pm, verify, Inline); add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); add_pass!(pm, verify, GVN); -- GitLab From 881248849fda7222741aca6bb4c37ccb8864fa1b Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 19 Nov 2024 08:10:43 -0600 Subject: [PATCH 04/12] Fix editor bug when deleting created node --- hercules_opt/src/editor.rs | 4 +++- hercules_opt/src/sroa.rs | 3 --- juno_frontend/src/lib.rs | 7 +++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 95fd1669..bf7cede2 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -505,7 +505,9 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) { // Step 2: drop schedules for deleted nodes and create empty schedule lists // for added nodes. for deleted in total_edit.0.iter() { - plan.schedules[deleted.idx()] = vec![]; + if deleted.idx() < plan.schedules.len() { + plan.schedules[deleted.idx()] = vec![]; + } } if !total_edit.1.is_empty() { assert_eq!( diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index b9380197..2dd0049b 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -243,15 +243,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } // Actually deleting nodes seems to break things right now - /* - println!("{:?}", to_delete); editor.edit(|mut edit| { for node in to_delete { edit = edit.delete_node(node)? } Ok(edit) }); - */ } // An index tree is used to store results at many index lists diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index d4263fe8..49249fc7 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -151,18 +151,17 @@ pub fn compile_ir( if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } - // TEMPORARY + add_pass!(pm, verify, Inline); + // Run SROA pretty early (though after inlining which can make SROA more effective) so that + // CCP, GVN, etc. can work on the result of SROA add_pass!(pm, verify, SROA); if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } - // TEMPORARY - add_pass!(pm, verify, Inline); add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); add_pass!(pm, verify, GVN); add_pass!(pm, verify, DCE); - //add_pass!(pm, verify, SROA); if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } -- GitLab From 2021ea3b0b5f45add9409b4c649470ac7a642ac6 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 19 Nov 2024 08:14:04 -0600 Subject: [PATCH 05/12] Comment change in editor and format --- hercules_opt/src/editor.rs | 2 + hercules_opt/src/pass.rs | 6 +- hercules_opt/src/sroa.rs | 224 ++++++++++++++++++++++++++----------- 3 files changed, 159 insertions(+), 73 deletions(-) diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index bf7cede2..2410a8ef 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -505,6 +505,8 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) { // Step 2: drop schedules for deleted nodes and create empty schedule lists // for added nodes. for deleted in total_edit.0.iter() { + // Nodes that were created and deleted using the same editor don't have + // an existing schedule, so ignore them if deleted.idx() < plan.schedules.len() { plan.schedules[deleted.idx()] = vec![]; } diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index a7d48efb..f8310ab7 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -506,11 +506,7 @@ impl PassManager { &types_ref, &def_uses[idx], ); - sroa( - &mut editor, - &reverse_postorders[idx], - &typing[idx] - ); + sroa(&mut editor, &reverse_postorders[idx], &typing[idx]); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 2dd0049b..d8e2021a 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -43,33 +43,43 @@ use crate::*; pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID // that contains the corresponding fields of the original value - let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); + let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); // First: determine all nodes which interact with products (as described above) - let mut product_nodes : Vec<NodeID> = vec![]; + let mut product_nodes: Vec<NodeID> = vec![]; // We track call and return nodes separately since they (may) require constructing new products // for the call's arguments or the return's value - let mut call_return_nodes : Vec<NodeID> = vec![]; + let mut call_return_nodes: Vec<NodeID> = vec![]; let func = editor.func(); for node in reverse_postorder { match func.nodes[node.idx()] { - Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } - | Node::Constant { .. } | Node::Write { .. } - | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select } - if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), - - Node::Read { collect, .. } - if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node), + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Write { .. } + | Node::Ternary { + first: _, + second: _, + third: _, + op: TernaryOperator::Select, + } if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), + + Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => { + product_nodes.push(*node) + } // We add all calls to the call/return list and check their arguments later Node::Call { .. } => call_return_nodes.push(*node), Node::Return { control: _, data } if editor.get_type(types[data.idx()]).is_product() => - call_return_nodes.push(*node), + { + call_return_nodes.push(*node) + } - _ => () + _ => (), } } @@ -85,36 +95,54 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: Node::Return { control, data } => { assert!(editor.get_type(types[data.idx()]).is_product()); let control = *control; - let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); + let new_data = + reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); editor.edit(|mut edit| { - edit.add_node(Node::Return { control, data: new_data }); + edit.add_node(Node::Return { + control, + data: new_data, + }); edit.delete_node(node) }); } - Node::Call { control, function, dynamic_constants, args } => { + Node::Call { + control, + function, + dynamic_constants, + args, + } => { let control = *control; let function = *function; let dynamic_constants = dynamic_constants.clone(); let args = args.clone(); // If the call returns a product, we generate reads for each field - let fields = - if editor.get_type(types[node.idx()]).is_product() { - Some(generate_reads(editor, types[node.idx()], node)) - } else { - None - }; + let fields = if editor.get_type(types[node.idx()]).is_product() { + Some(generate_reads(editor, types[node.idx()], node)) + } else { + None + }; let mut new_args = vec![]; for arg in args { if editor.get_type(types[arg.idx()]).is_product() { - new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes)); + new_args.push(reconstruct_product( + editor, + types[arg.idx()], + arg, + &mut product_nodes, + )); } else { new_args.push(arg); } } editor.edit(|mut edit| { - let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() }); + let new_call = edit.add_node(Node::Call { + control, + function, + dynamic_constants, + args: new_args.into(), + }); let edit = edit.replace_all_uses(node, new_call)?; let edit = edit.delete_node(node)?; @@ -128,15 +156,29 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: Ok(edit) }); } - _ => panic!("Processing non-call or return node") + _ => panic!("Processing non-call or return node"), } } enum WorkItem { Unhandled(NodeID), - AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> }, - AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> }, - AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> }, + AllocatedPhi { + control: NodeID, + data: Vec<NodeID>, + fields: IndexTree<NodeID>, + }, + AllocatedReduce { + control: NodeID, + init: NodeID, + reduct: NodeID, + fields: IndexTree<NodeID>, + }, + AllocatedTernary { + first: NodeID, + second: NodeID, + third: NodeID, + fields: IndexTree<NodeID>, + }, } // Now, we process the other nodes that deal with products. @@ -155,15 +197,17 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: field_map.insert(node, generate_constant_fields(editor, id)); to_delete.push(node); } - _ => { worklist.push_back(WorkItem::Unhandled(node)); } + _ => { + worklist.push_back(WorkItem::Unhandled(node)); + } } } // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. // We track the current NodeID and add nodes whenever possible, otherwise we return values to // the worklist - let mut next_id : usize = editor.func().nodes.len(); - let mut cur_id : usize = editor.func().nodes.len(); + let mut next_id: usize = editor.func().nodes.len(); + let mut cur_id: usize = editor.func().nodes.len(); while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { match &editor.func().nodes[node.idx()] { @@ -177,32 +221,59 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: item = WorkItem::AllocatedPhi { control, data: data.into(), - fields }; + fields, + }; } - Node::Reduce { control, init, reduct } => { + Node::Reduce { + control, + init, + reduct, + } => { let control = *control; let init = *init; let reduct = *reduct; let fields = allocate_fields(editor, types[node.idx()], &mut next_id); field_map.insert(node, fields.clone()); - item = WorkItem::AllocatedReduce { control, init, reduct, fields }; + item = WorkItem::AllocatedReduce { + control, + init, + reduct, + fields, + }; } - Node::Ternary { first, second, third, .. } => { + Node::Ternary { + first, + second, + third, + .. + } => { let first = *first; let second = *second; let third = *third; let fields = allocate_fields(editor, types[node.idx()], &mut next_id); field_map.insert(node, fields.clone()); - item = WorkItem::AllocatedTernary { first, second, third, fields }; + item = WorkItem::AllocatedTernary { + first, + second, + third, + fields, + }; } - Node::Write { collect, data, indices } => { + Node::Write { + collect, + data, + indices, + } => { if let Some(index_map) = field_map.get(collect) { if editor.get_type(types[data.idx()]).is_product() { if let Some(data_idx) = field_map.get(data) { - field_map.insert(node, index_map.clone().replace(indices, data_idx.clone())); + field_map.insert( + node, + index_map.clone().replace(indices, data_idx.clone()), + ); to_delete.push(node); } else { worklist.push_back(WorkItem::Unhandled(node)); @@ -220,9 +291,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: let read_info = index_map.lookup(indices); match read_info { IndexTree::Leaf(field) => { - editor.edit(|edit| { - edit.replace_all_uses(node, *field) - }); + editor.edit(|edit| edit.replace_all_uses(node, *field)); } _ => {} } @@ -233,12 +302,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } } - _ => panic!("Unexpected node type") + _ => panic!("Unexpected node type"), } } match item { WorkItem::Unhandled(_) => {} - _ => todo!() + _ => todo!(), } } @@ -268,7 +337,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { if let Index::Field(i) = idx[n] { match self { IndexTree::Leaf(_) => panic!("Invalid field"), - IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1), + IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1), } } else { // TODO: This could be hit because of an array inside of a product @@ -291,11 +360,11 @@ impl<T: std::fmt::Debug> IndexTree<T> { IndexTree::Node(mut ts) => { if i + 1 == ts.len() { let t = ts.pop().unwrap(); - ts.push(t.set_idx(idx, val, n+1)); + ts.push(t.set_idx(idx, val, n + 1)); } else { let mut t = ts.pop().unwrap(); std::mem::swap(&mut ts[i], &mut t); - t = t.set_idx(idx, val, n+1); + t = t.set_idx(idx, val, n + 1); std::mem::swap(&mut ts[i], &mut t); ts.push(t); } @@ -322,11 +391,11 @@ impl<T: std::fmt::Debug> IndexTree<T> { IndexTree::Node(mut ts) => { if i + 1 == ts.len() { let t = ts.pop().unwrap(); - ts.push(t.replace_idx(idx, val, n+1)); + ts.push(t.replace_idx(idx, val, n + 1)); } else { let mut t = ts.pop().unwrap(); std::mem::swap(&mut ts[i], &mut t); - t = t.replace_idx(idx, val, n+1); + t = t.replace_idx(idx, val, n + 1); std::mem::swap(&mut ts[i], &mut t); ts.push(t); } @@ -342,12 +411,16 @@ impl<T: std::fmt::Debug> IndexTree<T> { } fn for_each<F>(&self, mut f: F) - where F: FnMut(&Vec<Index>, &T) { + where + F: FnMut(&Vec<Index>, &T), + { self.for_each_idx(&mut vec![], &mut f); } fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F) - where F: FnMut(&Vec<Index>, &T) { + where + F: FnMut(&Vec<Index>, &T), + { match self { IndexTree::Leaf(t) => f(idx, t), IndexTree::Node(ts) => { @@ -365,7 +438,12 @@ impl<T: std::fmt::Debug> IndexTree<T> { // from that value and then writing them into a new constant // This process also adds all the read nodes that are generated into the read_list so that the // reads can be eliminated by later parts of SROA -fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID { +fn reconstruct_product( + editor: &mut FunctionEditor, + typ: TypeID, + val: NodeID, + read_list: &mut Vec<NodeID>, +) -> NodeID { let fields = generate_reads(editor, typ, val); let new_const = generate_constant(editor, typ); @@ -381,7 +459,11 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, re fields.for_each(|idx: &Vec<Index>, val: &NodeID| { read_list.push(*val); editor.edit(|mut edit| { - value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() }); + value = edit.add_node(Node::Write { + collect: value, + data: *val, + indices: idx.clone().into(), + }); Ok(edit) }); }); @@ -398,13 +480,17 @@ fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Inde // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list -fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> { - let ts : Option<Vec<TypeID>> = - if let Some(ts) = editor.get_type(typ).try_product() { - Some(ts.into()) - } else { - None - }; +fn generate_reads_at_index( + editor: &mut FunctionEditor, + typ: TypeID, + val: NodeID, + idx: Vec<Index>, +) -> IndexTree<NodeID> { + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; if let Some(ts) = ts { // For product values, we will recurse down each of its fields with an extended index @@ -421,7 +507,10 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID // information let mut read_id = None; editor.edit(|mut edit| { - read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() })); + read_id = Some(edit.add_node(Node::Read { + collect: val, + indices: idx.clone().into(), + })); Ok(edit) }); @@ -437,7 +526,7 @@ macro_rules! add_const { Ok(edit) }); res.expect("Add constant cannot fail") - }} + }}; } // Given a type, builds a default constant of that type @@ -470,14 +559,14 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { Type::Array(elem, _) => { add_const!(editor, Constant::Array(typ)) } - Type::Control => panic!("Cannot create constant of control type") + Type::Control => panic!("Cannot create constant of control type"), } } // Given a constant cnst adds node to the function which are the constant values of each field and // returns a list of pairs of indices and the node that holds that index fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> { - let cs : Option<Vec<ConstantID>> = + let cs: Option<Vec<ConstantID>> = if let Some(cs) = editor.get_constant(cnst).try_product_fields() { Some(cs.into()) } else { @@ -503,12 +592,11 @@ fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> In // Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the // id provided fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> { - let ts : Option<Vec<TypeID>> = - if let Some(ts) = editor.get_type(typ).try_product() { - Some(ts.into()) - } else { - None - }; + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; if let Some(ts) = ts { let mut fields = vec![]; -- GitLab From b8f77a5e2fe55d6008c3958563bb0d6354b740a2 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 19 Nov 2024 08:58:41 -0600 Subject: [PATCH 06/12] Start on SROA of remaining nodes --- hercules_opt/src/sroa.rs | 184 +++++++++++++++++++++++++++++++++++++-- juno_frontend/src/lib.rs | 4 + juno_samples/products.jn | 4 +- 3 files changed, 185 insertions(+), 7 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index d8e2021a..ccc2605c 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,6 +1,6 @@ extern crate hercules_ir; -use std::collections::{HashMap, LinkedList, VecDeque}; +use std::collections::{BTreeMap, HashMap, LinkedList, VecDeque}; use self::hercules_ir::dataflow::*; use self::hercules_ir::ir::*; @@ -165,18 +165,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: AllocatedPhi { control: NodeID, data: Vec<NodeID>, + node: NodeID, fields: IndexTree<NodeID>, }, AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, + node: NodeID, fields: IndexTree<NodeID>, }, AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, + node: NodeID, fields: IndexTree<NodeID>, }, } @@ -204,10 +207,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. - // We track the current NodeID and add nodes whenever possible, otherwise we return values to - // the worklist + // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we + // need to add nodes in a particular order we wait to do that until the end). If we don't have + // enough information to process a particular node, we add it back to the worklist let mut next_id: usize = editor.func().nodes.len(); - let mut cur_id: usize = editor.func().nodes.len(); + let mut to_insert = BTreeMap::new(); + while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { match &editor.func().nodes[node.idx()] { @@ -221,6 +226,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: item = WorkItem::AllocatedPhi { control, data: data.into(), + node, fields, }; } @@ -239,6 +245,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: control, init, reduct, + node, fields, }; } @@ -258,6 +265,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: first, second, third, + node, fields, }; } @@ -307,11 +315,127 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } match item { WorkItem::Unhandled(_) => {} - _ => todo!(), + WorkItem::AllocatedPhi { + control, + data, + node, + fields, + } => { + let mut data_fields = vec![]; + let mut ready = true; + for val in data.iter() { + if let Some(val_fields) = field_map.get(val) { + data_fields.push(val_fields); + } else { + ready = false; + break; + } + } + + if ready { + fields.zip_list(data_fields).for_each(|idx, (res, data)| { + to_insert.insert( + res.idx(), + Node::Phi { + control, + data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(), + }, + ); + }); + } else { + worklist.push_back(WorkItem::AllocatedPhi { + control, + data, + node, + fields, + }); + } + } + WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, + } => { + if let (Some(init_fields), Some(reduct_fields)) = + (field_map.get(&init), field_map.get(&reduct)) + { + fields.zip(init_fields).zip(reduct_fields).for_each( + |idx, ((res, init), reduct)| { + to_insert.insert( + res.idx(), + Node::Reduce { + control, + init: **init, + reduct: **reduct, + }, + ); + }, + ); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, + }); + } + } + WorkItem::AllocatedTernary { + first, + second, + third, + node, + fields, + } => { + if let (Some(fst_fields), Some(snd_fields), Some(thd_fields)) = ( + field_map.get(&first), + field_map.get(&second), + field_map.get(&third), + ) { + fields + .zip(fst_fields) + .zip(snd_fields) + .zip(thd_fields) + .for_each(|idx, (((res, fst), snd), thd)| { + to_insert.insert( + res.idx(), + Node::Ternary { + first: **fst, + second: **snd, + third: **thd, + op: TernaryOperator::Select, + }, + ); + }); + } else { + worklist.push_back(WorkItem::AllocatedTernary { + first, + second, + third, + node, + fields, + }); + } + } } } - // Actually deleting nodes seems to break things right now + // Create new nodes nodes + for (node_id, node) in to_insert { + assert!(node_id == editor.func().nodes.len()); + println!("Inserting {:?} : {:?}", node_id, node); + editor.edit(|mut edit| { + let id = edit.add_node(node); + assert!(node_id == id.idx()); + Ok(edit) + }); + } + + // Remove nodes editor.edit(|mut edit| { for node in to_delete { edit = edit.delete_node(node)? @@ -410,6 +534,54 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } + fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> { + match (self, other) { + (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)), + (IndexTree::Node(t), IndexTree::Node(a)) => { + let mut fields = vec![]; + for (t, a) in t.into_iter().zip(a.iter()) { + fields.push(t.zip(a)); + } + IndexTree::Node(fields) + } + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + + fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> { + match self { + IndexTree::Leaf(t) => { + let mut res = vec![]; + for other in others { + match other { + IndexTree::Leaf(a) => res.push(a), + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + IndexTree::Leaf((t, res)) + } + IndexTree::Node(t) => { + let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()]; + for other in others { + match other { + IndexTree::Node(a) => { + for (i, a) in a.iter().enumerate() { + fields[i].push(a); + } + } + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + IndexTree::Node( + t.into_iter() + .zip(fields.into_iter()) + .map(|(t, f)| t.zip_list(f)) + .collect(), + ) + } + } + } + fn for_each<F>(&self, mut f: F) where F: FnMut(&Vec<Index>, &T), diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 49249fc7..e54e88fc 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -151,6 +151,10 @@ pub fn compile_ir( if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } + add_pass!(pm, verify, SROA); + if x_dot { + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + } add_pass!(pm, verify, Inline); // Run SROA pretty early (though after inlining which can make SROA more effective) so that // CCP, GVN, etc. can work on the result of SROA diff --git a/juno_samples/products.jn b/juno_samples/products.jn index a6dd8862..00b66aab 100644 --- a/juno_samples/products.jn +++ b/juno_samples/products.jn @@ -1,5 +1,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) { - return (x, y); + let res = (x, y); + if x < 13 { res = (x + 1, y); } + return res; } fn test(x : i32, y : f32) -> (i32, f32) { -- GitLab From 8df0ef85e304465d964011d74bb8c6f120681e0a Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 19 Nov 2024 11:39:40 -0600 Subject: [PATCH 07/12] Fix: delete old phi nodes --- hercules_opt/src/sroa.rs | 3 +-- juno_frontend/src/lib.rs | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index ccc2605c..efa7bd90 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -342,6 +342,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: }, ); }); + to_delete.push(node); } else { worklist.push_back(WorkItem::AllocatedPhi { control, @@ -426,8 +427,6 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // Create new nodes nodes for (node_id, node) in to_insert { - assert!(node_id == editor.func().nodes.len()); - println!("Inserting {:?} : {:?}", node_id, node); editor.edit(|mut edit| { let id = edit.add_node(node); assert!(node_id == id.idx()); diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index e54e88fc..59240fe9 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -151,14 +151,13 @@ pub fn compile_ir( if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } - add_pass!(pm, verify, SROA); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } add_pass!(pm, verify, Inline); // Run SROA pretty early (though after inlining which can make SROA more effective) so that // CCP, GVN, etc. can work on the result of SROA add_pass!(pm, verify, SROA); + // We run phi-elim again because SROA can introduce new phis that might be able to be + // simplified + add_verified_pass!(pm, verify, PhiElim); if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } -- GitLab From 09c58d7553f34d92233fd6c8d75c3a9f44a0ab6c Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 19 Nov 2024 12:36:23 -0600 Subject: [PATCH 08/12] Fix SROA edit issues --- hercules_opt/src/sroa.rs | 18 ++++++++++++------ juno_frontend/src/lib.rs | 4 ++-- juno_samples/products.jn | 8 +++++++- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index efa7bd90..924c8d93 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -212,6 +212,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // enough information to process a particular node, we add it back to the worklist let mut next_id: usize = editor.func().nodes.len(); let mut to_insert = BTreeMap::new(); + let mut to_replace : Vec<(NodeID, NodeID)> = vec![]; while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { @@ -299,7 +300,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: let read_info = index_map.lookup(indices); match read_info { IndexTree::Leaf(field) => { - editor.edit(|edit| edit.replace_all_uses(node, *field)); + to_replace.push((node, *field)); } _ => {} } @@ -426,12 +427,17 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } // Create new nodes nodes - for (node_id, node) in to_insert { - editor.edit(|mut edit| { + editor.edit(|mut edit| { + for (node_id, node) in to_insert { let id = edit.add_node(node); - assert!(node_id == id.idx()); - Ok(edit) - }); + assert_eq!(node_id, id.idx()); + } + Ok(edit) + }); + + // Replace uses of old reads + for (old, new) in to_replace { + editor.edit(|edit| edit.replace_all_uses(old, new)); } // Remove nodes diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 59240fe9..7e6b61fa 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -69,11 +69,11 @@ impl fmt::Display for ErrorMessage { match self { ErrorMessage::SemanticError(errs) => { for err in errs { - write!(f, "{}", err)?; + write!(f, "{}\n", err)?; } } ErrorMessage::SchedulingError(msg) => { - write!(f, "{}", msg)?; + write!(f, "{}\n", msg)?; } } Ok(()) diff --git a/juno_samples/products.jn b/juno_samples/products.jn index 00b66aab..6c14e67a 100644 --- a/juno_samples/products.jn +++ b/juno_samples/products.jn @@ -1,6 +1,12 @@ fn test_call(x : i32, y : f32) -> (i32, f32) { let res = (x, y); - if x < 13 { res = (x + 1, y); } + for i = 0 to 10 { + if i % 2 == 0 { + res.0 += 1; + } else { + res.1 *= 2.0; + } + } return res; } -- GitLab From fb0e5c58edc4b755b3ecb238e751c1d8315b5cc1 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 20 Nov 2024 19:04:00 -0600 Subject: [PATCH 09/12] Fix SROA handling of select --- hercules_opt/src/sroa.rs | 45 +++++++++++++++++------------------ hercules_samples/products.hir | 3 +++ juno_samples/products.jn | 6 +---- 3 files changed, 26 insertions(+), 28 deletions(-) create mode 100644 hercules_samples/products.hir diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 924c8d93..6205421f 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -160,6 +160,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } } + #[derive(Debug)] enum WorkItem { Unhandled(NodeID), AllocatedPhi { @@ -176,9 +177,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: fields: IndexTree<NodeID>, }, AllocatedTernary { - first: NodeID, - second: NodeID, - third: NodeID, + cond: NodeID, + thn: NodeID, + els: NodeID, node: NodeID, fields: IndexTree<NodeID>, }, @@ -263,9 +264,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: field_map.insert(node, fields.clone()); item = WorkItem::AllocatedTernary { - first, - second, - third, + cond: first, + thn: second, + els: third, node, fields, }; @@ -387,37 +388,35 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } } WorkItem::AllocatedTernary { - first, - second, - third, + cond, + thn, + els, node, fields, } => { - if let (Some(fst_fields), Some(snd_fields), Some(thd_fields)) = ( - field_map.get(&first), - field_map.get(&second), - field_map.get(&third), + if let (Some(thn_fields), Some(els_fields)) = ( + field_map.get(&thn), + field_map.get(&els), ) { fields - .zip(fst_fields) - .zip(snd_fields) - .zip(thd_fields) - .for_each(|idx, (((res, fst), snd), thd)| { + .zip(thn_fields) + .zip(els_fields) + .for_each(|idx, ((res, thn), els)| { to_insert.insert( res.idx(), Node::Ternary { - first: **fst, - second: **snd, - third: **thd, + first: cond, + second: **thn, + third: **els, op: TernaryOperator::Select, }, ); }); } else { worklist.push_back(WorkItem::AllocatedTernary { - first, - second, - third, + cond, + thn, + els, node, fields, }); diff --git a/hercules_samples/products.hir b/hercules_samples/products.hir new file mode 100644 index 00000000..d09bb0fa --- /dev/null +++ b/hercules_samples/products.hir @@ -0,0 +1,3 @@ +fn test(x : prod(i32, f32), y: prod(i32, f32), b: bool) -> prod(i32, f32) + res = select(b, x, y) + r = return(start, res) diff --git a/juno_samples/products.jn b/juno_samples/products.jn index 6c14e67a..b97f1088 100644 --- a/juno_samples/products.jn +++ b/juno_samples/products.jn @@ -1,11 +1,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) { let res = (x, y); for i = 0 to 10 { - if i % 2 == 0 { - res.0 += 1; - } else { - res.1 *= 2.0; - } + res.0 += 1; } return res; } -- GitLab From 0bcdf93009233ea2522ce6d4c36013bd89d81532 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 20 Nov 2024 19:28:57 -0600 Subject: [PATCH 10/12] Delete selects --- hercules_opt/src/sroa.rs | 1 + hercules_samples/products.hir | 26 +++++++++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 6205421f..afbc775b 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -412,6 +412,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: }, ); }); + to_delete.push(node); } else { worklist.push_back(WorkItem::AllocatedTernary { cond, diff --git a/hercules_samples/products.hir b/hercules_samples/products.hir index d09bb0fa..9d191beb 100644 --- a/hercules_samples/products.hir +++ b/hercules_samples/products.hir @@ -1,3 +1,23 @@ -fn test(x : prod(i32, f32), y: prod(i32, f32), b: bool) -> prod(i32, f32) - res = select(b, x, y) - r = return(start, res) +fn test(x : prod(i32, f32), b: bool) -> prod(i32, f32) + zero = constant(u64, 0) + one = constant(i32, 1) + two = constant(u64, 2) + three = constant(f32, 3.0) + + f_ctrl = fork(start, 10) + idx = thread_id(f_ctrl, 0) + + mod2 = rem(idx, two) + is_even = eq(mod2, zero) + field0 = read(res, field(0)) + field1 = read(res, field(1)) + add = add(field0, one) + mul = mul(field1, three) + upd0 = write(res, add, field(0)) + upd1 = write(res, mul, field(1)) + select = select(is_even, upd0, upd1) + + j_ctrl = join(f_ctrl) + res = reduce(j_ctrl, x, select) + + r = return(j_ctrl, res) -- GitLab From 97c4bbcfd3f4a7e4d58a6d7d887b98267156ff2a Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 21 Nov 2024 09:02:13 -0600 Subject: [PATCH 11/12] Formatting --- hercules_opt/src/sroa.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index afbc775b..7430dbca 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -213,7 +213,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // enough information to process a particular node, we add it back to the worklist let mut next_id: usize = editor.func().nodes.len(); let mut to_insert = BTreeMap::new(); - let mut to_replace : Vec<(NodeID, NodeID)> = vec![]; + let mut to_replace: Vec<(NodeID, NodeID)> = vec![]; while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { @@ -394,10 +394,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: node, fields, } => { - if let (Some(thn_fields), Some(els_fields)) = ( - field_map.get(&thn), - field_map.get(&els), - ) { + if let (Some(thn_fields), Some(els_fields)) = + (field_map.get(&thn), field_map.get(&els)) + { fields .zip(thn_fields) .zip(els_fields) -- GitLab From cb4634e4cbebb396795ae8d2ceb9f94b408ebe89 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 22 Nov 2024 09:47:17 -0600 Subject: [PATCH 12/12] Fix bug when read later used in write - Need to track nodes that are replaced as we perform replacements --- hercules_opt/src/sroa.rs | 36 ++++++++++++++++++++++++++++++++++++ juno_samples/products.jn | 5 +++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 7430dbca..59ee4a8a 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -435,8 +435,44 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: }); // Replace uses of old reads + // Because a read that is being replaced could also be the node some other read is being + // replaced by (if the first read is then written into a product that is then read from again) + // we need to track what nodes have already been replaced (and by what) so we can properly + // replace uses without leaving users of nodes that should be deleted. + // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that + // maps to a particular node (which is needed to maintain the data structure efficiently) + let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new(); + let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new(); for (old, new) in to_replace { + let new = match replaced_by.get(&new) { + Some(res) => *res, + None => new, + }; + editor.edit(|edit| edit.replace_all_uses(old, new)); + replaced_by.insert(old, new); + + let mut replaced = vec![]; + match replaced_of.get_mut(&old) { + Some(res) => { + std::mem::swap(res, &mut replaced); + } + None => {} + } + + let new_of = match replaced_of.get_mut(&new) { + Some(res) => res, + None => { + replaced_of.insert(new, vec![]); + replaced_of.get_mut(&new).unwrap() + } + }; + new_of.push(old); + + for n in replaced { + replaced_by.insert(n, new); + new_of.push(n); + } } // Remove nodes diff --git a/juno_samples/products.jn b/juno_samples/products.jn index b97f1088..d39ca246 100644 --- a/juno_samples/products.jn +++ b/juno_samples/products.jn @@ -6,6 +6,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) { return res; } -fn test(x : i32, y : f32) -> (i32, f32) { - return test_call(x, y); +fn test(x : i32, y : f32) -> (f32, i32) { + let res = test_call(x, y); + return (res.1, res.0); } -- GitLab