diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index d8e2021a3431ba2b842901022a417e129e796df8..ccc2605c5f3da9fd85eb4d596b54f504ddbea08f 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1,6 +1,6 @@ extern crate hercules_ir; -use std::collections::{HashMap, LinkedList, VecDeque}; +use std::collections::{BTreeMap, HashMap, LinkedList, VecDeque}; use self::hercules_ir::dataflow::*; use self::hercules_ir::ir::*; @@ -165,18 +165,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: AllocatedPhi { control: NodeID, data: Vec<NodeID>, + node: NodeID, fields: IndexTree<NodeID>, }, AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, + node: NodeID, fields: IndexTree<NodeID>, }, AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, + node: NodeID, fields: IndexTree<NodeID>, }, } @@ -204,10 +207,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. - // We track the current NodeID and add nodes whenever possible, otherwise we return values to - // the worklist + // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we + // need to add nodes in a particular order we wait to do that until the end). If we don't have + // enough information to process a particular node, we add it back to the worklist let mut next_id: usize = editor.func().nodes.len(); - let mut cur_id: usize = editor.func().nodes.len(); + let mut to_insert = BTreeMap::new(); + while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { match &editor.func().nodes[node.idx()] { @@ -221,6 +226,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: item = WorkItem::AllocatedPhi { control, data: data.into(), + node, fields, }; } @@ -239,6 +245,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: control, init, reduct, + node, fields, }; } @@ -258,6 +265,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: first, second, third, + node, fields, }; } @@ -307,11 +315,127 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: } match item { WorkItem::Unhandled(_) => {} - _ => todo!(), + WorkItem::AllocatedPhi { + control, + data, + node, + fields, + } => { + let mut data_fields = vec![]; + let mut ready = true; + for val in data.iter() { + if let Some(val_fields) = field_map.get(val) { + data_fields.push(val_fields); + } else { + ready = false; + break; + } + } + + if ready { + fields.zip_list(data_fields).for_each(|idx, (res, data)| { + to_insert.insert( + res.idx(), + Node::Phi { + control, + data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(), + }, + ); + }); + } else { + worklist.push_back(WorkItem::AllocatedPhi { + control, + data, + node, + fields, + }); + } + } + WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, + } => { + if let (Some(init_fields), Some(reduct_fields)) = + (field_map.get(&init), field_map.get(&reduct)) + { + fields.zip(init_fields).zip(reduct_fields).for_each( + |idx, ((res, init), reduct)| { + to_insert.insert( + res.idx(), + Node::Reduce { + control, + init: **init, + reduct: **reduct, + }, + ); + }, + ); + to_delete.push(node); + } else { + worklist.push_back(WorkItem::AllocatedReduce { + control, + init, + reduct, + node, + fields, + }); + } + } + WorkItem::AllocatedTernary { + first, + second, + third, + node, + fields, + } => { + if let (Some(fst_fields), Some(snd_fields), Some(thd_fields)) = ( + field_map.get(&first), + field_map.get(&second), + field_map.get(&third), + ) { + fields + .zip(fst_fields) + .zip(snd_fields) + .zip(thd_fields) + .for_each(|idx, (((res, fst), snd), thd)| { + to_insert.insert( + res.idx(), + Node::Ternary { + first: **fst, + second: **snd, + third: **thd, + op: TernaryOperator::Select, + }, + ); + }); + } else { + worklist.push_back(WorkItem::AllocatedTernary { + first, + second, + third, + node, + fields, + }); + } + } } } - // Actually deleting nodes seems to break things right now + // Create new nodes nodes + for (node_id, node) in to_insert { + assert!(node_id == editor.func().nodes.len()); + println!("Inserting {:?} : {:?}", node_id, node); + editor.edit(|mut edit| { + let id = edit.add_node(node); + assert!(node_id == id.idx()); + Ok(edit) + }); + } + + // Remove nodes editor.edit(|mut edit| { for node in to_delete { edit = edit.delete_node(node)? @@ -410,6 +534,54 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } + fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> { + match (self, other) { + (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)), + (IndexTree::Node(t), IndexTree::Node(a)) => { + let mut fields = vec![]; + for (t, a) in t.into_iter().zip(a.iter()) { + fields.push(t.zip(a)); + } + IndexTree::Node(fields) + } + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + + fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> { + match self { + IndexTree::Leaf(t) => { + let mut res = vec![]; + for other in others { + match other { + IndexTree::Leaf(a) => res.push(a), + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + IndexTree::Leaf((t, res)) + } + IndexTree::Node(t) => { + let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()]; + for other in others { + match other { + IndexTree::Node(a) => { + for (i, a) in a.iter().enumerate() { + fields[i].push(a); + } + } + _ => panic!("IndexTrees do not have the same fields, cannot zip"), + } + } + IndexTree::Node( + t.into_iter() + .zip(fields.into_iter()) + .map(|(t, f)| t.zip_list(f)) + .collect(), + ) + } + } + } + fn for_each<F>(&self, mut f: F) where F: FnMut(&Vec<Index>, &T), diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 49249fc7d24b576ad3bd7cc794ef3b7303d63626..e54e88fc13ee98373384e6641ed5ded76dd34b8e 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -151,6 +151,10 @@ pub fn compile_ir( if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } + add_pass!(pm, verify, SROA); + if x_dot { + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); + } add_pass!(pm, verify, Inline); // Run SROA pretty early (though after inlining which can make SROA more effective) so that // CCP, GVN, etc. can work on the result of SROA diff --git a/juno_samples/products.jn b/juno_samples/products.jn index a6dd88620ffdf20dc5451f535bf3f3da0e1bd929..00b66aab5835849c24b11c3af643761f0a8d000f 100644 --- a/juno_samples/products.jn +++ b/juno_samples/products.jn @@ -1,5 +1,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) { - return (x, y); + let res = (x, y); + if x < 13 { res = (x + 1, y); } + return res; } fn test(x : i32, y : f32) -> (i32, f32) {