From 2021ea3b0b5f45add9409b4c649470ac7a642ac6 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 19 Nov 2024 08:14:04 -0600 Subject: [PATCH] Comment change in editor and format --- hercules_opt/src/editor.rs | 2 + hercules_opt/src/pass.rs | 6 +- hercules_opt/src/sroa.rs | 224 ++++++++++++++++++++++++++----------- 3 files changed, 159 insertions(+), 73 deletions(-) diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index bf7cede2..2410a8ef 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -505,6 +505,8 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) { // Step 2: drop schedules for deleted nodes and create empty schedule lists // for added nodes. for deleted in total_edit.0.iter() { + // Nodes that were created and deleted using the same editor don't have + // an existing schedule, so ignore them if deleted.idx() < plan.schedules.len() { plan.schedules[deleted.idx()] = vec![]; } diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index a7d48efb..f8310ab7 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -506,11 +506,7 @@ impl PassManager { &types_ref, &def_uses[idx], ); - sroa( - &mut editor, - &reverse_postorders[idx], - &typing[idx] - ); + sroa(&mut editor, &reverse_postorders[idx], &typing[idx]); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 2dd0049b..d8e2021a 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -43,33 +43,43 @@ use crate::*; pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) { // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID // that contains the corresponding fields of the original value - let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); + let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); // First: determine all nodes which interact with products (as described above) - let mut product_nodes : Vec<NodeID> = vec![]; + let mut product_nodes: Vec<NodeID> = vec![]; // We track call and return nodes separately since they (may) require constructing new products // for the call's arguments or the return's value - let mut call_return_nodes : Vec<NodeID> = vec![]; + let mut call_return_nodes: Vec<NodeID> = vec![]; let func = editor.func(); for node in reverse_postorder { match func.nodes[node.idx()] { - Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } - | Node::Constant { .. } | Node::Write { .. } - | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select } - if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), - - Node::Read { collect, .. } - if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node), + Node::Phi { .. } + | Node::Reduce { .. } + | Node::Parameter { .. } + | Node::Constant { .. } + | Node::Write { .. } + | Node::Ternary { + first: _, + second: _, + third: _, + op: TernaryOperator::Select, + } if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node), + + Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => { + product_nodes.push(*node) + } // We add all calls to the call/return list and check their arguments later Node::Call { .. } => call_return_nodes.push(*node), Node::Return { control: _, data } if editor.get_type(types[data.idx()]).is_product() => - call_return_nodes.push(*node), + { + call_return_nodes.push(*node) + } - _ => () + _ => (), } } @@ -85,36 +95,54 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: Node::Return { control, data } => { assert!(editor.get_type(types[data.idx()]).is_product()); let control = *control; - let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); + let new_data = + reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes); editor.edit(|mut edit| { - edit.add_node(Node::Return { control, data: new_data }); + edit.add_node(Node::Return { + control, + data: new_data, + }); edit.delete_node(node) }); } - Node::Call { control, function, dynamic_constants, args } => { + Node::Call { + control, + function, + dynamic_constants, + args, + } => { let control = *control; let function = *function; let dynamic_constants = dynamic_constants.clone(); let args = args.clone(); // If the call returns a product, we generate reads for each field - let fields = - if editor.get_type(types[node.idx()]).is_product() { - Some(generate_reads(editor, types[node.idx()], node)) - } else { - None - }; + let fields = if editor.get_type(types[node.idx()]).is_product() { + Some(generate_reads(editor, types[node.idx()], node)) + } else { + None + }; let mut new_args = vec![]; for arg in args { if editor.get_type(types[arg.idx()]).is_product() { - new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes)); + new_args.push(reconstruct_product( + editor, + types[arg.idx()], + arg, + &mut product_nodes, + )); } else { new_args.push(arg); } } editor.edit(|mut edit| { - let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() }); + let new_call = edit.add_node(Node::Call { + control, + function, + dynamic_constants, + args: new_args.into(), + }); let edit = edit.replace_all_uses(node, new_call)?; let edit = edit.delete_node(node)?; @@ -128,15 +156,29 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: Ok(edit) }); } - _ => panic!("Processing non-call or return node") + _ => panic!("Processing non-call or return node"), } } enum WorkItem { Unhandled(NodeID), - AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> }, - AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> }, - AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> }, + AllocatedPhi { + control: NodeID, + data: Vec<NodeID>, + fields: IndexTree<NodeID>, + }, + AllocatedReduce { + control: NodeID, + init: NodeID, + reduct: NodeID, + fields: IndexTree<NodeID>, + }, + AllocatedTernary { + first: NodeID, + second: NodeID, + third: NodeID, + fields: IndexTree<NodeID>, + }, } // Now, we process the other nodes that deal with products. @@ -155,15 +197,17 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: field_map.insert(node, generate_constant_fields(editor, id)); to_delete.push(node); } - _ => { worklist.push_back(WorkItem::Unhandled(node)); } + _ => { + worklist.push_back(WorkItem::Unhandled(node)); + } } } // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. // We track the current NodeID and add nodes whenever possible, otherwise we return values to // the worklist - let mut next_id : usize = editor.func().nodes.len(); - let mut cur_id : usize = editor.func().nodes.len(); + let mut next_id: usize = editor.func().nodes.len(); + let mut cur_id: usize = editor.func().nodes.len(); while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { match &editor.func().nodes[node.idx()] { @@ -177,32 +221,59 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: item = WorkItem::AllocatedPhi { control, data: data.into(), - fields }; + fields, + }; } - Node::Reduce { control, init, reduct } => { + Node::Reduce { + control, + init, + reduct, + } => { let control = *control; let init = *init; let reduct = *reduct; let fields = allocate_fields(editor, types[node.idx()], &mut next_id); field_map.insert(node, fields.clone()); - item = WorkItem::AllocatedReduce { control, init, reduct, fields }; + item = WorkItem::AllocatedReduce { + control, + init, + reduct, + fields, + }; } - Node::Ternary { first, second, third, .. } => { + Node::Ternary { + first, + second, + third, + .. + } => { let first = *first; let second = *second; let third = *third; let fields = allocate_fields(editor, types[node.idx()], &mut next_id); field_map.insert(node, fields.clone()); - item = WorkItem::AllocatedTernary { first, second, third, fields }; + item = WorkItem::AllocatedTernary { + first, + second, + third, + fields, + }; } - Node::Write { collect, data, indices } => { + Node::Write { + collect, + data, + indices, + } => { if let Some(index_map) = field_map.get(collect) { if editor.get_type(types[data.idx()]).is_product() { if let Some(data_idx) = field_map.get(data) { - field_map.insert(node, index_map.clone().replace(indices, data_idx.clone())); + field_map.insert( + node, + index_map.clone().replace(indices, data_idx.clone()), + ); to_delete.push(node); } else { worklist.push_back(WorkItem::Unhandled(node)); @@ -220,9 +291,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: let read_info = index_map.lookup(indices); match read_info { IndexTree::Leaf(field) => { - editor.edit(|edit| { - edit.replace_all_uses(node, *field) - }); + editor.edit(|edit| edit.replace_all_uses(node, *field)); } _ => {} } @@ -233,12 +302,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } } - _ => panic!("Unexpected node type") + _ => panic!("Unexpected node type"), } } match item { WorkItem::Unhandled(_) => {} - _ => todo!() + _ => todo!(), } } @@ -268,7 +337,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { if let Index::Field(i) = idx[n] { match self { IndexTree::Leaf(_) => panic!("Invalid field"), - IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1), + IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1), } } else { // TODO: This could be hit because of an array inside of a product @@ -291,11 +360,11 @@ impl<T: std::fmt::Debug> IndexTree<T> { IndexTree::Node(mut ts) => { if i + 1 == ts.len() { let t = ts.pop().unwrap(); - ts.push(t.set_idx(idx, val, n+1)); + ts.push(t.set_idx(idx, val, n + 1)); } else { let mut t = ts.pop().unwrap(); std::mem::swap(&mut ts[i], &mut t); - t = t.set_idx(idx, val, n+1); + t = t.set_idx(idx, val, n + 1); std::mem::swap(&mut ts[i], &mut t); ts.push(t); } @@ -322,11 +391,11 @@ impl<T: std::fmt::Debug> IndexTree<T> { IndexTree::Node(mut ts) => { if i + 1 == ts.len() { let t = ts.pop().unwrap(); - ts.push(t.replace_idx(idx, val, n+1)); + ts.push(t.replace_idx(idx, val, n + 1)); } else { let mut t = ts.pop().unwrap(); std::mem::swap(&mut ts[i], &mut t); - t = t.replace_idx(idx, val, n+1); + t = t.replace_idx(idx, val, n + 1); std::mem::swap(&mut ts[i], &mut t); ts.push(t); } @@ -342,12 +411,16 @@ impl<T: std::fmt::Debug> IndexTree<T> { } fn for_each<F>(&self, mut f: F) - where F: FnMut(&Vec<Index>, &T) { + where + F: FnMut(&Vec<Index>, &T), + { self.for_each_idx(&mut vec![], &mut f); } fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F) - where F: FnMut(&Vec<Index>, &T) { + where + F: FnMut(&Vec<Index>, &T), + { match self { IndexTree::Leaf(t) => f(idx, t), IndexTree::Node(ts) => { @@ -365,7 +438,12 @@ impl<T: std::fmt::Debug> IndexTree<T> { // from that value and then writing them into a new constant // This process also adds all the read nodes that are generated into the read_list so that the // reads can be eliminated by later parts of SROA -fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID { +fn reconstruct_product( + editor: &mut FunctionEditor, + typ: TypeID, + val: NodeID, + read_list: &mut Vec<NodeID>, +) -> NodeID { let fields = generate_reads(editor, typ, val); let new_const = generate_constant(editor, typ); @@ -381,7 +459,11 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, re fields.for_each(|idx: &Vec<Index>, val: &NodeID| { read_list.push(*val); editor.edit(|mut edit| { - value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() }); + value = edit.add_node(Node::Write { + collect: value, + data: *val, + indices: idx.clone().into(), + }); Ok(edit) }); }); @@ -398,13 +480,17 @@ fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Inde // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list -fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> { - let ts : Option<Vec<TypeID>> = - if let Some(ts) = editor.get_type(typ).try_product() { - Some(ts.into()) - } else { - None - }; +fn generate_reads_at_index( + editor: &mut FunctionEditor, + typ: TypeID, + val: NodeID, + idx: Vec<Index>, +) -> IndexTree<NodeID> { + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; if let Some(ts) = ts { // For product values, we will recurse down each of its fields with an extended index @@ -421,7 +507,10 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID // information let mut read_id = None; editor.edit(|mut edit| { - read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() })); + read_id = Some(edit.add_node(Node::Read { + collect: val, + indices: idx.clone().into(), + })); Ok(edit) }); @@ -437,7 +526,7 @@ macro_rules! add_const { Ok(edit) }); res.expect("Add constant cannot fail") - }} + }}; } // Given a type, builds a default constant of that type @@ -470,14 +559,14 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { Type::Array(elem, _) => { add_const!(editor, Constant::Array(typ)) } - Type::Control => panic!("Cannot create constant of control type") + Type::Control => panic!("Cannot create constant of control type"), } } // Given a constant cnst adds node to the function which are the constant values of each field and // returns a list of pairs of indices and the node that holds that index fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> { - let cs : Option<Vec<ConstantID>> = + let cs: Option<Vec<ConstantID>> = if let Some(cs) = editor.get_constant(cnst).try_product_fields() { Some(cs.into()) } else { @@ -503,12 +592,11 @@ fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> In // Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the // id provided fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> { - let ts : Option<Vec<TypeID>> = - if let Some(ts) = editor.get_type(typ).try_product() { - Some(ts.into()) - } else { - None - }; + let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + Some(ts.into()) + } else { + None + }; if let Some(ts) = ts { let mut fields = vec![]; -- GitLab