From 6013f4791dbc2e49765e8117365f21cfbdb31f86 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 18 Nov 2024 21:20:00 -0600 Subject: [PATCH] Acyclic SROA working (mostly) --- hercules_opt/src/sroa.rs | 369 +++++++++++++++++++++++++++++++++------ juno_frontend/src/lib.rs | 2 +- 2 files changed, 317 insertions(+), 54 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index ae5d76dc..b9380197 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,6 +1,6 @@ extern crate hercules_ir; -use std::collections::{HashMap, LinkedList}; +use std::collections::{HashMap, LinkedList, VecDeque}; use self::hercules_ir::dataflow::*; use self::hercules_ir::ir::*; @@ -41,46 +41,33 @@ use crate::*; * replaced by a direct def of the field value */ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { - // This map stores a map from NodeID and fields to the NodeID which contains just that field of - // the original value - let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new(); + // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID + // that contains the corresponding fields of the original value + let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); // First: determine all nodes which interact with products (as described above) - let mut product_nodes = vec![]; + let mut product_nodes : Vec<NodeID> = vec![]; // We track call and return nodes separately since they (may) require constructing new products // for the call's arguments or the return's value - let mut call_return_nodes = vec![]; - // We track writes separately since they should be processed once their input product has been - // processed, so we handle them after all other nodes (though before reads) - let mut write_nodes = vec![]; - // We also track reads separately because they aren't handled until all other nodes have been - // processed - let mut read_nodes = vec![]; + let mut call_return_nodes : Vec<NodeID> = vec![]; let func = editor.func(); for node in reverse_postorder { match func.nodes[node.idx()] { Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } - | Node::Constant { .. } + | Node::Constant { .. } | Node::Write { .. } | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select } - if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node), + if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), - Node::Write { .. } - if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node), Node::Read { collect, .. } - if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node), + if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node), // We add all calls to the call/return list and check their arguments later - Node::Call { .. } => { - call_return_nodes.push(node); - if editor.get_type(types[node.idx()]).is_product() { - read_nodes.push(node); - } - } + Node::Call { .. } => call_return_nodes.push(*node), Node::Return { control: _, data } if editor.get_type(types[data.idx()]).is_product() => - call_return_nodes.push(node), + call_return_nodes.push(*node), _ => () } @@ -98,10 +85,10 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: Node::Return { control, data } => { assert!(editor.get_type(types[data.idx()]).is_product()); let control = *control; - let new_data = reconstruct_product(editor, types[data.idx()], *data); + let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); editor.edit(|mut edit| { edit.add_node(Node::Return { control, data: new_data }); - edit.delete_node(*node) + edit.delete_node(node) }); } Node::Call { control, function, dynamic_constants, args } => { @@ -113,7 +100,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // If the call returns a product, we generate reads for each field let fields = if editor.get_type(types[node.idx()]).is_product() { - Some(generate_reads(editor, types[node.idx()], *node)) + Some(generate_reads(editor, types[node.idx()], node)) } else { None }; @@ -121,23 +108,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: let mut new_args = vec![]; for arg in args { if editor.get_type(types[arg.idx()]).is_product() { - new_args.push(reconstruct_product(editor, types[arg.idx()], arg)); + new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes)); } else { new_args.push(arg); } } - editor.dit(|mut edit| { + editor.edit(|mut edit| { let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() }); - let edit = edit.replace_all_uses(*node, new_call)?; - let edit = edit.delete_node(*node)?; + let edit = edit.replace_all_uses(node, new_call)?; + let edit = edit.delete_node(node)?; match fields { None => {} Some(fields) => { - for (idx, node) in fields { - field_map.insert((new_call, idx), node); - } - } + field_map.insert(new_call, fields); + } } Ok(edit) @@ -147,14 +132,243 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } } - // Now, we process all other non-read/write nodes that deal with products. - // The first step is to identify the NodeIDs that contain each field of each of these nodes - todo!() + enum WorkItem { + Unhandled(NodeID), + AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> }, + AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> }, + AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> }, + } + + // Now, we process the other nodes that deal with products. + // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi, + // reduce, parameter, constant, and ternary. + // We do this in several steps: first we break apart parameters and constants + let mut to_delete = vec![]; + let mut worklist: VecDeque<WorkItem> = VecDeque::new(); + + for node in product_nodes { + match editor.func().nodes[node.idx()] { + Node::Parameter { .. } => { + field_map.insert(node, generate_reads(editor, types[node.idx()], node)); + } + Node::Constant { id } => { + field_map.insert(node, generate_constant_fields(editor, id)); + to_delete.push(node); + } + _ => { worklist.push_back(WorkItem::Unhandled(node)); } + } + } + + // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. + // We track the current NodeID and add nodes whenever possible, otherwise we return values to + // the worklist + let mut next_id : usize = editor.func().nodes.len(); + let mut cur_id : usize = editor.func().nodes.len(); + while let Some(mut item) = worklist.pop_front() { + if let WorkItem::Unhandled(node) = item { + match &editor.func().nodes[node.idx()] { + // For phi, reduce, and ternary, we break them apart into separate nodes for each field + Node::Phi { control, data } => { + let control = *control; + let data = data.clone(); + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedPhi { + control, + data: data.into(), + fields }; + } + Node::Reduce { control, init, reduct } => { + let control = *control; + let init = *init; + let reduct = *reduct; + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedReduce { control, init, reduct, fields }; + } + Node::Ternary { first, second, third, .. } => { + let first = *first; + let second = *second; + let third = *third; + let fields = allocate_fields(editor, types[node.idx()], &mut next_id); + field_map.insert(node, fields.clone()); + + item = WorkItem::AllocatedTernary { first, second, third, fields }; + } + + Node::Write { collect, data, indices } => { + if let Some(index_map) = field_map.get(collect) { + if editor.get_type(types[data.idx()]).is_product() { + if let Some(data_idx) = field_map.get(data) { + field_map.insert(node, index_map.clone().replace(indices, data_idx.clone())); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } else { + field_map.insert(node, index_map.clone().set(indices, *data)); + to_delete.push(node); + } + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } + Node::Read { collect, indices } => { + if let Some(index_map) = field_map.get(collect) { + let read_info = index_map.lookup(indices); + match read_info { + IndexTree::Leaf(field) => { + editor.edit(|edit| { + edit.replace_all_uses(node, *field) + }); + } + _ => {} + } + field_map.insert(node, read_info.clone()); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::Unhandled(node)); + } + } + + _ => panic!("Unexpected node type") + } + } + match item { + WorkItem::Unhandled(_) => {} + _ => todo!() + } + } + + // Actually deleting nodes seems to break things right now + /* + println!("{:?}", to_delete); + editor.edit(|mut edit| { + for node in to_delete { + edit = edit.delete_node(node)? + } + Ok(edit) + }); + */ +} + +// An index tree is used to store results at many index lists +#[derive(Clone, Debug)] +enum IndexTree<T> { + Leaf(T), + Node(Vec<IndexTree<T>>), +} + +impl<T: std::fmt::Debug> IndexTree<T> { + fn lookup(&self, idx: &[Index]) -> &IndexTree<T> { + self.lookup_idx(idx, 0) + } + + fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1), + } + } else { + // TODO: This could be hit because of an array inside of a product + panic!("Error handling lookup of field"); + } + } else { + self + } + } + + fn set(self, idx: &[Index], val: T) -> IndexTree<T> { + self.set_idx(idx, val, 0) + } + + fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(mut ts) => { + if i + 1 == ts.len() { + let t = ts.pop().unwrap(); + ts.push(t.set_idx(idx, val, n+1)); + } else { + let mut t = ts.pop().unwrap(); + std::mem::swap(&mut ts[i], &mut t); + t = t.set_idx(idx, val, n+1); + std::mem::swap(&mut ts[i], &mut t); + ts.push(t); + } + IndexTree::Node(ts) + } + } + } else { + panic!("Error handling set of field"); + } + } else { + IndexTree::Leaf(val) + } + } + + fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> { + self.replace_idx(idx, val, 0) + } + + fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> { + if n < idx.len() { + if let Index::Field(i) = idx[n] { + match self { + IndexTree::Leaf(_) => panic!("Invalid field"), + IndexTree::Node(mut ts) => { + if i + 1 == ts.len() { + let t = ts.pop().unwrap(); + ts.push(t.replace_idx(idx, val, n+1)); + } else { + let mut t = ts.pop().unwrap(); + std::mem::swap(&mut ts[i], &mut t); + t = t.replace_idx(idx, val, n+1); + std::mem::swap(&mut ts[i], &mut t); + ts.push(t); + } + IndexTree::Node(ts) + } + } + } else { + panic!("Error handling set of field"); + } + } else { + val + } + } + + fn for_each<F>(&self, mut f: F) + where F: FnMut(&Vec<Index>, &T) { + self.for_each_idx(&mut vec![], &mut f); + } + + fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F) + where F: FnMut(&Vec<Index>, &T) { + match self { + IndexTree::Leaf(t) => f(idx, t), + IndexTree::Node(ts) => { + for (i, t) in ts.iter().enumerate() { + idx.push(Index::Field(i)); + t.for_each_idx(idx, f); + idx.pop(); + } + } + } + } } // Given a product value val of type typ, constructs a copy of that value by extracting all fields // from that value and then writing them into a new constant -fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID { +// This process also adds all the read nodes that are generated into the read_list so that the +// reads can be eliminated by later parts of SROA +fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID { let fields = generate_reads(editor, typ, val); let new_const = generate_constant(editor, typ); @@ -167,44 +381,44 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> // Generate writes for each field let mut value = const_node.expect("Add node cannot fail"); - for (idx, val) in fields { + fields.for_each(|idx: &Vec<Index>, val: &NodeID| { + read_list.push(*val); editor.edit(|mut edit| { - value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() }); + value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() }); Ok(edit) }); - } + }); value } // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and // returns a list of pairs of the indices and the node that reads that index -fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> { - generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>() +fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + let res = generate_reads_at_index(editor, typ, val, vec![]); + res } // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list -fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> { - let ts : Option<Vec<TypeID>> = { +fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> { + let ts : Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { Some(ts.into()) } else { None - } - }; + }; if let Some(ts) = ts { // For product values, we will recurse down each of its fields with an extended index // and the appropriate type of that field - let mut result = LinkedList::new(); + let mut fields = vec![]; for (i, t) in ts.into_iter().enumerate() { let mut new_idx = idx.clone(); new_idx.push(Index::Field(i)); - result.append(&mut generate_reads_at_index(editor, t, val, new_idx)); + fields.push(generate_reads_at_index(editor, t, val, new_idx)); } - - result + IndexTree::Node(fields) } else { // For non-product types, we've reached a leaf so we generate the read and return it's // information @@ -214,7 +428,7 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID Ok(edit) }); - LinkedList::from([(idx, read_id.expect("Add node cannot fail"))]) + IndexTree::Leaf(read_id.expect("Add node canont fail")) } } @@ -262,3 +476,52 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { Type::Control => panic!("Cannot create constant of control type") } } + +// Given a constant cnst adds node to the function which are the constant values of each field and +// returns a list of pairs of indices and the node that holds that index +fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> { + let cs : Option<Vec<ConstantID>> = + if let Some(cs) = editor.get_constant(cnst).try_product_fields() { + Some(cs.into()) + } else { + None + }; + + if let Some(cs) = cs { + let mut fields = vec![]; + for c in cs { + fields.push(generate_constant_fields(editor, c)); + } + IndexTree::Node(fields) + } else { + let mut node = None; + editor.edit(|mut edit| { + node = Some(edit.add_node(Node::Constant { id: cnst })); + Ok(edit) + }); + IndexTree::Leaf(node.expect("Add node cannot fail")) + } +} + +// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the +// id provided +fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> { + let ts : Option<Vec<TypeID>> = + if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; + + if let Some(ts) = ts { + let mut fields = vec![]; + for t in ts { + fields.push(allocate_fields(editor, t, id)); + } + IndexTree::Node(fields) + } else { + let node = *id; + *id += 1; + IndexTree::Leaf(NodeID::new(node)) + } +} diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index d4263fe8..c3dc262a 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -157,7 +157,7 @@ pub fn compile_ir( pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } // TEMPORARY - add_pass!(pm, verify, Inline); + //add_pass!(pm, verify, Inline); add_pass!(pm, verify, CCP); add_pass!(pm, verify, DCE); add_pass!(pm, verify, GVN); -- GitLab