diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 66cb896d793bc3eb2cb56153a98653cc826a831d..ed340a7604cc2a4bc8c78f1e3743a8343fbb04ea 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -728,6 +728,14 @@ impl Type { None } } + + pub fn try_product(&self) -> Option<&[TypeID]> { + if let Type::Product(ts) = self { + Some(ts) + } else { + None + } + } } impl Constant { diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index fb5b1c1887d37fc6e909994e12376185fb1468dd..e6d8e3e076c05c15f183892a3c2517fc22448006 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -555,7 +555,11 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) { // Step 2: drop schedules for deleted nodes and create empty schedule lists // for added nodes. for deleted in total_edit.0.iter() { - plan.schedules[deleted.idx()] = vec![]; + // Nodes that were created and deleted using the same editor don't have + // an existing schedule, so ignore them + if deleted.idx() < plan.schedules.len() { + plan.schedules[deleted.idx()] = vec![]; + } } if !total_edit.1.is_empty() { assert_eq!( diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index fc7d894c32dba5eb085d666600db875bbb718644..3494e010c7b69d2558dc4437759efd89c8cd7030 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -543,16 +543,33 @@ impl PassManager { let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let typing = self.typing.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - sroa( + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( &mut self.module.functions[idx], + &constants_ref, + &dynamic_constants_ref, + &types_ref, &def_uses[idx], - &reverse_postorders[idx], - &typing[idx], - &self.module.types, - &mut self.module.constants, ); + sroa(&mut editor, &reverse_postorders[idx], &typing[idx]); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + let edits = &editor.edits(); + if let Some(plans) = self.plans.as_mut() { + repair_plan(&mut plans[idx], &self.module.functions[idx], edits); + } + let grave_mapping = self.module.functions[idx].delete_gravestones(); + if let Some(plans) = self.plans.as_mut() { + plans[idx].fix_gravestones(&grave_mapping); + } } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::Inline => { diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index cb5ecd252bfe7238a69b43c02f081ecb0fc4a7fe..59ee4a8ad54cea9627ae85185fe5fe04198e72cc 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,15 +1,12 @@ -extern crate bitvec; extern crate hercules_ir; -use std::collections::HashMap; -use std::iter::zip; - -use self::bitvec::prelude::*; +use std::collections::{BTreeMap, HashMap, LinkedList, VecDeque}; use self::hercules_ir::dataflow::*; -use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; +use crate::*; + /* * Top level function to run SROA, intraprocedurally. Product values can be used * and created by a relatively small number of nodes. Here are *all* of them: @@ -20,11 +17,11 @@ use self::hercules_ir::ir::*; * - Reduce: similarly to phis, reduce nodes can cycle product values through * reduction loops - these get broken up into reduces on the fields * - * + Return: can return a product - these are untouched, and are the sinks for - * unbroken product values + * - Return: can return a product - the product values will be constructed + * at the return site * - * + Parameter: can introduce a product - these are untouched, and are the - * sources for unbroken product values + * - Parameter: can introduce a product - reads will be introduced for each + * field * * - Constant: can introduce a product - these are broken up into constants for * the individual fields @@ -32,334 +29,796 @@ use self::hercules_ir::ir::*; * - Ternary: the select ternary operator can select between products - these * are broken up into ternary nodes for the individual fields * - * + Call: the call node can use a product value as an argument to another - * function, and can produce a product value as a result - these are - * untouched, and are the sink and source for unbroken product values + * - Call: the call node can use a product value as an argument to another + * function, and can produce a product value as a result. Argument values + * will be constructed at the call site and the return value will be broken + * into individual fields * * - Read: the read node reads primitive fields from product values - these get - * replaced by a direct use of the field value from the broken product value, - * but are retained when the product value is unbroken + * replaced by a direct use of the field value * * - Write: the write node writes primitive fields in product values - these get - * replaced by a direct def of the field value from the broken product value, - * but are retained when the product value is unbroken - * - * The nodes above with the list marker "+" are retained for maintaining API/ABI - * compatability with other Hercules functions and the host code. These are - * called "sink" or "source" nodes in comments below. + * replaced by a direct def of the field value */ -pub fn sroa( - function: &mut Function, - def_use: &ImmutableDefUseMap, - reverse_postorder: &Vec<NodeID>, - typing: &Vec<TypeID>, - types: &Vec<Type>, - constants: &mut Vec<Constant>, -) { - // Determine which sources of product values we want to try breaking up. We - // can determine easily on the soure side if a node produces a product that - // shouldn't be broken up by just examining the node type. However, the way - // that products are used is also important for determining if the product - // can be broken up. We backward dataflow this info to the sources of - // product values. - #[derive(PartialEq, Eq, Clone, Debug)] - enum ProductUseLattice { - // The product value used by this node is eventually used by a sink. - UsedBySink, - // This node uses multiple product values - the stored node ID indicates - // which is eventually used by a sink. This lattice value is produced by - // read and write nodes implementing partial indexing. - SpecificUsedBySink(NodeID), - // This node doesn't use a product node, or the product node it does use - // is not in turn used by a sink. - UnusedBySink, +pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { + // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID + // that contains the corresponding fields of the original value + let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); + + // First: determine all nodes which interact with products (as described above) + let mut product_nodes: Vec<NodeID> = vec![]; + // We track call and return nodes separately since they (may) require constructing new products + // for the call's arguments or the return's value + let mut call_return_nodes: Vec<NodeID> = vec![]; + + let func = editor.func(); + + for node in reverse_postorder { + match func.nodes[node.idx()] { + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Write { .. } + | Node::Ternary { + first: _, + second: _, + third: _, + op: TernaryOperator::Select, + } if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), + + Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => { + product_nodes.push(*node) + } + + // We add all calls to the call/return list and check their arguments later + Node::Call { .. } => call_return_nodes.push(*node), + Node::Return { control: _, data } + if editor.get_type(types[data.idx()]).is_product() => + { + call_return_nodes.push(*node) + } + + _ => (), + } } - impl Semilattice for ProductUseLattice { - fn meet(a: &Self, b: &Self) -> Self { - match (a, b) { - (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink, - (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => { - if id1 == id2 { - Self::SpecificUsedBySink(*id1) + // Next, we handle calls and returns. For returns, we will insert nodes that read each field of + // the returned product and then write them into a new product. These writes are not put into + // the list of product nodes since they must remain but the reads are so that they will be + // replaced later on. + // For calls, we do a similar process for each (product) argument. Additionally, if the call + // returns a product, we create reads for each field in that product and store it into our + // field map + for node in call_return_nodes { + match &editor.func().nodes[node.idx()] { + Node::Return { control, data } => { + assert!(editor.get_type(types[data.idx()]).is_product()); + let control = *control; + let new_data = + reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); + editor.edit(|mut edit| { + edit.add_node(Node::Return { + control, + data: new_data, + }); + edit.delete_node(node) + }); + } + Node::Call { + control, + function, + dynamic_constants, + args, + } => { + let control = *control; + let function = *function; + let dynamic_constants = dynamic_constants.clone(); + let args = args.clone(); + + // If the call returns a product, we generate reads for each field + let fields = if editor.get_type(types[node.idx()]).is_product() { + Some(generate_reads(editor, types[node.idx()], node)) + } else { + None + }; + + let mut new_args = vec![]; + for arg in args { + if editor.get_type(types[arg.idx()]).is_product() { + new_args.push(reconstruct_product( + editor, + types[arg.idx()], + arg, + &mut product_nodes, + )); } else { - Self::UsedBySink + new_args.push(arg); } } - (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => { - Self::SpecificUsedBySink(*id) - } - _ => Self::UnusedBySink, + editor.edit(|mut edit| { + let new_call = edit.add_node(Node::Call { + control, + function, + dynamic_constants, + args: new_args.into(), + }); + let edit = edit.replace_all_uses(node, new_call)?; + let edit = edit.delete_node(node)?; + + match fields { + None => {} + Some(fields) => { + field_map.insert(new_call, fields); + } + } + + Ok(edit) + }); } + _ => panic!("Processing non-call or return node"), } + } - fn bottom() -> Self { - Self::UsedBySink - } + #[derive(Debug)] + enum WorkItem { + Unhandled(NodeID), + AllocatedPhi { + control: NodeID, + data: Vec<NodeID>, + node: NodeID, + fields: IndexTree<NodeID>, + }, + AllocatedReduce { + control: NodeID, + init: NodeID, + reduct: NodeID, + node: NodeID, + fields: IndexTree<NodeID>, + }, + AllocatedTernary { + cond: NodeID, + thn: NodeID, + els: NodeID, + node: NodeID, + fields: IndexTree<NodeID>, + }, + } - fn top() -> Self { - Self::UnusedBySink + // Now, we process the other nodes that deal with products. + // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi, + // reduce, parameter, constant, and ternary. + // We do this in several steps: first we break apart parameters and constants + let mut to_delete = vec![]; + let mut worklist: VecDeque<WorkItem> = VecDeque::new(); + + for node in product_nodes { + match editor.func().nodes[node.idx()] { + Node::Parameter { .. } => { + field_map.insert(node, generate_reads(editor, types[node.idx()], node)); + } + Node::Constant { id } => { + field_map.insert(node, generate_constant_fields(editor, id)); + to_delete.push(node); + } + _ => { + worklist.push_back(WorkItem::Unhandled(node)); + } } } - // Run dataflow analysis to find which product values are used by a sink. - let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| { - match function.nodes[id.idx()] { - Node::Return { - control: _, - data: _, - } => { - if types[typing[id.idx()].idx()].is_product() { - ProductUseLattice::UsedBySink - } else { - ProductUseLattice::UnusedBySink + // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. + // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we + // need to add nodes in a particular order we wait to do that until the end). If we don't have + // enough information to process a particular node, we add it back to the worklist + let mut next_id: usize = editor.func().nodes.len(); + let mut to_insert = BTreeMap::new(); + let mut to_replace: Vec<(NodeID, NodeID)> = vec![]; + + while let Some(mut item) = worklist.pop_front() { + if let WorkItem::Unhandled(node) = item { + match &editor.func().nodes[node.idx()] { + // For phi, reduce, and ternary, we break them apart into separate nodes for each field + Node::Phi { control, data } => { + let control = *control; + let data = data.clone(); + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedPhi { + control, + data: data.into(), + node, + fields, + }; } + Node::Reduce { + control, + init, + reduct, + } => { + let control = *control; + let init = *init; + let reduct = *reduct; + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, + }; + } + Node::Ternary { + first, + second, + third, + .. + } => { + let first = *first; + let second = *second; + let third = *third; + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedTernary { + cond: first, + thn: second, + els: third, + node, + fields, + }; + } + + Node::Write { + collect, + data, + indices, + } => { + if let Some(index_map) = field_map.get(collect) { + if editor.get_type(types[data.idx()]).is_product() { + if let Some(data_idx) = field_map.get(data) { + field_map.insert( + node, + index_map.clone().replace(indices, data_idx.clone()), + ); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } else { + field_map.insert(node, index_map.clone().set(indices, *data)); + to_delete.push(node); + } + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } + Node::Read { collect, indices } => { + if let Some(index_map) = field_map.get(collect) { + let read_info = index_map.lookup(indices); + match read_info { + IndexTree::Leaf(field) => { + to_replace.push((node, *field)); + } + _ => {} + } + field_map.insert(node, read_info.clone()); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } + + _ => panic!("Unexpected node type"), } - Node::Call { - control: _, - function: _, - dynamic_constants: _, - args: _, - } => todo!(), - // For reads and writes, we only want to propagate the use of the - // product to the collect input of the node. - Node::Read { - collect, - indices: _, - } - | Node::Write { - collect, - data: _, - indices: _, + } + match item { + WorkItem::Unhandled(_) => {} + WorkItem::AllocatedPhi { + control, + data, + node, + fields, } => { - let meet = succ_outs - .iter() - .fold(ProductUseLattice::top(), |acc, latt| { - ProductUseLattice::meet(&acc, latt) + let mut data_fields = vec![]; + let mut ready = true; + for val in data.iter() { + if let Some(val_fields) = field_map.get(val) { + data_fields.push(val_fields); + } else { + ready = false; + break; + } + } + + if ready { + fields.zip_list(data_fields).for_each(|idx, (res, data)| { + to_insert.insert( + res.idx(), + Node::Phi { + control, + data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(), + }, + ); }); - if meet == ProductUseLattice::UnusedBySink { - ProductUseLattice::UnusedBySink + to_delete.push(node); } else { - ProductUseLattice::SpecificUsedBySink(collect) + worklist.push_back(WorkItem::AllocatedPhi { + control, + data, + node, + fields, + }); } } - // For non-sink nodes. - _ => { - if function.nodes[id.idx()].is_control() { - return ProductUseLattice::UnusedBySink; - } - let meet = succ_outs - .iter() - .fold(ProductUseLattice::top(), |acc, latt| { - ProductUseLattice::meet(&acc, latt) + WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, + } => { + if let (Some(init_fields), Some(reduct_fields)) = + (field_map.get(&init), field_map.get(&reduct)) + { + fields.zip(init_fields).zip(reduct_fields).for_each( + |idx, ((res, init), reduct)| { + to_insert.insert( + res.idx(), + Node::Reduce { + control, + init: **init, + reduct: **reduct, + }, + ); + }, + ); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, }); - if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet { - if meet_id == id { - ProductUseLattice::UsedBySink - } else { - ProductUseLattice::UnusedBySink - } + } + } + WorkItem::AllocatedTernary { + cond, + thn, + els, + node, + fields, + } => { + if let (Some(thn_fields), Some(els_fields)) = + (field_map.get(&thn), field_map.get(&els)) + { + fields + .zip(thn_fields) + .zip(els_fields) + .for_each(|idx, ((res, thn), els)| { + to_insert.insert( + res.idx(), + Node::Ternary { + first: cond, + second: **thn, + third: **els, + op: TernaryOperator::Select, + }, + ); + }); + to_delete.push(node); } else { - meet + worklist.push_back(WorkItem::AllocatedTernary { + cond, + thn, + els, + node, + fields, + }); } } } + } + + // Create new nodes nodes + editor.edit(|mut edit| { + for (node_id, node) in to_insert { + let id = edit.add_node(node); + assert_eq!(node_id, id.idx()); + } + Ok(edit) }); - // Only product values introduced as constants can be replaced by scalars. - let to_sroa: Vec<(NodeID, ConstantID)> = product_uses - .into_iter() - .enumerate() - .filter_map(|(node_idx, product_use)| { - if ProductUseLattice::UnusedBySink == product_use - && types[typing[node_idx].idx()].is_product() - { - function.nodes[node_idx] - .try_constant() - .map(|cons_id| (NodeID::new(node_idx), cons_id)) - } else { - None + // Replace uses of old reads + // Because a read that is being replaced could also be the node some other read is being + // replaced by (if the first read is then written into a product that is then read from again) + // we need to track what nodes have already been replaced (and by what) so we can properly + // replace uses without leaving users of nodes that should be deleted. + // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that + // maps to a particular node (which is needed to maintain the data structure efficiently) + let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new(); + let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new(); + for (old, new) in to_replace { + let new = match replaced_by.get(&new) { + Some(res) => *res, + None => new, + }; + + editor.edit(|edit| edit.replace_all_uses(old, new)); + replaced_by.insert(old, new); + + let mut replaced = vec![]; + match replaced_of.get_mut(&old) { + Some(res) => { + std::mem::swap(res, &mut replaced); } - }) - .collect(); - - // Perform SROA. TODO: repair def-use when there are multiple product - // constants to SROA away. - assert!(to_sroa.len() < 2); - for (constant_node_id, constant_id) in to_sroa { - // Get the field constants to replace the product constant with. - let product_constant = constants[constant_id.idx()].clone(); - let constant_fields = product_constant.try_product_fields().unwrap(); - - // DFS to find all data nodes that use the product constant. - let to_replace = sroa_dfs(constant_node_id, function, def_use); - - // Assemble a mapping from old nodes IDs acting on the product constant - // to new nodes IDs operating on the field constants. - let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace - .iter() - .map(|old_id| match function.nodes[old_id.idx()] { - Node::Phi { - control: _, - data: _, + None => {} + } + + let new_of = match replaced_of.get_mut(&new) { + Some(res) => res, + None => { + replaced_of.insert(new, vec![]); + replaced_of.get_mut(&new).unwrap() + } + }; + new_of.push(old); + + for n in replaced { + replaced_by.insert(n, new); + new_of.push(n); + } + } + + // Remove nodes + editor.edit(|mut edit| { + for node in to_delete { + edit = edit.delete_node(node)? + } + Ok(edit) + }); +} + +// An index tree is used to store results at many index lists +#[derive(Clone, Debug)] +enum IndexTree<T> { + Leaf(T), + Node(Vec<IndexTree<T>>), +} + +impl<T: std::fmt::Debug> IndexTree<T> { + fn lookup(&self, idx: &[Index]) -> &IndexTree<T> { + self.lookup_idx(idx, 0) + } + + fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1), } - | Node::Reduce { - control: _, - init: _, - reduct: _, + } else { + // TODO: This could be hit because of an array inside of a product + panic!("Error handling lookup of field"); + } + } else { + self + } + } + + fn set(self, idx: &[Index], val: T) -> IndexTree<T> { + self.set_idx(idx, val, 0) + } + + fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(mut ts) => { + if i + 1 == ts.len() { + let t = ts.pop().unwrap(); + ts.push(t.set_idx(idx, val, n + 1)); + } else { + let mut t = ts.pop().unwrap(); + std::mem::swap(&mut ts[i], &mut t); + t = t.set_idx(idx, val, n + 1); + std::mem::swap(&mut ts[i], &mut t); + ts.push(t); + } + IndexTree::Node(ts) + } } - | Node::Constant { id: _ } - | Node::Ternary { - op: _, - first: _, - second: _, - third: _, + } else { + panic!("Error handling set of field"); + } + } else { + IndexTree::Leaf(val) + } + } + + fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> { + self.replace_idx(idx, val, 0) + } + + fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(mut ts) => { + if i + 1 == ts.len() { + let t = ts.pop().unwrap(); + ts.push(t.replace_idx(idx, val, n + 1)); + } else { + let mut t = ts.pop().unwrap(); + std::mem::swap(&mut ts[i], &mut t); + t = t.replace_idx(idx, val, n + 1); + std::mem::swap(&mut ts[i], &mut t); + ts.push(t); + } + IndexTree::Node(ts) + } } - | Node::Write { - collect: _, - data: _, - indices: _, - } => { - let new_ids = (0..constant_fields.len()) - .map(|_| { - let id = NodeID::new(function.nodes.len()); - function.nodes.push(Node::Start); - id - }) - .collect(); - (*old_id, new_ids) + } else { + panic!("Error handling set of field"); + } + } else { + val + } + } + + fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> { + match (self, other) { + (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)), + (IndexTree::Node(t), IndexTree::Node(a)) => { + let mut fields = vec![]; + for (t, a) in t.into_iter().zip(a.iter()) { + fields.push(t.zip(a)); } - Node::Read { - collect: _, - indices: _, - } => (*old_id, vec![]), - _ => panic!("PANIC: Invalid node using a constant product found during SROA."), - }) - .collect(); - - // Replace the old nodes with the new nodes. Since we've already - // allocated the node IDs, at this point we can iterate through the to- - // replace nodes in an arbitrary order. - for (old_id, new_ids) in &old_to_new_id_map { - // First, add the new nodes to the node list. - let node = function.nodes[old_id.idx()].clone(); - match node { - // Replace the original constant with constants for each field. - Node::Constant { id: _ } => { - for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) { - function.nodes[new_id.idx()] = Node::Constant { id: *field_id }; + IndexTree::Node(fields) + } + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + + fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> { + match self { + IndexTree::Leaf(t) => { + let mut res = vec![]; + for other in others { + match other { + IndexTree::Leaf(a) => res.push(a), + _ => panic!("IndexTrees do not have the same fields, cannot zip"), } } - // Replace writes using the constant as the data use with a - // series of writes writing the invidiual constant fields. TODO: - // handle the case where the constant is the collect use of the - // write node. - Node::Write { - collect, - data, - ref indices, - } => { - // Create the write chain. - assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet."); - let mut collect_def = collect; - for (idx, (new_id, new_data_def)) in - zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate() - { - let mut new_indices = indices.clone().into_vec(); - new_indices.push(Index::Field(idx)); - function.nodes[new_id.idx()] = Node::Write { - collect: collect_def, - data: *new_data_def, - indices: new_indices.into_boxed_slice(), - }; - collect_def = *new_id; - } - - // Replace uses of the old write with the new write. - for user in def_use.get_users(*old_id) { - get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def); + IndexTree::Leaf((t, res)) + } + IndexTree::Node(t) => { + let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()]; + for other in others { + match other { + IndexTree::Node(a) => { + for (i, a) in a.iter().enumerate() { + fields[i].push(a); + } + } + _ => panic!("IndexTrees do not have the same fields, cannot zip"), } } - _ => todo!(), + IndexTree::Node( + t.into_iter() + .zip(fields.into_iter()) + .map(|(t, f)| t.zip_list(f)) + .collect(), + ) } + } + } + + fn for_each<F>(&self, mut f: F) + where + F: FnMut(&Vec<Index>, &T), + { + self.for_each_idx(&mut vec![], &mut f); + } - // Delete the old node. - function.nodes[old_id.idx()] = Node::Start; + fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F) + where + F: FnMut(&Vec<Index>, &T), + { + match self { + IndexTree::Leaf(t) => f(idx, t), + IndexTree::Node(ts) => { + for (i, t) in ts.iter().enumerate() { + idx.push(Index::Field(i)); + t.for_each_idx(idx, f); + idx.pop(); + } + } } } } -fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> { - // Initialize order vector and bitset for tracking which nodes have been - // visited. - let order = Vec::with_capacity(def_uses.num_nodes()); - let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()]; +// Given a product value val of type typ, constructs a copy of that value by extracting all fields +// from that value and then writing them into a new constant +// This process also adds all the read nodes that are generated into the read_list so that the +// reads can be eliminated by later parts of SROA +fn reconstruct_product( + editor: &mut FunctionEditor, + typ: TypeID, + val: NodeID, + read_list: &mut Vec<NodeID>, +) -> NodeID { + let fields = generate_reads(editor, typ, val); + let new_const = generate_constant(editor, typ); + + // Create a constant node + let mut const_node = None; + editor.edit(|mut edit| { + const_node = Some(edit.add_node(Node::Constant { id: new_const })); + Ok(edit) + }); + + // Generate writes for each field + let mut value = const_node.expect("Add node cannot fail"); + fields.for_each(|idx: &Vec<Index>, val: &NodeID| { + read_list.push(*val); + editor.edit(|mut edit| { + value = edit.add_node(Node::Write { + collect: value, + data: *val, + indices: idx.clone().into(), + }); + Ok(edit) + }); + }); + + value +} - // Order and visited are threaded through arguments / return pair of - // sroa_dfs_helper for ownership reasons. - let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited); - order +// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and +// returns a list of pairs of the indices and the node that reads that index +fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + let res = generate_reads_at_index(editor, typ, val, vec![]); + res } -fn sroa_dfs_helper( - node: NodeID, - def: NodeID, - function: &Function, - def_uses: &ImmutableDefUseMap, - mut order: Vec<NodeID>, - mut visited: BitVec<u8, Lsb0>, -) -> (Vec<NodeID>, BitVec<u8, Lsb0>) { - if visited[node.idx()] { - // If already visited, return early. - (order, visited) +// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) +// fields within this sub-value of val and return the correspondence list +fn generate_reads_at_index( + editor: &mut FunctionEditor, + typ: TypeID, + val: NodeID, + idx: Vec<Index>, +) -> IndexTree<NodeID> { + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) } else { - // Set visited to true. - visited.set(node.idx(), true); - - // Before iterating users, push this node. - order.push(node); - match function.nodes[node.idx()] { - Node::Phi { - control: _, - data: _, - } - | Node::Reduce { - control: _, - init: _, - reduct: _, - } - | Node::Constant { id: _ } - | Node::Ternary { - op: _, - first: _, - second: _, - third: _, - } => {} - Node::Read { - collect, - indices: _, - } => { - assert_eq!(def, collect); - return (order, visited); - } - Node::Write { - collect, - data, - indices: _, - } => { - if def == data { - return (order, visited); - } - assert_eq!(def, collect); + None + }; + + if let Some(ts) = ts { + // For product values, we will recurse down each of its fields with an extended index + // and the appropriate type of that field + let mut fields = vec![]; + for (i, t) in ts.into_iter().enumerate() { + let mut new_idx = idx.clone(); + new_idx.push(Index::Field(i)); + fields.push(generate_reads_at_index(editor, t, val, new_idx)); + } + IndexTree::Node(fields) + } else { + // For non-product types, we've reached a leaf so we generate the read and return it's + // information + let mut read_id = None; + editor.edit(|mut edit| { + read_id = Some(edit.add_node(Node::Read { + collect: val, + indices: idx.clone().into(), + })); + Ok(edit) + }); + + IndexTree::Leaf(read_id.expect("Add node canont fail")) + } +} + +macro_rules! add_const { + ($editor:ident, $const:expr) => {{ + let mut res = None; + $editor.edit(|mut edit| { + res = Some(edit.add_constant($const)); + Ok(edit) + }); + res.expect("Add constant cannot fail") + }}; +} + +// Given a type, builds a default constant of that type +fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { + let t = editor.get_type(typ).clone(); + + match t { + Type::Product(ts) => { + let mut cs = vec![]; + for t in ts { + cs.push(generate_constant(editor, t)); } - _ => panic!("PANIC: Invalid node using a constant product found during SROA."), + add_const!(editor, Constant::Product(typ, cs.into())) + } + Type::Boolean => add_const!(editor, Constant::Boolean(false)), + Type::Integer8 => add_const!(editor, Constant::Integer8(0)), + Type::Integer16 => add_const!(editor, Constant::Integer16(0)), + Type::Integer32 => add_const!(editor, Constant::Integer32(0)), + Type::Integer64 => add_const!(editor, Constant::Integer64(0)), + Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)), + Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)), + Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)), + Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)), + Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))), + Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))), + Type::Summation(ts) => { + let const_id = generate_constant(editor, ts[0]); + add_const!(editor, Constant::Summation(typ, 0, const_id)) } + Type::Array(elem, _) => { + add_const!(editor, Constant::Array(typ)) + } + Type::Control => panic!("Cannot create constant of control type"), + } +} - // Iterate over users, if we shouldn't stop here. - for user in def_uses.get_users(node) { - (order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited); +// Given a constant cnst adds node to the function which are the constant values of each field and +// returns a list of pairs of indices and the node that holds that index +fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> { + let cs: Option<Vec<ConstantID>> = + if let Some(cs) = editor.get_constant(cnst).try_product_fields() { + Some(cs.into()) + } else { + None + }; + + if let Some(cs) = cs { + let mut fields = vec![]; + for c in cs { + fields.push(generate_constant_fields(editor, c)); } + IndexTree::Node(fields) + } else { + let mut node = None; + editor.edit(|mut edit| { + node = Some(edit.add_node(Node::Constant { id: cnst })); + Ok(edit) + }); + IndexTree::Leaf(node.expect("Add node cannot fail")) + } +} + +// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the +// id provided +fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> { + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; - (order, visited) + if let Some(ts) = ts { + let mut fields = vec![]; + for t in ts { + fields.push(allocate_fields(editor, t, id)); + } + IndexTree::Node(fields) + } else { + let node = *id; + *id += 1; + IndexTree::Leaf(NodeID::new(node)) } } diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index de4c1aa612e7d7df9fa35ed719a4bfa60bcf1d09..b5ce14565393217189887fe4805939ef3c4e48fb 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -69,11 +69,11 @@ impl fmt::Display for ErrorMessage { match self { ErrorMessage::SemanticError(errs) => { for err in errs { - write!(f, "{}", err)?; + write!(f, "{}\n", err)?; } } ErrorMessage::SchedulingError(msg) => { - write!(f, "{}", msg)?; + write!(f, "{}\n", msg)?; } } Ok(()) @@ -152,11 +152,20 @@ pub fn compile_ir( pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } add_pass!(pm, verify, Inline); + // Run SROA pretty early (though after inlining which can make SROA more effective) so that + // CCP, GVN, etc. can work on the result of SROA + add_pass!(pm, verify, InterproceduralSROA); + add_pass!(pm, verify, SROA); + // We run phi-elim again because SROA can introduce new phis that might be able to be + // simplified + add_verified_pass!(pm, verify, PhiElim); + if x_dot { + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + } add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); add_pass!(pm, verify, GVN); add_pass!(pm, verify, DCE); - add_pass!(pm, verify, InterproceduralSROA); if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } diff --git a/juno_samples/products.jn b/juno_samples/products.jn new file mode 100644 index 0000000000000000000000000000000000000000..d39ca2464d35d63c40f688f7f7d1354d1ab0a56a --- /dev/null +++ b/juno_samples/products.jn @@ -0,0 +1,12 @@ +fn test_call(x : i32, y : f32) -> (i32, f32) { + let res = (x, y); + for i = 0 to 10 { + res.0 += 1; + } + return res; +} + +fn test(x : i32, y : f32) -> (f32, i32) { + let res = test_call(x, y); + return (res.1, res.0); +}