diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index fa7b55be35426fc3db15124602cee5471c2948df..f62c00c15f9e8715254c515834d8fe63c2715539 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -983,7 +983,7 @@ impl Constant { Constant::Float64(ord) => *ord == OrderedFloat::<f64>(1.0), _ => false, } - } + } } impl DynamicConstant { diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 68693e8b78a85d71928cfb4dc295cacd864a8189..92d52a716710d255ff9d1a0ef8a9de20a7c88fa9 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -677,7 +677,9 @@ fn ccp_flow_function( (BinaryOperator::RSh, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Some(Constant::UnsignedInteger64(left_val >> right_val)), _ => panic!("Unsupported combination of binary operation and constant values. Did typechecking succeed?") }; - new_cons.map(|c| ConstantLattice::Constant(c)).unwrap_or(ConstantLattice::bottom()) + new_cons + .map(|c| ConstantLattice::Constant(c)) + .unwrap_or(ConstantLattice::bottom()) } else if (left_constant.is_top() && !right_constant.is_bottom()) || (!left_constant.is_bottom() && right_constant.is_top()) { diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index f9b8b494aef571d58ab25ee34521be21bbd5097a..2444fdb4d5d69dfa64049f4342a6489ab55d6117 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -359,14 +359,15 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.dynamic_constants.borrow() } - pub fn get_users(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ { self.mut_def_use[id.idx()].iter().map(|x| *x) } pub fn get_uses(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ { get_uses(&self.function.nodes[id.idx()]) - .as_ref().into_iter().map(|x| *x) + .as_ref() + .into_iter() + .map(|x| *x) .collect::<Vec<_>>() // @(xrouth): wtf??? .into_iter() } @@ -794,83 +795,6 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } -pub type DenseNodeMap<T> = Vec<T>; -pub type SparseNodeMap<T> = HashMap<NodeID, T>; - -nest! { -// Is this something editor should give... Or is it just for analyses. -// -#[derive(Clone, Debug)] -pub struct NodeIterator<'a> { - pub direction: - #[derive(Clone, Debug, PartialEq)] - enum Direction { - Uses, - Users, - }, - visited: DenseNodeMap<bool>, - stack: Vec<NodeID>, - func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor. - // `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search. - stop_on: HashSet<NodeID>, // Don't add neighbors of these. -} -} - -pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { - let len = editor.func().nodes.len(); - NodeIterator { direction: Direction::Uses, visited: vec![false; len], stack: vec![node], func: editor, - stop_on: HashSet::new()} -} - -pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { - let len = editor.func().nodes.len(); - NodeIterator { direction: Direction::Users, visited: vec![false; len], stack: vec![node], func: editor, - stop_on: HashSet::new()} -} - -pub fn walk_all_uses_stop_on<'a>(node: NodeID, editor: &'a FunctionEditor<'a>, stop_on: HashSet<NodeID>) -> NodeIterator<'a> { - let len = editor.func().nodes.len(); - let uses = editor.get_uses(node).collect(); - NodeIterator { direction: Direction::Uses, visited: vec![false; len], stack: uses, func: editor, - stop_on,} -} - -pub fn walk_all_users_stop_on<'a>(node: NodeID, editor: &'a FunctionEditor<'a>, stop_on: HashSet<NodeID>) -> NodeIterator<'a> { - let len = editor.func().nodes.len(); - let users = editor.get_users(node).collect(); - NodeIterator { direction: Direction::Users, visited: vec![false; len], stack: users, func: editor, - stop_on,} -} - -impl<'a> Iterator for NodeIterator<'a> { - type Item = NodeID; - - fn next(&mut self) -> Option<Self::Item> { - while let Some(current) = self.stack.pop() { - - if !self.visited[current.idx()]{ - self.visited[current.idx()] = true; - - if !self.stop_on.contains(¤t) { - if self.direction == Direction::Uses { - for neighbor in self.func.get_uses(current) { - self.stack.push(neighbor) - } - } else { - for neighbor in self.func.get_users(current) { - self.stack.push(neighbor) - } - } - } - - return Some(current); - } - } - None - } -} - - #[cfg(test)] mod editor_tests { #[allow(unused_imports)] diff --git a/hercules_opt/src/fork_concat_split.rs b/hercules_opt/src/fork_concat_split.rs index ae4ce72e8f099bde1e22000c872f8118bde5c7e2..1339a38436bcf1db5a613d3cb121d1f67d612a2e 100644 --- a/hercules_opt/src/fork_concat_split.rs +++ b/hercules_opt/src/fork_concat_split.rs @@ -43,7 +43,7 @@ pub fn fork_split( .collect(); editor.edit(|mut edit| { - // Create the forks and a thread ID per fork. + // Create the forks and a thread ID per fork. let mut acc_fork = fork_control; let mut new_tids = vec![]; for factor in factors { diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 8f6a98c4b341c5691396d22eec9a5896b2a1387b..435e63b6eb0a2b8cc91adbb50451a0caddd2b16a 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -37,7 +37,7 @@ use crate::FunctionEditor; // Simplify factors through max enum Factor { Max(usize, DynamicConstantID), - Normal(usize, DynamicConstantID) + Normal(usize, DynamicConstantID), } impl Factor { @@ -49,7 +49,6 @@ impl Factor { } } - struct GuardedFork { fork: NodeID, join: NodeID, @@ -66,10 +65,7 @@ fn guarded_fork( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, node: NodeID, -) -> Option< - GuardedFork -> { - +) -> Option<GuardedFork> { let function = editor.func(); // Identify fork nodes @@ -77,21 +73,24 @@ fn guarded_fork( return None; }; - let factors = factors.iter().enumerate().map(|(idx, dc)| { // FIXME: Can we hide .idx() in an impl Index or something so we don't index Vec<Nodes> iwht DynamicConstantId.idx() - let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else {return Factor::Normal(idx, *dc)}; + let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else { + return Factor::Normal(idx, *dc); + }; // There really needs to be a better way to work w/ associativity. - let binding = [(l,r), (r,l)]; + let binding = [(l, r), (r, l)]; let id = binding.iter().find_map(|(a, b)| { - let DynamicConstant::Constant(1) = *editor.get_dynamic_constant(*a) else {return None}; + let DynamicConstant::Constant(1) = *editor.get_dynamic_constant(*a) else { + return None; + }; Some(b) }); - + match id { Some(v) => Factor::Max(idx, *v), - None => Factor::Normal(idx, *dc) + None => Factor::Normal(idx, *dc), } }); @@ -121,32 +120,42 @@ fn guarded_fork( // branchIdx == 1 means the true branch so we want the condition to be // 0 < n or n > 0 if branch_idx == 1 { - [(left, BinaryOperator::LT, right), (right, BinaryOperator::GT, left)].iter().find_map(|(pattern_zero, pattern_op, pattern_factor)| - { + [ + (left, BinaryOperator::LT, right), + (right, BinaryOperator::GT, left), + ] + .iter() + .find_map(|(pattern_zero, pattern_op, pattern_factor)| { // Match Op if op != *pattern_op { - return None + return None; } // Match Zero - if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) || editor.node(pattern_zero).is_zero_dc(&editor.get_dynamic_constants())) { - return None + if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) + || editor + .node(pattern_zero) + .is_zero_dc(&editor.get_dynamic_constants())) + { + return None; } // Match Factor let factor = factors.clone().find(|factor| { - // This clone on the dc is painful. - match (&function.nodes[pattern_factor.idx()], editor.get_dynamic_constant(factor.get_id()).clone()) { + // This clone on the dc is painful. + match ( + &function.nodes[pattern_factor.idx()], + editor.get_dynamic_constant(factor.get_id()).clone(), + ) { (Node::Constant { id }, DynamicConstant::Constant(v)) => { - let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id) else { + let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id) + else { return false; }; - pattern_v == (v as u64) - }, - (Node::DynamicConstant { id }, _) => { - *id == factor.get_id() - }, - _ => false - } + pattern_v == (v as u64) + } + (Node::DynamicConstant { id }, _) => *id == factor.get_id(), + _ => false, + } }); // return Factor factor @@ -155,35 +164,48 @@ fn guarded_fork( // branchIdx == 0 means the false branch so we want the condition to be // n < 0 or 0 > n else if branch_idx == 0 { - [(right, BinaryOperator::LT, left), (left, BinaryOperator::GT, right)].iter().find_map(|(pattern_zero, pattern_op, pattern_factor)| - { + [ + (right, BinaryOperator::LT, left), + (left, BinaryOperator::GT, right), + ] + .iter() + .find_map(|(pattern_zero, pattern_op, pattern_factor)| { // Match Op if op != *pattern_op { - return None + return None; } // Match Zero - if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) || editor.node(pattern_zero).is_zero_dc(&editor.get_dynamic_constants())) { - return None + if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) + || editor + .node(pattern_zero) + .is_zero_dc(&editor.get_dynamic_constants())) + { + return None; } // Match Factor - let factor = factors.clone().find(|factor| function.nodes[pattern_factor.idx()].try_dynamic_constant() == Some(factor.get_id())); + // FIXME: Implement dc / constant matching as in case where branch_idx == 1 + let factor = factors.clone().find(|factor| { + function.nodes[pattern_factor.idx()].try_dynamic_constant() + == Some(factor.get_id()) + }); // return Factor factor - }) + }) } else { None } }; - let Some(factor) = factor else {return None}; + let Some(factor) = factor else { return None }; // Identify the join node and its users let join_id = fork_join_map.get(&node)?; // Find the unique control use of the join; if it's not a region we can't // eliminate this guard - let join_control = editor.get_users(*join_id) + let join_control = editor + .get_users(*join_id) .filter(|n| function.nodes[n.idx()].is_region()) .collect::<Vec<_>>(); if join_control.len() != 1 { @@ -218,14 +240,15 @@ fn guarded_fork( let else_branch = *selection; if else_branch == branch_idx { return None; - } + } if if_node2 != if_node { return None; } // Finally, identify the phi nodes associated with the region and match // them with the reduce nodes of the fork-join - let reduce_nodes = editor.get_users(*join_id) + let reduce_nodes = editor + .get_users(*join_id) .filter(|n| function.nodes[n.idx()].is_reduce()) .collect::<HashSet<_>>(); // Construct a map from phi nodes indices to the reduce node index @@ -268,7 +291,7 @@ fn guarded_fork( return None; } - let mut phi_nodes = phi_nodes + let phi_nodes = phi_nodes .into_iter() .map(|(phi, red)| (phi, red.unwrap())) .collect::<HashMap<_, _>>(); @@ -288,7 +311,7 @@ fn guarded_fork( guard_pred: if_pred, guard_join_region: join_control, phi_reduce_map: phi_nodes, - factor + factor, }) } @@ -297,39 +320,57 @@ fn guarded_fork( * Deletes nodes by setting nodes to gravestones. Works with a function already * containing gravestones. */ -pub fn fork_guard_elim( - editor: &mut FunctionEditor, - fork_join_map: &HashMap<NodeID, NodeID>, -) { - let guard_info = editor.node_ids() +pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { + let guard_info = editor + .node_ids() .filter_map(|node| guarded_fork(editor, fork_join_map, node)) .collect::<Vec<_>>(); - // (fork_node, factors, guard_node, guard_proj1, guard_proj2, guard_pred, map) - for GuardedFork {fork, join, fork_taken_proj, fork_skipped_proj, guard_pred, phi_reduce_map, factor, guard_if, guard_join_region } in guard_info { + for GuardedFork { + fork, + join, + fork_taken_proj, + fork_skipped_proj, + guard_pred, + phi_reduce_map, + factor, + guard_if, + guard_join_region, + } in guard_info + { let new_fork_info = if let Factor::Max(idx, dc) = factor { - let Node::Fork { control, mut factors } = editor.func().nodes[fork.idx()].clone() else {unreachable!()}; + let Node::Fork { + control, + mut factors, + } = editor.func().nodes[fork.idx()].clone() + else { + unreachable!() + }; factors[idx] = dc; - let new_fork = Node::Fork { control: guard_pred, factors }; + let new_fork = Node::Fork { + control: guard_pred, + factors, + }; Some(new_fork) } else { None }; editor.edit(|mut edit| { - edit = edit.replace_all_uses_where(fork_taken_proj, guard_pred, |usee| *usee == fork)?; + edit = + edit.replace_all_uses_where(fork_taken_proj, guard_pred, |usee| *usee == fork)?; edit = edit.delete_node(guard_if)?; edit = edit.delete_node(fork_taken_proj)?; edit = edit.delete_node(fork_skipped_proj)?; edit = edit.replace_all_uses(guard_join_region, join)?; edit = edit.delete_node(guard_join_region)?; - // Delete region node + // Delete region node for (phi, reduce) in phi_reduce_map.iter() { edit = edit.replace_all_uses(*phi, *reduce)?; edit = edit.delete_node(*phi)?; } - + if let Some(new_fork_info) = new_fork_info { let new_fork = edit.add_node(new_fork_info); edit = edit.replace_all_uses(fork, new_fork)?; diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 79fedcdcfa37ca41fe64de1f894541df93ee46d9..14145f57e4c80f5299b862a97fadc42c46aa157d 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::ops::Sub; -extern crate hercules_ir; extern crate bimap; +extern crate hercules_ir; use itertools::Itertools; @@ -26,32 +26,45 @@ use crate::{DenseNodeMap, FunctionEditor, Loop, SparseNodeMap}; type ForkID = usize; /** Places each reduce node into its own fork */ -pub fn default_reduce_partition(editor: &FunctionEditor, fork: NodeID, join: NodeID) -> SparseNodeMap<ForkID> { +pub fn default_reduce_partition( + editor: &FunctionEditor, + fork: NodeID, + join: NodeID, +) -> SparseNodeMap<ForkID> { let mut map = SparseNodeMap::new(); - editor.get_users(join) + editor + .get_users(join) .filter(|id| editor.func().nodes[id.idx()].is_reduce()) .enumerate() - .for_each(|(fork, reduce)| { map.insert(reduce, fork); }); + .for_each(|(fork, reduce)| { + map.insert(reduce, fork); + }); map } -// TODO: Refine these conditions. +// TODO: Refine these conditions. /** */ -pub fn find_reduce_dependencies<'a>(function: &'a Function, reduce: NodeID, fork: NodeID -) -> impl IntoIterator<Item = NodeID> + 'a -{ +pub fn find_reduce_dependencies<'a>( + function: &'a Function, + reduce: NodeID, + fork: NodeID, +) -> impl IntoIterator<Item = NodeID> + 'a { let len = function.nodes.len(); - let mut visited: DenseNodeMap<bool> = vec![false; len]; let mut depdendent: DenseNodeMap<bool> = vec![false; len]; // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it? - fn recurse(function: &Function, node: NodeID, fork: NodeID, - dependent_map: &mut DenseNodeMap<bool>, visited: &mut DenseNodeMap<bool> - ) -> () { // return through dependent_map { + fn recurse( + function: &Function, + node: NodeID, + fork: NodeID, + dependent_map: &mut DenseNodeMap<bool>, + visited: &mut DenseNodeMap<bool>, + ) -> () { + // return through dependent_map { if visited[node.idx()] { return; @@ -70,13 +83,13 @@ pub fn find_reduce_dependencies<'a>(function: &'a Function, reduce: NodeID, fork for used in uses { recurse(function, *used, fork, dependent_map, visited); } - + dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a); return; } // Note: HACKY, the condition wwe want is 'all nodes on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph) - // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node + // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce) // the current solution is just to mark the reduce as dependent at the start of traversing the graph. depdendent[reduce.idx()] = true; @@ -84,42 +97,52 @@ pub fn find_reduce_dependencies<'a>(function: &'a Function, reduce: NodeID, fork recurse(function, reduce, fork, &mut depdendent, &mut visited); // Return node IDs that are dependent - let a: Vec<_> = depdendent.iter().enumerate() - .filter_map(|(idx, dependent)| if *dependent {Some(NodeID::new(idx))} else {None}) + let a: Vec<_> = depdendent + .iter() + .enumerate() + .filter_map(|(idx, dependent)| { + if *dependent { + Some(NodeID::new(idx)) + } else { + None + } + }) .collect(); a } -pub fn copy_subgraph(editor: &mut FunctionEditor, subgraph: HashSet<NodeID>) --> (HashSet<NodeID>, HashMap<NodeID, NodeID>, Vec<(NodeID, NodeID)>) // set of all nodes, set of new nodes, outside node. s.t old node-> outside node exists as an edge. +pub fn copy_subgraph( + editor: &mut FunctionEditor, + subgraph: HashSet<NodeID>, +) -> ( + HashSet<NodeID>, + HashMap<NodeID, NodeID>, + Vec<(NodeID, NodeID)>, +) // set of all nodes, set of new nodes, outside node. s.t old node-> outside node exists as an edge. { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); let mut new_nodes: HashSet<NodeID> = HashSet::new(); - + // Copy nodes for old_id in subgraph.iter() { - editor.edit(|mut edit| - { - let new_id = edit.copy_node(*old_id); - map.insert(*old_id, new_id); - new_nodes.insert(new_id); - Ok(edit) - } - ); + editor.edit(|mut edit| { + let new_id = edit.copy_node(*old_id); + map.insert(*old_id, new_id); + new_nodes.insert(new_id); + Ok(edit) + }); } // Update edges to new nodes for old_id in subgraph.iter() { // Replace all uses of old_id w/ new_id, where the use is in new_node - editor.edit(|edit| - { - edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id)) - } - ); + editor.edit(|edit| { + edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id)) + }); } - // Get all users that aren't in new_nodes. + // Get all users that aren't in new_nodes. let mut outside_users = Vec::new(); for node in new_nodes.iter() { @@ -133,68 +156,67 @@ pub fn copy_subgraph(editor: &mut FunctionEditor, subgraph: HashSet<NodeID>) (new_nodes, map, outside_users) } -pub fn fork_fission<'a> ( +pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, control_subgraph: &Subgraph, types: &Vec<TypeID>, loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, -)-> () { - let forks: Vec<_> = editor.func().nodes.iter().enumerate().filter_map(|(idx, node)| { - if node.is_fork() { - Some(NodeID::new(idx)) - } else {None} - }).collect(); +) -> () { + let forks: Vec<_> = editor + .func() + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| { + if node.is_fork() { + Some(NodeID::new(idx)) + } else { + None + } + }) + .collect(); let mut control_pred = NodeID::new(0); // This does the reduction fission: - if true { - for fork in forks.clone() { - // FIXME: If there is control in between fork and join, give up. - let join = fork_join_map[&fork]; - let join_pred = editor.func().nodes[join.idx()].try_join().unwrap(); - if join_pred != fork { - todo!("Can't do fork fission on nodes with internal control") - // Inner control LOOPs are hard - // inner control in general *should* work right now without modifications. - } - let reduce_partition = default_reduce_partition(editor, fork, join); - - let (new_fork, new_join) = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork); - // control_pred = new_join; + for fork in forks.clone() { + // FIXME: If there is control in between fork and join, don't just give up. + let join = fork_join_map[&fork]; + let join_pred = editor.func().nodes[join.idx()].try_join().unwrap(); + if join_pred != fork { + todo!("Can't do fork fission on nodes with internal control") + // Inner control LOOPs are hard + // inner control in general *should* work right now without modifications. } - } else { - // This does the bufferization: - let edge = (NodeID::new(15), NodeID::new(16)); - // let edge = (NodeID::new(4), NodeID::new(9)); - let mut edges = HashSet::new(); - edges.insert(edge); - let fork = loop_tree.bottom_up_loops().first().unwrap().0; - //let fork = forks.first().unwrap(); - fork_bufferize_fission_helper(editor, fork_join_map, edges, NodeID::new(0), types, fork); + let reduce_partition = default_reduce_partition(editor, fork, join); + + let (new_fork, new_join) = + fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork); + // control_pred = new_join; } } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ -pub fn fork_bufferize_fission_helper<'a> ( +pub fn fork_bufferize_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, - bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. - original_control_pred: NodeID, // What the new fork connects to. + bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. + original_control_pred: NodeID, // What the new fork connects to. types: &Vec<TypeID>, fork: NodeID, -) -> (NodeID, NodeID) { // Returns the two forks that it generates. - - // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. +) -> (NodeID, NodeID) { + // Returns the two forks that it generates. + + // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. - // Copy fork + control intermediates + join to new fork + join, - // How does control get partitioned? + // Copy fork + control intermediates + join to new fork + join, + // How does control get partitioned? // (depending on how it affects the data nodes on each side of the bufferized_edges) - // may end up in each loop, fix me later. + // may end up in each loop, fix me later. // place new fork + join after join of first. - // Only handle fork+joins with no inner control for now. + // Only handle fork+joins with no inner control for now. // Create fork + join + Thread control let join = fork_join_map[&fork]; @@ -204,77 +226,95 @@ pub fn fork_bufferize_fission_helper<'a> ( editor.edit(|mut edit| { new_join_id = edit.add_node(Node::Join { control: fork }); let factors = edit.get_node(fork).try_fork().unwrap().1.clone(); - new_fork_id = edit.add_node(Node::Fork { control: new_join_id, factors: factors.into() }); + new_fork_id = edit.add_node(Node::Fork { + control: new_join_id, + factors: factors.into(), + }); edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) }); for (src, dst) in bufferized_edges { // FIXME: Disgusting cloning and allocationing and iterators. - let factors: Vec<_> = editor.func().nodes[fork.idx()].try_fork().unwrap().1.iter().cloned().collect(); - editor.edit(|mut edit| - { - // Create write to buffer - - let thread_stuff_it = factors.into_iter().enumerate(); - - // FIxme: try to use unzip here? Idk why it wasn't working. - let (tids) = thread_stuff_it.clone().map(|(dim, factor)| - ( - edit.add_node(Node::ThreadID { control: fork, dimension: dim }) - ) - ); - - let array_dims = thread_stuff_it.clone().map(|(dim, factor)| - ( - factor - ) - ); - - // Assume 1-d fork only for now. - // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); - let write = edit.add_node(Node::Write { collect: NodeID::new(0), data: src, indices: vec![position_idx].into() }); - let ele_type = types[src.idx()]; - let empty_buffer = edit.add_type(hercules_ir::Type::Array(ele_type, array_dims.collect::<Vec<_>>().into_boxed_slice())); - let empty_buffer = edit.add_zero_constant(empty_buffer); - let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); - let reduce = Node::Reduce { control: new_join_id, init: empty_buffer, reduct: write }; - let reduce = edit.add_node(reduce); - // Fix write node - edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; - - - // Create read from buffer - let (tids) = thread_stuff_it.clone().map(|(dim, factor)| - ( - edit.add_node(Node::ThreadID { control: new_fork_id, dimension: dim }) - ) - ); - - let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); - - let read = edit.add_node(Node::Read { collect: reduce, indices: vec![position_idx].into() }); - - edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; - - Ok(edit) - } - ); + let factors: Vec<_> = editor.func().nodes[fork.idx()] + .try_fork() + .unwrap() + .1 + .iter() + .cloned() + .collect(); + editor.edit(|mut edit| { + // Create write to buffer + + let thread_stuff_it = factors.into_iter().enumerate(); + + // FIxme: try to use unzip here? Idk why it wasn't working. + let (tids) = thread_stuff_it.clone().map(|(dim, factor)| { + (edit.add_node(Node::ThreadID { + control: fork, + dimension: dim, + })) + }); + + let array_dims = thread_stuff_it.clone().map(|(dim, factor)| (factor)); + + // Assume 1-d fork only for now. + // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); + let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + let write = edit.add_node(Node::Write { + collect: NodeID::new(0), + data: src, + indices: vec![position_idx].into(), + }); + let ele_type = types[src.idx()]; + let empty_buffer = edit.add_type(hercules_ir::Type::Array( + ele_type, + array_dims.collect::<Vec<_>>().into_boxed_slice(), + )); + let empty_buffer = edit.add_zero_constant(empty_buffer); + let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); + let reduce = Node::Reduce { + control: new_join_id, + init: empty_buffer, + reduct: write, + }; + let reduce = edit.add_node(reduce); + // Fix write node + edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; + + // Create read from buffer + let (tids) = thread_stuff_it.clone().map(|(dim, factor)| { + (edit.add_node(Node::ThreadID { + control: new_fork_id, + dimension: dim, + })) + }); + + let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); + + let read = edit.add_node(Node::Read { + collect: reduce, + indices: vec![position_idx].into(), + }); + + edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; + + Ok(edit) + }); } (fork, new_fork_id) - } /** Split a 1D fork into a separate fork for each reduction. */ -pub fn fork_reduce_fission_helper<'a> ( +pub fn fork_reduce_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split, - original_control_pred: NodeID, // What the new fork connects to. + original_control_pred: NodeID, // What the new fork connects to. fork: NodeID, -) -> (NodeID, NodeID) { // returns Fork, Join pair { +) -> (NodeID, NodeID) { + // returns Fork, Join pair { let join = fork_join_map[&fork]; // If there is control in between then j give up. @@ -284,16 +324,16 @@ pub fn fork_reduce_fission_helper<'a> ( // Get nodes to copy // let factors: Box<[DynamicConstantID]> = edit..nodes[fork.idx()].try_fork().unwrap().1.into(); - // None of this matters, just assume we have DCE for control flow. + // None of this matters, just assume we have DCE for control flow. // Make new fork put it after the existing loop (deal with dependencies later.) // Make new join, put it after fork (FIXME: THIS IS WRONG) // Make copies of all control + data nodes, including the reduce and join, with equivalent uses / users, mark them as NEW - // - Need an editor utility to copy a subsection of the graph. + // - Need an editor utility to copy a subsection of the graph. // 1) Edges going into the subsection stay the same, i.e something new still *uses* something old. - // 2) Edges leaving the subsection need to be handled by the user, (can't force outgoing new edges into nodes) - // return a list of outgoing (but unattatached) edges + the old destination to the programmer. - - // Important edges are: Reduces, + // 2) Edges leaving the subsection need to be handled by the user, (can't force outgoing new edges into nodes) + // return a list of outgoing (but unattatached) edges + the old destination to the programmer. + + // Important edges are: Reduces, // NOTE: // Say two reduce are in a fork, s.t reduce A depends on reduce B @@ -306,13 +346,13 @@ pub fn fork_reduce_fission_helper<'a> ( // for now, DONT HANDLE IT. LOL. // NOTE: - // + // // Replace all - // Replace all uses of (fork, reduce, ) w/ predicate that they are the newly copied nodes. + // Replace all uses of (fork, reduce, ) w/ predicate that they are the newly copied nodes. // repalce uses - let mut new_fork = NodeID::new(0); + let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); // Gets everything between fork & join that this reduce needs. (ALL CONTROL) @@ -321,28 +361,30 @@ pub fn fork_reduce_fission_helper<'a> ( let function = editor.func(); let subgraph = find_reduce_dependencies(function, reduce, fork); - + let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect(); - + subgraph.insert(join); subgraph.insert(fork); subgraph.insert(reduce); - + // println!("subgraph for {:?}: \n{:?}", reduce, subgraph); - + let (new_nodes, mapping, _) = copy_subgraph(editor, subgraph); - + // println!("new_nodes: {:?} ", new_nodes); // println!("mapping: {:?} ",mapping); - + new_fork = mapping[&fork]; new_join = mapping[&join]; - + editor.edit(|mut edit| { // Atttach new_fork after control_pred let (old_control_pred, factors) = edit.get_node(new_fork).try_fork().unwrap().clone(); - edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| *usee == new_fork)?; - + edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { + *usee == new_fork + })?; + // Replace uses of reduce edit = edit.replace_all_uses(reduce, mapping[&reduce])?; Ok(edit) @@ -351,7 +393,6 @@ pub fn fork_reduce_fission_helper<'a> ( new_control_pred = new_join; } - editor.edit(|mut edit| { // Replace original join w/ new final join edit = edit.replace_all_uses_where(join, new_join, |_| true)?; @@ -359,7 +400,7 @@ pub fn fork_reduce_fission_helper<'a> ( // Delete original join (all reduce users have been moved) edit = edit.delete_node(join)?; - // Replace all users of original fork, and then delete it, leftover users will be DCE'd. + // Replace all users of original fork, and then delete it, leftover users will be DCE'd. edit = edit.replace_all_uses(fork, new_fork)?; edit.delete_node(fork) }); @@ -372,14 +413,16 @@ pub fn fork_coalesce( loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, ) -> bool { - - let fork_joins = loops - .bottom_up_loops() - .into_iter() - .filter_map(|(k, _)| if editor.func().nodes[k.idx()].is_fork() {Some(k)} else {None}); + let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| { + if editor.func().nodes[k.idx()].is_fork() { + Some(k) + } else { + None + } + }); let fork_joins: Vec<_> = fork_joins.collect(); - // FIXME: postorder traversal. + // FIXME: postorder traversal. // Fixme: This could give us two forks that aren't actually ancestors / related, but then the helper will just retunr false early. //for (inner, outer) in fork_joins.windows(2) { @@ -391,7 +434,7 @@ pub fn fork_coalesce( return false; } -/** Opposite of fork split, takes two fork-joins +/** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. */ pub fn fork_coalesce_helper( @@ -400,29 +443,43 @@ pub fn fork_coalesce_helper( inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, ) -> bool { - // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; let inner_join = fork_join_map[&inner_fork]; - - let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner - // FIXME: Iterate all control uses of joins to really collect all reduces - // (reduces can be attached to inner control) - for outer_reduce in editor.get_users(outer_join).filter(|node| editor.func().nodes[node.idx()].is_reduce()) { + let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner + // FIXME: Iterate all control uses of joins to really collect all reduces + // (reduces can be attached to inner control) + for outer_reduce in editor + .get_users(outer_join) + .filter(|node| editor.func().nodes[node.idx()].is_reduce()) + { // check that inner reduce is of the inner join - let (outer_control, outer_init, outer_reduct) = editor.func().nodes[outer_reduce.idx()].try_reduce().unwrap(); + let (outer_control, outer_init, outer_reduct) = editor.func().nodes[outer_reduce.idx()] + .try_reduce() + .unwrap(); let inner_reduce = outer_reduct; let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()]; - let Node::Reduce { control: inner_control, init: inner_init, reduct: inner_reduct } = inner_reduce_node else {return false}; + let Node::Reduce { + control: inner_control, + init: inner_init, + reduct: inner_reduct, + } = inner_reduce_node + else { + return false; + }; // FIXME: check this condition better (i.e reduce might not be attached to join) - if *inner_control != inner_join {return false}; - if *inner_init != outer_reduce {return false}; + if *inner_control != inner_join { + return false; + }; + if *inner_init != outer_reduce { + return false; + }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { return false; @@ -431,16 +488,27 @@ pub fn fork_coalesce_helper( } } - // Check Control between joins and forks - // FIXME: use control subgraph. - let Some(user) = editor.get_users(outer_fork) - .filter(|node| editor.func().nodes[node.idx()].is_control()).next() else { return false}; + // Check for control between join-join and fork-fork + let Some(user) = editor + .get_users(outer_fork) + .filter(|node| editor.func().nodes[node.idx()].is_control()) + .next() + else { + return false; + }; if user != inner_fork { return false; } - let Some(user) = editor.get_users(inner_join).filter(|node| editor.func().nodes[node.idx()].is_control()).next() else { return false}; + let Some(user) = editor + .get_users(inner_join) + .filter(|node| editor.func().nodes[node.idx()].is_control()) + .next() + else { + return false; + }; + if user != outer_join { return false; } @@ -449,24 +517,30 @@ pub fn fork_coalesce_helper( // Add outers dimension to front of inner fork. // Fuse reductions // - Initializer becomes outer initializer - // - + // - // Replace uses of outer fork w/ inner fork. // Replace uses of outer join w/ inner join. // Delete outer fork-join - let inner_tids: Vec<NodeID> = editor.get_users(inner_fork).filter(|node| editor.func().nodes[node.idx()].is_thread_id()).collect(); + let inner_tids: Vec<NodeID> = editor + .get_users(inner_fork) + .filter(|node| editor.func().nodes[node.idx()].is_thread_id()) + .collect(); let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap(); let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap(); let num_outer_dims = outer_dims.len(); let mut new_factors = outer_dims.to_vec(); - // FIXME: Might need to be added the other way. + // CHECK ME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); - + for tid in inner_tids { let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); - let new_tid = Node::ThreadID { control: fork, dimension: dim + num_outer_dims}; + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); @@ -475,13 +549,18 @@ pub fn fork_coalesce_helper( }); } - // Fuse Reductions + // Fuse Reductions for (outer_reduce, inner_reduce) in pairs { - let (outer_control, outer_init, outer_reduct) = editor.func().nodes[outer_reduce.idx()].try_reduce().unwrap(); - let (inner_control, inner_init, inner_reduct) = editor.func().nodes[inner_reduce.idx()].try_reduce().unwrap(); + let (outer_control, outer_init, outer_reduct) = editor.func().nodes[outer_reduce.idx()] + .try_reduce() + .unwrap(); + let (inner_control, inner_init, inner_reduct) = editor.func().nodes[inner_reduce.idx()] + .try_reduce() + .unwrap(); editor.edit(|mut edit| { // Set inner init to outer init. - edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; + edit = + edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; @@ -489,22 +568,22 @@ pub fn fork_coalesce_helper( }); } - editor.edit( - |mut edit| { - let new_fork = Node::Fork {control: outer_pred, factors: new_factors.into()}; - let new_fork = edit.add_node(new_fork); - - edit = edit.replace_all_uses(inner_fork, new_fork)?; - edit = edit.replace_all_uses(outer_fork, new_fork)?; - edit = edit.replace_all_uses(outer_join, inner_join)?; - edit = edit.delete_node(outer_join)?; - edit = edit.delete_node(inner_fork)?; - edit = edit.delete_node(outer_fork)?; - - Ok(edit) - } - ); + editor.edit(|mut edit| { + let new_fork = Node::Fork { + control: outer_pred, + factors: new_factors.into(), + }; + let new_fork = edit.add_node(new_fork); + + edit = edit.replace_all_uses(inner_fork, new_fork)?; + edit = edit.replace_all_uses(outer_fork, new_fork)?; + edit = edit.replace_all_uses(outer_join, inner_join)?; + edit = edit.delete_node(outer_join)?; + edit = edit.delete_node(inner_fork)?; + edit = edit.delete_node(outer_fork)?; + + Ok(edit) + }); true - -} \ No newline at end of file +} diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 82358f91dd810417f6f6451b09ea4ed05c752051..c7acfe6b66b0a3bd5cbe50190673a4a7289368f2 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -1,5 +1,5 @@ -extern crate hercules_ir; extern crate bitvec; +extern crate hercules_ir; extern crate nestify; use std::collections::HashMap; @@ -46,30 +46,40 @@ pub fn forkify( loops: &LoopTree, ) -> bool { let natural_loops = loops - .bottom_up_loops() - .into_iter() - .filter(|(k, _)| editor.func().nodes[k.idx()].is_region()); + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_region()); let natural_loops: Vec<_> = natural_loops.collect(); - + for l in natural_loops { - // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses. - if forkify_loop(editor, control_subgraph, fork_join_map, &Loop { header: l.0, control: l.1.clone()}) { + // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses. + if forkify_loop( + editor, + control_subgraph, + fork_join_map, + &Loop { + header: l.0, + control: l.1.clone(), + }, + ) { return true; } - } + } return false; } - /** Given a node used as a loop bound, return a dynamic constant ID. */ -pub fn get_node_as_dc(editor: &mut FunctionEditor, node: NodeID) -> Result<DynamicConstantID, String> { +pub fn get_node_as_dc( + editor: &mut FunctionEditor, + node: NodeID, +) -> Result<DynamicConstantID, String> { // Check for a constant used as loop bound. match editor.node(node) { - Node::DynamicConstant{id: dynamic_constant_id} => { - Ok(*dynamic_constant_id) - } - Node::Constant {id: constant_id} => { + Node::DynamicConstant { + id: dynamic_constant_id, + } => Ok(*dynamic_constant_id), + Node::Constant { id: constant_id } => { let dc = match *editor.get_constant(*constant_id) { Constant::Integer8(x) => DynamicConstant::Constant(x as _), Constant::Integer16(x) => DynamicConstant::Constant(x as _), @@ -83,23 +93,21 @@ pub fn get_node_as_dc(editor: &mut FunctionEditor, node: NodeID) -> Result<Dynam }; let mut b = DynamicConstantID::new(0); - editor.edit( - |mut edit| { - b = edit.add_dynamic_constant(dc); - Ok(edit) - } - ); - // Return the ID of the dynamic constant that is generated from the constant + editor.edit(|mut edit| { + b = edit.add_dynamic_constant(dc); + Ok(edit) + }); + // Return the ID of the dynamic constant that is generated from the constant // or dynamic constant that is the existing loop bound - Ok(b) + Ok(b) } - _ => Err("Blah".to_owned()) + _ => Err("Blah".to_owned()), } } -fn all_same_variant<I, T>(mut iter: I) -> bool +fn all_same_variant<I, T>(mut iter: I) -> bool where - I: Iterator<Item = T> + I: Iterator<Item = T>, { // Empty iterator case - return true let first = match iter.next() { @@ -109,60 +117,79 @@ where // Get discriminant of first item let first_discriminant = std::mem::discriminant(&first); - + // Check all remaining items have same discriminant iter.all(|x| std::mem::discriminant(&x) == first_discriminant) } /** - Top level function to convert natural loops with simple induction variables - into fork-joins. - */ + Top level function to convert natural loops with simple induction variables + into fork-joins. +*/ pub fn forkify_loop( editor: &mut FunctionEditor, control_subgraph: &Subgraph, fork_join_map: &HashMap<NodeID, NodeID>, l: &Loop, ) -> bool { - let function = editor.func(); - let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else {return false}; + let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else { + return false; + }; - let LoopExit::Conditional { if_node: loop_if, condition_node } = loop_condition.clone() else {return false}; + let LoopExit::Conditional { + if_node: loop_if, + condition_node, + } = loop_condition.clone() + else { + return false; + }; // Compute loop variance let loop_variance = compute_loop_variance(editor, l); let ivs = compute_induction_vars(editor.func(), l, &loop_variance); let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition); - let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else {return false}; + let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else { + return false; + }; // FIXME: Make sure IV is not used outside the loop. // Get bound let bound = match canonical_iv { - InductionVariable::Basic { node, initializer, update, final_value } => final_value.map(|final_value| get_node_as_dc(editor, final_value)).and_then(|r| r.ok()), + InductionVariable::Basic { + node, + initializer, + update, + final_value, + } => final_value + .map(|final_value| get_node_as_dc(editor, final_value)) + .and_then(|r| r.ok()), InductionVariable::SCEV(node_id) => return false, }; - - let Some(bound_dc_id) = bound else {return false}; - + let Some(bound_dc_id) = bound else { + return false; + }; let function = editor.func(); - // Check if it is do-while loop. - let loop_exit_projection = editor.get_users(loop_if) + // Check if it is do-while loop. + let loop_exit_projection = editor + .get_users(loop_if) .filter(|id| !l.control[id.idx()]) .next() .unwrap(); - let loop_continue_projection = editor.get_users(loop_if) + let loop_continue_projection = editor + .get_users(loop_if) .filter(|id| l.control[id.idx()]) .next() .unwrap(); - - let loop_preds: Vec<_> = editor.get_uses(l.header) + + let loop_preds: Vec<_> = editor + .get_uses(l.header) .filter(|id| !l.control[id.idx()]) .collect(); @@ -172,71 +199,83 @@ pub fn forkify_loop( let loop_pred = loop_preds[0]; - if !editor.get_uses(l.header).contains(&loop_continue_projection) { + if !editor + .get_uses(l.header) + .contains(&loop_continue_projection) + { return false; } - // Get all phis used outside of the loop, they need to be reductionable. - // For now just assume all phis will be phis used outside of the loop, except for the canonical iv. - // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one - // we currently have. + // Get all phis used outside of the loop, they need to be reductionable. + // For now just assume all phis will be phis used outside of the loop, except for the canonical iv. + // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one + // we currently have. let loop_nodes = calculate_loop_nodes(editor, l); // // Check reductionable phis, only PHIs depending on the loop are considered, - let candidate_phis: Vec<_> = editor.get_users(l.header) - .filter(|id|function.nodes[id.idx()].is_phi()) + let candidate_phis: Vec<_> = editor + .get_users(l.header) + .filter(|id| function.nodes[id.idx()].is_phi()) .filter(|id| *id != canonical_iv.phi()) .collect(); - let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes).into_iter().collect(); - + let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes) + .into_iter() + .collect(); + // START EDITING - + // What we do is: // 1) Find a (the?) basic induction variable, create a ThreadID + Fork + Join over it. - // 2) Turn reductionable PHIs into reduces (including the redcutionable PHI) - // - a) If the PHI is the IV: - // Uses of the IV become: + // 2) Turn reductionable PHIs into reduces (including the redcutionable PHI) + // - a) If the PHI is the IV: + // Uses of the IV become: // 1) Inside the loop: Uses of the ThreadID // 2) Outside the loop: Uses of the reduction node. - // - b) if the PHI is not the IV: + // - b) if the PHI is not the IV: // Make it a reduce - - let function = editor.func(); + let function = editor.func(); // TOOD: Handle multiple loop body lasts. // If there are multiple candidates for loop body last, return false. - if editor.get_uses(loop_if) + if editor + .get_uses(loop_if) .filter(|id| l.control[id.idx()]) - .count() > 1 { - return false; - } + .count() + > 1 + { + return false; + } - let loop_body_last = editor.get_uses(loop_if) - .next() - .unwrap(); - - if reductionable_phis.iter() - .any(|phi| !matches!(phi, LoopPHI::Reductionable{..})) { - return false - } + let loop_body_last = editor.get_uses(loop_if).next().unwrap(); + + if reductionable_phis + .iter() + .any(|phi| !matches!(phi, LoopPHI::Reductionable { .. })) + { + return false; + } // 1) If there is any control between header and loop condition, exit. - let header_control_users: Vec<_> = editor.get_users(l.header) + let header_control_users: Vec<_> = editor + .get_users(l.header) .filter(|id| function.nodes[id.idx()].is_control()) .collect(); - + // Outside uses of IV, then exit; - if editor.get_users(canonical_iv.phi()).any(|node| !loop_nodes.contains(&node)) { - return false + if editor + .get_users(canonical_iv.phi()) + .any(|node| !loop_nodes.contains(&node)) + { + return false; } // Start Transformation: // Graft everyhting between header and loop condition // Attach join to right before header (after loop_body_last, unless loop body last *is* the header). - // Attach fork to right after loop_continue_projection. + // Attach fork to right after loop_continue_projection. // // Create fork and join nodes: let mut join_id = NodeID::new(0); @@ -255,25 +294,26 @@ pub fn forkify_loop( }; // // FIXME (@xrouth), handle control in loop body. - editor.edit( - |mut edit| { - let fork = Node::Fork { control: loop_pred, factors: Box::new([bound_dc_id])}; - fork_id = edit.add_node(fork); - - let join = Node::Join { - control: if l.header == loop_body_last { - fork_id - } else { - loop_body_last - }, - }; - - join_id = edit.add_node(join); + editor.edit(|mut edit| { + let fork = Node::Fork { + control: loop_pred, + factors: Box::new([bound_dc_id]), + }; + fork_id = edit.add_node(fork); + + let join = Node::Join { + control: if l.header == loop_body_last { + fork_id + } else { + loop_body_last + }, + }; + + join_id = edit.add_node(join); + + Ok(edit) + }); - Ok(edit) - } - ); - // let function = editor.func(); // let update = *zip( @@ -288,115 +328,101 @@ pub fn forkify_loop( // .next() // .unwrap() // .1; - + let function = editor.func(); let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); let dimension = factors.len() - 1; // Create ThreadID - editor.edit( - |mut edit| { - let thread_id = Node::ThreadID { - control: fork_id, - dimension: dimension, - }; - let thread_id_id = edit.add_node(thread_id); - - // let iv_reduce = Node::Reduce { - // control: join_id, - // init: basic_iv.initializer, - // reduct: update, - // }; - - // If a user occurs after the loop is finished, we replace it with the DC that is the IV bound, - // If a user occurs inside the loop, we replace it with the IV. - - // Replace uses that are inside with the thread id - edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { - loop_nodes.contains(node) - })?; - - // Replace uses that are outside with DC - 1. Or just give up. - let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id }); - edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| { - !loop_nodes.contains(node) - })?; - - edit.delete_node(canonical_iv.phi()) - } - ); + editor.edit(|mut edit| { + let thread_id = Node::ThreadID { + control: fork_id, + dimension: dimension, + }; + let thread_id_id = edit.add_node(thread_id); + + // let iv_reduce = Node::Reduce { + // control: join_id, + // init: basic_iv.initializer, + // reduct: update, + // }; + + // If a user occurs after the loop is finished, we replace it with the DC that is the IV bound, + // If a user occurs inside the loop, we replace it with the IV. + + // Replace uses that are inside with the thread id + edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { + loop_nodes.contains(node) + })?; + + // Replace uses that are outside with DC - 1. Or just give up. + let bound_dc_node = edit.add_node(Node::DynamicConstant { id: bound_dc_id }); + edit = edit.replace_all_uses_where(canonical_iv.phi(), bound_dc_node, |node| { + !loop_nodes.contains(node) + })?; + + edit.delete_node(canonical_iv.phi()) + }); for reduction_phi in reductionable_phis { - let LoopPHI::Reductionable { phi, data_cycle, continue_latch, is_associative } = reduction_phi else {continue}; + let LoopPHI::Reductionable { + phi, + data_cycle, + continue_latch, + is_associative, + } = reduction_phi + else { + continue; + }; let function = editor.func(); let init = *zip( editor.get_uses(l.header), - function.nodes[phi.idx()] - .try_phi() - .unwrap() - .1 - .iter(), - ) - .filter(|(c, _)| *c == loop_pred) - .next() - .unwrap() - .1; - - editor.edit( - |mut edit| { - let reduce = Node::Reduce { - control: join_id, - init, - reduct: continue_latch, - }; - let reduce_id = edit.add_node(reduce); - - edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?; - edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| !loop_nodes.contains(usee ) && *usee != reduce_id)?; - edit.delete_node(phi) - } - ); + function.nodes[phi.idx()].try_phi().unwrap().1.iter(), + ) + .filter(|(c, _)| *c == loop_pred) + .next() + .unwrap() + .1; + + editor.edit(|mut edit| { + let reduce = Node::Reduce { + control: join_id, + init, + reduct: continue_latch, + }; + let reduce_id = edit.add_node(reduce); + + edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?; + edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| { + !loop_nodes.contains(usee) && *usee != reduce_id + })?; + edit.delete_node(phi) + }); } - // Replace all uses of the loop header with the fork - editor.edit( - |mut edit| { - edit.replace_all_uses(l.header, fork_id) - } - ); + editor.edit(|mut edit| edit.replace_all_uses(l.header, fork_id)); - editor.edit( - |mut edit| { - edit.replace_all_uses(loop_continue_projection, fork_id) - } - ); + editor.edit(|mut edit| edit.replace_all_uses(loop_continue_projection, fork_id)); - editor.edit( - |mut edit| { - edit.replace_all_uses(loop_exit_projection, join_id) - } - ); + editor.edit(|mut edit| edit.replace_all_uses(loop_exit_projection, join_id)); // Get rid of loop condition // DCE should get these, but delete them ourselves because we are nice :) - editor.edit( - |mut edit| { - edit = edit.delete_node(loop_continue_projection)?; - edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. - edit = edit.delete_node(loop_exit_projection)?; - edit = edit.delete_node(loop_if)?; - edit = edit.delete_node(l.header)?; - Ok(edit) - } - ); + editor.edit(|mut edit| { + edit = edit.delete_node(loop_continue_projection)?; + edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. + edit = edit.delete_node(loop_exit_projection)?; + edit = edit.delete_node(loop_if)?; + edit = edit.delete_node(l.header)?; + Ok(edit) + }); return true; } - nest! { #[derive(Debug)] pub enum LoopPHI { @@ -414,56 +440,68 @@ nest! { impl LoopPHI { pub fn get_phi(&self) -> NodeID { match self { - LoopPHI::Reductionable {phi, data_cycle, ..} => *phi, + LoopPHI::Reductionable { + phi, data_cycle, .. + } => *phi, LoopPHI::LoopDependant(node_id) => *node_id, LoopPHI::UsedByDependant(node_id) => *node_id, } } } - -/** - Checks some conditions on loop variables that will need to be converted into reductions to be forkified. - To convert a phi into a reduce we need to check that every cycle containing the PHI does not contain any other PHI. - I think this restriction can be loosened (more specified) - - Every cycle *in the loop* containing the PHI does not contain any other PHI. Or something, IDK. - - - We also need to make it not control dependent on anything other than the loop header. */ -pub fn analyze_phis<'a>(editor: &'a FunctionEditor, natural_loop: &'a Loop, phis: &'a [NodeID], loop_nodes: &'a HashSet<NodeID>) - -> impl Iterator<Item = LoopPHI> + 'a -{ +/** +Checks some conditions on loop variables that will need to be converted into reductions to be forkified. + To convert a phi into a reduce we need to check that every cycle containing the PHI does not contain any other PHI. +I think this restriction can be loosened (more specified) + - Every cycle *in the loop* containing the PHI does not contain any other PHI. Or something, IDK. + - +We also need to make it not control dependent on anything other than the loop header. */ +pub fn analyze_phis<'a>( + editor: &'a FunctionEditor, + natural_loop: &'a Loop, + phis: &'a [NodeID], + loop_nodes: &'a HashSet<NodeID>, +) -> impl Iterator<Item = LoopPHI> + 'a { phis.into_iter().map(move |phi| { - let stop_on: HashSet<NodeID> = editor.node_ids().filter(|node| { - let data = &editor.func().nodes[node.idx()]; - - // External Phi - if let Node::Phi { control, data } = data { - if *control != natural_loop.header { - return true; + let stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { + let data = &editor.func().nodes[node.idx()]; + + // External Phi + if let Node::Phi { control, data } = data { + if *control != natural_loop.header { + return true; + } + // if !natural_loop.control[control.idx()] { + // return true; + // } } - // if !natural_loop.control[control.idx()] { - // return true; - // } - } - // External Reduce - if let Node::Reduce { control, init, reduct} = data { - if !natural_loop.control[control.idx()] { - return true; - } else { - return false; + // External Reduce + if let Node::Reduce { + control, + init, + reduct, + } = data + { + if !natural_loop.control[control.idx()] { + return true; + } else { + return false; + } } - } - // External Control - if data.is_control() {//&& !natural_loop.control[node.idx()] { - return true - } + // External Control + if data.is_control() { + //&& !natural_loop.control[node.idx()] { + return true; + } - return false; + return false; + }) + .collect(); - }).collect(); - - // TODO: We may need to stop on exiting the loop for looking for data cycles. + // TODO: We may need to stop on exiting the loop for looking for data cycles. let uses = walk_all_uses_stop_on(*phi, editor, stop_on.clone()); // .filter(|node| // { @@ -472,74 +510,88 @@ pub fn analyze_phis<'a>(editor: &'a FunctionEditor, natural_loop: &'a Loop, phis // }); let users = walk_all_users_stop_on(*phi, editor, stop_on.clone()); - let other_stop_on: HashSet<NodeID> = editor.node_ids().filter(|node| { - let data = &editor.func().nodes[node.idx()]; - - // Phi, Reduce - if let Node::Phi { control, data } = data { - return true; - } - - if let Node::Reduce { control, init, reduct} = data { - return true; - } + let other_stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { + let data = &editor.func().nodes[node.idx()]; - // External Control - if data.is_control() {//&& !natural_loop.control[node.idx()] { - return true - } + // Phi, Reduce + if let Node::Phi { control, data } = data { + return true; + } - return false; + if let Node::Reduce { + control, + init, + reduct, + } = data + { + return true; + } - }).collect(); + // External Control + if data.is_control() { + //&& !natural_loop.control[node.idx()] { + return true; + } + return false; + }) + .collect(); let mut uses_for_dependance = walk_all_users_stop_on(*phi, editor, other_stop_on); - + let set1: HashSet<_> = HashSet::from_iter(uses); let set2: HashSet<_> = HashSet::from_iter(users); let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect(); // If this phi uses any other phis the node is loop dependant, - // we use `phis` because this phi can actually contain the loop iv and its fine. + // we use `phis` because this phi can actually contain the loop iv and its fine. if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) { LoopPHI::LoopDependant(*phi) - } - // // If this phi is used by other phis in the loop, FIXME: include reduces, they are the same as phis right? - // // DOn't go through nodes that would become a reduction. + } + // // If this phi is used by other phis in the loop, FIXME: include reduces, they are the same as phis right? + // // DOn't go through nodes that would become a reduction. // else if set2.clone().iter().any(|node| phis.contains(node) && node != phi ) { // LoopPHI::UsedByDependant(*phi) // } else if intersection.clone().iter().any(|node| true) { - let continue_idx = editor.get_uses(natural_loop.header) + let continue_idx = editor + .get_uses(natural_loop.header) .position(|node| natural_loop.control[node.idx()]) .unwrap(); let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx]; - // Phis on the frontier of the intersection, i.e in uses_for_dependance need - // to have headers + // Phis on the frontier of the intersection, i.e in uses_for_dependance need + // to have headers // FIXME: Need to postdominate the loop continue latch - // The phi's region needs to postdominate all PHI / Reduceses (that are in the control of the loop, i.e that or uses of the loop_continue_latch) - // that it uses, not going through phis / reduces, - // + // The phi's region needs to postdominate all PHI / Reduceses (that are in the control of the loop, i.e that or uses of the loop_continue_latch) + // that it uses, not going through phis / reduces, + // - // let uses = + // let uses = // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. - if intersection.iter() + if intersection + .iter() .filter(|node| **node != loop_continue_latch) - .any(|data_node| editor.get_users(*data_node).any(|user| !loop_nodes.contains(&user))) { - // This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op) - // 3) Split the cycle into two phis, add them or multiply them together at the end. - // 4) Split the cycle into two reduces, add them or multiply them together at the end. - // Somewhere else should handle this. - return LoopPHI::LoopDependant(*phi) - } - + .any(|data_node| { + editor + .get_users(*data_node) + .any(|user| !loop_nodes.contains(&user)) + }) + { + // This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op) + // 3) Split the cycle into two phis, add them or multiply them together at the end. + // 4) Split the cycle into two reduces, add them or multiply them together at the end. + // Somewhere else should handle this. + return LoopPHI::LoopDependant(*phi); + } + // if tehre are separate types of ops, or any non associative ops, then its not associative - + // Extract ops // let is_associative = intersection.iter().filter_map(|node| match editor.node(node) { // Node::Unary { input, op } => todo!(), @@ -555,11 +607,9 @@ pub fn analyze_phis<'a>(editor: &'a FunctionEditor, natural_loop: &'a Loop, phis continue_latch: loop_continue_latch, is_associative, } - - - } else { // No cycles exist, this isn't a reduction. + } else { + // No cycles exist, this isn't a reduction. LoopPHI::LoopDependant(*phi) } }) - -} \ No newline at end of file +} diff --git a/hercules_opt/src/ivar.rs b/hercules_opt/src/ivar.rs index 893cf7638c3541ea25b80069e08835091ec337e7..7f76b0f540ddaae756e988f001d3de5bca01c197 100644 --- a/hercules_opt/src/ivar.rs +++ b/hercules_opt/src/ivar.rs @@ -1,7 +1,7 @@ -extern crate hercules_ir; -extern crate slotmap; extern crate bitvec; +extern crate hercules_ir; extern crate nestify; +extern crate slotmap; use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::path::Iter; @@ -11,9 +11,9 @@ use self::nestify::nest; use self::hercules_ir::Subgraph; use self::bitvec::order::Lsb0; +use self::bitvec::prelude::*; use self::bitvec::vec::BitVec; use self::hercules_ir::get_uses; -use self::bitvec::prelude::*; use self::hercules_ir::LoopTree; @@ -30,13 +30,12 @@ use crate::*; */ /* ASIDE: (@xrouth) I want a word for something that can be 'queried', but doesn't reveal anything about the underlying data structure, - single loop only... */ - +single loop only... */ #[derive(Debug)] pub struct LoopVarianceInfo { - pub loop_header: NodeID, - pub map: DenseNodeMap<LoopVariance> + pub loop_header: NodeID, + pub map: DenseNodeMap<LoopVariance>, } #[derive(Clone, Copy, Debug, PartialEq)] @@ -48,11 +47,10 @@ pub enum LoopVariance { type NodeVec = BitVec<u8, Lsb0>; - #[derive(Clone, Debug)] pub struct Loop { pub header: NodeID, - pub control: NodeVec, // + pub control: NodeVec, // } impl Loop { @@ -62,8 +60,8 @@ impl Loop { all_loop_nodes } } -nest!{ -/** Represents a basic induction variable. +nest! { +/** Represents a basic induction variable. NOTE (@xrouth): May switch to using SCEV to represent induction vairables, for now we assume only basic induction variables with a constant update (not even linear). Eventually add dynamic constant updates, and linear updates */ @@ -76,7 +74,7 @@ pub struct BasicInductionVariable { } } // nest -nest!{ +nest! { #[derive(Clone, Copy, Debug, PartialEq)]* pub enum InductionVariable { pub Basic { @@ -86,7 +84,7 @@ nest!{ final_value: Option<NodeID>, }, SCEV(NodeID), - //ScevAdd(NodeID, NodeID), + //ScevAdd(NodeID, NodeID), // ScevMul(NodeID, NodeID), } } @@ -94,15 +92,20 @@ nest!{ impl InductionVariable { pub fn phi(&self) -> NodeID { match self { - InductionVariable::Basic { node, initializer, update, final_value } => *node, + InductionVariable::Basic { + node, + initializer, + update, + final_value, + } => *node, InductionVariable::SCEV(_) => todo!(), } } // Editor has become just a 'context' that everything needs. This is similar to how analyses / passes are structured, // but editor forces recomputation / bookkeeping of simple / more commonly used info (even though it really is just def use, constants, dyn_constants) - // While if the pass writer wants more complicated info, like analyses results, they have to thread it through the pass manager. - // This seems fine. + // While if the pass writer wants more complicated info, like analyses results, they have to thread it through the pass manager. + // This seems fine. // pub fn update_i64(&self, editor: &FunctionEditor) -> Option<i64> { // match self { // InductionVariable::Basic { node, initializer, update, final_value } => { @@ -118,19 +121,16 @@ impl InductionVariable { // } // } - // It would be nice for functions, as they (kinda) automatically capture 'self' to also automatically capture a 'context' that is in the same scope, - // so I don't have to keep passing a context into every function that needs one. - // + // It would be nice for functions, as they (kinda) automatically capture 'self' to also automatically capture a 'context' that is in the same scope, + // so I don't have to keep passing a context into every function that needs one. + // } -// TODO: Optimize. -pub fn calculate_loop_nodes( - editor: &FunctionEditor, - natural_loop: &Loop, -) -> HashSet<NodeID> { - - // Stop on PHIs / reduces outside of loop. - let stop_on: HashSet<NodeID> = editor.node_ids().filter( - |node|{ +// TODO: Optimize. +pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> HashSet<NodeID> { + // Stop on PHIs / reduces outside of loop. + let stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { let data = &editor.func().nodes[node.idx()]; // External Phi @@ -140,7 +140,12 @@ pub fn calculate_loop_nodes( } } // External Reduce - if let Node::Reduce { control, init, reduct} = data { + if let Node::Reduce { + control, + init, + reduct, + } = data + { if !natural_loop.control[control.idx()] { return true; } @@ -148,32 +153,41 @@ pub fn calculate_loop_nodes( // External Control if data.is_control() && !natural_loop.control[node.idx()] { - return true + return true; } return false; - } - ).collect(); - - let phis: Vec<_> = editor.node_ids().filter(|node| { - let Node::Phi { control, ref data } = editor.func().nodes[node.idx()] else {return false}; - natural_loop.control[control.idx()] - }).collect(); + }) + .collect(); + + let phis: Vec<_> = editor + .node_ids() + .filter(|node| { + let Node::Phi { control, ref data } = editor.func().nodes[node.idx()] else { + return false; + }; + natural_loop.control[control.idx()] + }) + .collect(); // let all_users: HashSet<_> = editor.get_users(natural_loop.header).filter(|node| editor.func().nodes[node.idx()].is_phi()) // .flat_map(|phi| walk_all_users_stop_on(phi, editor, stop_on.clone())) // .chain(editor.get_users(natural_loop.header).filter(|node| editor.func().nodes[node.idx()].is_phi())) // .collect(); - let all_users: HashSet<NodeID> = phis.clone().iter().flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) - .chain(phis.clone()) - .collect(); + let all_users: HashSet<NodeID> = phis + .clone() + .iter() + .flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) + .chain(phis.clone()) + .collect(); - let all_uses: HashSet<_> = phis.clone().iter() + let all_uses: HashSet<_> = phis + .clone() + .iter() .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone())) .chain(phis) - .filter(|node| - { + .filter(|node| { // Get rid of nodes in stop_on !stop_on.contains(node) }) @@ -192,9 +206,15 @@ pub fn calculate_loop_nodes( } /** returns PHIs that are *in* a loop */ -pub fn get_all_loop_phis<'a>(function: &'a Function, l: &'a Loop) -> impl Iterator<Item = NodeID> + 'a { - function.nodes.iter().enumerate().filter_map( - move |(node_id, node)| { +pub fn get_all_loop_phis<'a>( + function: &'a Function, + l: &'a Loop, +) -> impl Iterator<Item = NodeID> + 'a { + function + .nodes + .iter() + .enumerate() + .filter_map(move |(node_id, node)| { if let Some((control, _)) = node.try_phi() { if l.control[control.idx()] { Some(NodeID::new(node_id)) @@ -204,18 +224,17 @@ pub fn get_all_loop_phis<'a>(function: &'a Function, l: &'a Loop) -> impl Iterat } else { None } - } - ) + }) } // FIXME: Need a trait that Editor and Function both implement, that gives us UseDefInfo /** Given a loop determine for each data node if the value might change upon each iteration of the loop */ pub fn compute_loop_variance(editor: &FunctionEditor, l: &Loop) -> LoopVarianceInfo { - // Gather all Phi nodes that are controlled by this loop. + // Gather all Phi nodes that are controlled by this loop. let mut loop_vars: Vec<NodeID> = vec![]; - for node_id in editor.get_users(l.header) { + for node_id in editor.get_users(l.header) { let node = &editor.func().nodes[node_id.idx()]; if let Some((control, _)) = node.try_phi() { if l.control[control.idx()] { @@ -229,38 +248,42 @@ pub fn compute_loop_variance(editor: &FunctionEditor, l: &Loop) -> LoopVarianceI let mut all_loop_nodes = l.control.clone(); all_loop_nodes.set(l.header.idx(), true); - - let mut variance_map: DenseNodeMap<LoopVariance> = vec![LoopVariance::Unknown; len]; - fn recurse(function: &Function, node: NodeID, all_loop_nodes: &BitVec<u8, Lsb0>, - variance_map: &mut DenseNodeMap<LoopVariance>, visited: &mut DenseNodeMap<bool>) - -> LoopVariance { + let mut variance_map: DenseNodeMap<LoopVariance> = vec![LoopVariance::Unknown; len]; + fn recurse( + function: &Function, + node: NodeID, + all_loop_nodes: &BitVec<u8, Lsb0>, + variance_map: &mut DenseNodeMap<LoopVariance>, + visited: &mut DenseNodeMap<bool>, + ) -> LoopVariance { if visited[node.idx()] { return variance_map[node.idx()]; } visited[node.idx()] = true; - - let node_variance = match variance_map[node.idx()] { + + let node_variance = match variance_map[node.idx()] { LoopVariance::Invariant => LoopVariance::Invariant, LoopVariance::Variant => LoopVariance::Variant, LoopVariance::Unknown => { - let mut node_variance = LoopVariance::Invariant; // Two conditions cause something to be loop variant: for node_use in get_uses(&function.nodes[node.idx()]).as_ref() { // 1) The use is a PHI *controlled* by the loop if let Some((control, data)) = function.nodes[node_use.idx()].try_phi() { - if *all_loop_nodes.get(control.idx()).unwrap() { + if *all_loop_nodes.get(control.idx()).unwrap() { node_variance = LoopVariance::Variant; break; - } + } } - + // 2) Any of the nodes uses are loop variant - if recurse(function, *node_use, all_loop_nodes, variance_map, visited) == LoopVariance::Variant { + if recurse(function, *node_use, all_loop_nodes, variance_map, visited) + == LoopVariance::Variant + { node_variance = LoopVariance::Variant; break; } @@ -271,17 +294,26 @@ pub fn compute_loop_variance(editor: &FunctionEditor, l: &Loop) -> LoopVarianceI node_variance } }; - + return node_variance; } let mut visited: DenseNodeMap<bool> = vec![false; len]; for node in (0..len).map(NodeID::new) { - recurse(editor.func(), node, &all_loop_nodes, &mut variance_map, &mut visited); - }; + recurse( + editor.func(), + node, + &all_loop_nodes, + &mut variance_map, + &mut visited, + ); + } - return LoopVarianceInfo { loop_header: l.header, map: variance_map }; + return LoopVarianceInfo { + loop_header: l.header, + map: variance_map, + }; } nest! { @@ -291,22 +323,27 @@ pub enum LoopExit { if_node: NodeID, condition_node: NodeID, }, - Unconditional(NodeID) // Probably a region. + Unconditional(NodeID) // Probably a region. } } -pub fn get_loop_exit_conditions(function: &Function, l: &Loop, control_subgraph: &Subgraph) -> Option<LoopExit> { // impl IntoIterator<Item = LoopExit> - // DFS Traversal on loop control subgraph until we find a node that is outside the loop, find the last IF on this path. - let mut last_if_on_path: DenseNodeMap<Option<NodeID>> = vec![None; function.nodes.len()]; +pub fn get_loop_exit_conditions( + function: &Function, + l: &Loop, + control_subgraph: &Subgraph, +) -> Option<LoopExit> { + // impl IntoIterator<Item = LoopExit> + // DFS Traversal on loop control subgraph until we find a node that is outside the loop, find the last IF on this path. + let mut last_if_on_path: DenseNodeMap<Option<NodeID>> = vec![None; function.nodes.len()]; // FIXME: (@xrouth) THIS IS MOST CERTAINLY BUGGED // this might be bugged... i.e might need to udpate `last if` even if already defined. - // needs to be `saturating` kinda, more iterative. May need to visit nodes more than once? + // needs to be `saturating` kinda, more iterative. May need to visit nodes more than once? - // FIXME: (@xrouth) Right now we assume only one exit from the loop, later: check for multiple exits on the loop, + // FIXME: (@xrouth) Right now we assume only one exit from the loop, later: check for multiple exits on the loop, // either as an assertion here or some other part of forkify or analysis. let mut bag_of_control_nodes = vec![l.header]; let mut visited: DenseNodeMap<bool> = vec![false; function.nodes.len()]; - + let mut final_if: Option<NodeID> = None; // do WFS @@ -317,39 +354,50 @@ pub fn get_loop_exit_conditions(function: &Function, l: &Loop, control_subgraph: } visited[node.idx()] = true; - final_if = - if function.nodes[node.idx()].is_if() { - Some(node) - } else { - last_if_on_path[node.idx()] - }; - + final_if = if function.nodes[node.idx()].is_if() { + Some(node) + } else { + last_if_on_path[node.idx()] + }; + if !l.control[node.idx()] { break; } - + for succ in control_subgraph.succs(node) { last_if_on_path[succ.idx()] = final_if; bag_of_control_nodes.push(succ.clone()); } } - final_if.map(|v| {LoopExit::Conditional { - if_node: v, - condition_node: if let Node::If{ control: _, cond } = function.nodes[v.idx()] {cond} else {unreachable!()} - }}) + final_if.map(|v| LoopExit::Conditional { + if_node: v, + condition_node: if let Node::If { control: _, cond } = function.nodes[v.idx()] { + cond + } else { + unreachable!() + }, + }) } - -pub fn match_canonicalization_bound(editor: &mut FunctionEditor, natural_loop: &Loop, loop_condition: NodeID, loop_if: NodeID, ivar: BasicInductionVariable) -> Option<NodeID> { +pub fn match_canonicalization_bound( + editor: &mut FunctionEditor, + natural_loop: &Loop, + loop_condition: NodeID, + loop_if: NodeID, + ivar: BasicInductionVariable, +) -> Option<NodeID> { // Match for code generated by loop canon - let Node::Phi { control, data } = &editor.func().nodes[loop_condition.idx()] else {unreachable!()}; + let Node::Phi { control, data } = &editor.func().nodes[loop_condition.idx()] else { + unreachable!() + }; if *control != natural_loop.header { - return None + return None; } - let continue_idx = editor.get_uses(natural_loop.header) + let continue_idx = editor + .get_uses(natural_loop.header) .position(|node| natural_loop.control[node.idx()]) .unwrap(); @@ -360,121 +408,176 @@ pub fn match_canonicalization_bound(editor: &mut FunctionEditor, natural_loop: & todo!() } - let Node::Constant { id } = &editor.func().nodes[data[init_idx].idx()] else {return None}; + let Node::Constant { id } = &editor.func().nodes[data[init_idx].idx()] else { + return None; + }; - // Check that the ID is true. - let Constant::Boolean(val) = *editor.get_constant(*id) else {return None}; - if val != true {return None}; + // Check that the ID is true. + let Constant::Boolean(val) = *editor.get_constant(*id) else { + return None; + }; + if val != true { + return None; + }; // Check other phi input. // FIXME: Factor this out into diff loop analysis. - let Node::Binary { left, right, op } = &editor.func().nodes[data[continue_idx].idx()].clone() else {return None}; + let Node::Binary { left, right, op } = &editor.func().nodes[data[continue_idx].idx()].clone() + else { + return None; + }; + + let BinaryOperator::LT = op else { return None }; - let BinaryOperator::LT = op else {return None}; - let bound = &editor.func().nodes[right.idx()]; - if !(bound.is_constant() || bound.is_dynamic_constant()) {return None}; + if !(bound.is_constant() || bound.is_dynamic_constant()) { + return None; + }; let bound = match bound { Node::Constant { id } => { let constant = editor.get_constant(*id).clone(); - let Constant::UnsignedInteger64(v) = constant else {return None}; + let Constant::UnsignedInteger64(v) = constant else { + return None; + }; let mut b = DynamicConstantID::new(0); - editor.edit( - |mut edit| { - b = edit.add_dynamic_constant(DynamicConstant::Constant(v.try_into().unwrap())); - Ok(edit) - } - ); - // Return the ID of the dynamic constant that is generated from the constant + editor.edit(|mut edit| { + b = edit.add_dynamic_constant(DynamicConstant::Constant(v.try_into().unwrap())); + Ok(edit) + }); + // Return the ID of the dynamic constant that is generated from the constant // or dynamic constant that is the existing loop bound b } Node::DynamicConstant { id } => *id, - _ => unreachable!() + _ => unreachable!(), + }; + + let Node::Binary { + left: add_left, + right: add_right, + op: add_op, + } = &editor.func().nodes[left.idx()] + else { + return None; }; - let Node::Binary { left: add_left, right: add_right, op: add_op } = &editor.func().nodes[left.idx()] else {return None}; - - let (phi, inc) = if let Node::Phi { control, data } = &editor.func().nodes[add_left.idx()] { + let (phi, inc) = if let Node::Phi { control, data } = &editor.func().nodes[add_left.idx()] { (add_left, add_right) - } else if let Node::Phi { control, data } = &editor.func().nodes[add_right.idx()] { + } else if let Node::Phi { control, data } = &editor.func().nodes[add_right.idx()] { (add_right, add_left) } else { return None; }; // Check Constant - let Node::Constant { id } = &editor.func().nodes[inc.idx()] else {return None}; + let Node::Constant { id } = &editor.func().nodes[inc.idx()] else { + return None; + }; if !editor.get_constant(*id).is_one() { return None; } // Check PHI - let Node::Phi { control: outer_control, data: outer_data } = &editor.func().nodes[phi.idx()] else {unreachable!()}; + let Node::Phi { + control: outer_control, + data: outer_data, + } = &editor.func().nodes[phi.idx()] + else { + unreachable!() + }; // FIXME: Multiple loop predecessors. - if outer_data[continue_idx] != *left {return None}; + if outer_data[continue_idx] != *left { + return None; + }; - let Node::Constant { id } = &editor.func().nodes[outer_data[init_idx].idx()] else {return None}; + let Node::Constant { id } = &editor.func().nodes[outer_data[init_idx].idx()] else { + return None; + }; if !editor.get_constant(*id).is_zero() { return None; } - // All checks passed, make new DC + // All checks passed, make new DC let mut final_node = NodeID::new(0); - editor.edit( - |mut edit| { - let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); - let max_dc = edit.add_dynamic_constant(DynamicConstant::Max(one, bound)); - final_node = edit.add_node(Node::DynamicConstant { id: max_dc }); - Ok(edit) - } - ); + editor.edit(|mut edit| { + let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); + let max_dc = edit.add_dynamic_constant(DynamicConstant::Max(one, bound)); + final_node = edit.add_node(Node::DynamicConstant { id: max_dc }); + Ok(edit) + }); Some(final_node) } pub fn has_const_fields(editor: &FunctionEditor, ivar: InductionVariable) -> bool { match ivar { - InductionVariable::Basic { node, initializer, update, final_value } => { + InductionVariable::Basic { + node, + initializer, + update, + final_value, + } => { if final_value.is_none() { return false; } - [initializer, update].iter().any( - |node| !editor.node(node).is_constant() - ) - }, + [initializer, update] + .iter() + .any(|node| !editor.node(node).is_constant()) + } InductionVariable::SCEV(node_id) => false, } -} +} /* Loop has any IV from range 0....N, N can be dynconst iterates +1 per iteration */ -// IVs need to be bounded... -pub fn has_canonical_iv<'a>(editor: &FunctionEditor, l: &Loop, ivs: &'a[InductionVariable]) -> Option<&'a InductionVariable> { - ivs.iter().find(|iv| { match iv { - InductionVariable::Basic { node, initializer, update, final_value } => { - (editor.node(initializer).is_zero_constant(&editor.get_constants()) || editor.node(initializer).is_zero_dc(&editor.get_dynamic_constants())) - && (editor.node(update).is_one_constant(&editor.get_constants()) || editor.node(update).is_one_dc(&editor.get_dynamic_constants())) - && (final_value.map(|val| editor.node(val).is_constant() || editor.node(val).is_dynamic_constant()).is_some()) +// IVs need to be bounded... +pub fn has_canonical_iv<'a>( + editor: &FunctionEditor, + l: &Loop, + ivs: &'a [InductionVariable], +) -> Option<&'a InductionVariable> { + ivs.iter().find(|iv| match iv { + InductionVariable::Basic { + node, + initializer, + update, + final_value, + } => { + (editor + .node(initializer) + .is_zero_constant(&editor.get_constants()) + || editor + .node(initializer) + .is_zero_dc(&editor.get_dynamic_constants())) + && (editor.node(update).is_one_constant(&editor.get_constants()) + || editor + .node(update) + .is_one_dc(&editor.get_dynamic_constants())) + && (final_value + .map(|val| { + editor.node(val).is_constant() || editor.node(val).is_dynamic_constant() + }) + .is_some()) } InductionVariable::SCEV(node_id) => false, - } }) } // Need a transformation that forces all IVs to be SCEVs of an IV from range 0...N, +1, else places them in a separate loop? -pub fn compute_induction_vars(function: &Function, l: &Loop, loop_variance: &LoopVarianceInfo) - -> Vec<InductionVariable> { - +pub fn compute_induction_vars( + function: &Function, + l: &Loop, + loop_variance: &LoopVarianceInfo, +) -> Vec<InductionVariable> { // 1) Gather PHIs contained in the loop. // FIXME: (@xrouth) Should this just be PHIs controlled by the header? let mut loop_vars: Vec<NodeID> = vec![]; - for (node_id, node) in function.nodes.iter().enumerate() { + for (node_id, node) in function.nodes.iter().enumerate() { if let Some((control, _)) = node.try_phi() { if l.control[control.idx()] { loop_vars.push(NodeID::new(node_id)); @@ -482,22 +585,30 @@ pub fn compute_induction_vars(function: &Function, l: &Loop, loop_variance: &Loo } } - // FIXME: (@xrouth) For now, only compute variables that have one assignment, - // (look into this:) possibly treat multiple assignment as separate induction variables. + // FIXME: (@xrouth) For now, only compute variables that have one assignment, + // (look into this:) possibly treat multiple assignment as separate induction variables. let mut induction_variables: Vec<InductionVariable> = vec![]; /* For each PHI controlled by the loop, check how it is modified */ - // It's initializer needs to be loop invariant, it's update needs to be loop variant. + // It's initializer needs to be loop invariant, it's update needs to be loop variant. for phi_id in loop_vars { let phi_node = &function.nodes[phi_id.idx()]; let (region, data) = phi_node.try_phi().unwrap(); let region_node = &function.nodes[region.idx()]; - let Node::Region { preds: region_inputs } = region_node else {continue}; + let Node::Region { + preds: region_inputs, + } = region_node + else { + continue; + }; // The initializer index is the first index of the inputs to the region node of that isn't in the loop. (what is loop_header, wtf...) // FIXME (@xrouth): If there is control flow in the loop, we won't find ... WHAT - let Some(initializer_idx) = region_inputs.iter().position(|&node_id| !l.control[node_id.idx()]) else { + let Some(initializer_idx) = region_inputs + .iter() + .position(|&node_id| !l.control[node_id.idx()]) + else { continue; }; @@ -507,30 +618,37 @@ pub fn compute_induction_vars(function: &Function, l: &Loop, loop_variance: &Loo let initializer = &function.nodes[initializer_id.idx()]; // In the case of a non 0 starting value: - // - a new dynamic constant or constant may need to be created that is the difference between the initiailizer and the loop bounds. - // Initializer does not necessarily have to be constant, but this is fine for now. + // - a new dynamic constant or constant may need to be created that is the difference between the initiailizer and the loop bounds. + // Initializer does not necessarily have to be constant, but this is fine for now. if !(initializer.is_dynamic_constant() || initializer.is_constant()) { continue; } // Check all data inputs to this phi, that aren't the initializer (i.e the value the comes from control outside of the loop) - // For now we expect only one initializer. - let data_inputs = data.iter().filter( - |data_id| NodeID::new(initializer_idx) != **data_id - ); + // For now we expect only one initializer. + let data_inputs = data + .iter() + .filter(|data_id| NodeID::new(initializer_idx) != **data_id); for data_id in data_inputs { let node = &function.nodes[data_id.idx()]; - for bop in [BinaryOperator::Add] { //, BinaryOperator::Mul, BinaryOperator::Sub] { + for bop in [BinaryOperator::Add] { + //, BinaryOperator::Mul, BinaryOperator::Sub] { if let Some((a, b)) = node.try_binary(bop) { - let iv = [(a, b), (b, a)].iter().find_map(|(pattern_phi, pattern_const)| { - if *pattern_phi == phi_id && function.nodes[pattern_const.idx()].is_constant() || function.nodes[pattern_const.idx()].is_dynamic_constant() { - return Some(InductionVariable::Basic { - node: phi_id, - initializer: initializer_id, - update: b, - final_value: None, - }) } else { + let iv = [(a, b), (b, a)] + .iter() + .find_map(|(pattern_phi, pattern_const)| { + if *pattern_phi == phi_id + && function.nodes[pattern_const.idx()].is_constant() + || function.nodes[pattern_const.idx()].is_dynamic_constant() + { + return Some(InductionVariable::Basic { + node: phi_id, + initializer: initializer_id, + update: b, + final_value: None, + }); + } else { None } }); @@ -540,36 +658,46 @@ pub fn compute_induction_vars(function: &Function, l: &Loop, loop_variance: &Loo } } } - }; + } induction_variables } // Find loop iterations -pub fn compute_iv_ranges(editor: &FunctionEditor, l: &Loop, - induction_vars: Vec<InductionVariable>, loop_condition: &LoopExit) - -> Vec<InductionVariable> { - +pub fn compute_iv_ranges( + editor: &FunctionEditor, + l: &Loop, + induction_vars: Vec<InductionVariable>, + loop_condition: &LoopExit, +) -> Vec<InductionVariable> { let (if_node, condition_node) = match loop_condition { - LoopExit::Conditional { if_node, condition_node } => (if_node, condition_node), - LoopExit::Unconditional(node_id) => todo!() + LoopExit::Conditional { + if_node, + condition_node, + } => (if_node, condition_node), + LoopExit::Unconditional(node_id) => todo!(), }; - + // Find IVs used by the loop condition, not across loop iterations. // without leaving the loop. - let stop_on: HashSet<_> = editor.node_ids().filter(|node_id| - { + let stop_on: HashSet<_> = editor + .node_ids() + .filter(|node_id| { if let Node::Phi { control, data } = editor.node(node_id) { *control == l.header } else { false } - } - ).collect(); - + }) + .collect(); + // Bound IVs used in loop bound. - let loop_bound_uses: HashSet<_> = walk_all_uses_stop_on(*condition_node, editor, stop_on).collect(); - let (loop_bound_ivs, other_ivs): (Vec<InductionVariable>, Vec<InductionVariable>) = induction_vars.into_iter().partition(|f| loop_bound_uses.contains(&f.phi())); + let loop_bound_uses: HashSet<_> = + walk_all_uses_stop_on(*condition_node, editor, stop_on).collect(); + let (loop_bound_ivs, other_ivs): (Vec<InductionVariable>, Vec<InductionVariable>) = + induction_vars + .into_iter() + .partition(|f| loop_bound_uses.contains(&f.phi())); let Some(iv) = loop_bound_ivs.first() else { return other_ivs; @@ -579,45 +707,67 @@ pub fn compute_iv_ranges(editor: &FunctionEditor, l: &Loop, return loop_bound_ivs.into_iter().chain(other_ivs).collect(); } - // FIXME: DO linear algerbra to solve for loop bounds with multiple variables involved. + // FIXME: DO linear algerbra to solve for loop bounds with multiple variables involved. let final_value = match &editor.func().nodes[condition_node.idx()] { - Node::Phi { control, data } => { - None - }, - Node::Reduce { control, init, reduct } => None, + Node::Phi { control, data } => None, + Node::Reduce { + control, + init, + reduct, + } => None, Node::Parameter { index } => None, Node::Constant { id } => None, Node::Unary { input, op } => None, - Node::Ternary { first, second, third, op } => None, + Node::Ternary { + first, + second, + third, + op, + } => None, Node::Binary { left, right, op } => { match op { BinaryOperator::LT => { // Check for a loop guard condition. // left < right - if *left == iv.phi() && - (editor.func().nodes[right.idx()].is_constant() || editor.func().nodes[right.idx()].is_dynamic_constant()) { - Some(*right) - } + if *left == iv.phi() + && (editor.func().nodes[right.idx()].is_constant() + || editor.func().nodes[right.idx()].is_dynamic_constant()) + { + Some(*right) + } // left + const < right, - else if let Node::Binary { left: inner_left, right: inner_right, op: inner_op } = editor.node(left) { - let pattern = [(inner_left, inner_right), (inner_right, inner_left)].iter().find_map(|(pattern_iv, pattern_constant)| - { - if iv.phi()== **pattern_iv && (editor.node(*pattern_constant).is_constant() || editor.node(*pattern_constant).is_dynamic_constant()) { - // FIXME: pattern_constant can be anything >= loop_update expression, + else if let Node::Binary { + left: inner_left, + right: inner_right, + op: inner_op, + } = editor.node(left) + { + let pattern = [(inner_left, inner_right), (inner_right, inner_left)] + .iter() + .find_map(|(pattern_iv, pattern_constant)| { + if iv.phi() == **pattern_iv + && (editor.node(*pattern_constant).is_constant() + || editor.node(*pattern_constant).is_dynamic_constant()) + { + // FIXME: pattern_constant can be anything >= loop_update expression, let update = match iv { - InductionVariable::Basic { node, initializer, update, final_value } => update, + InductionVariable::Basic { + node, + initializer, + update, + final_value, + } => update, InductionVariable::SCEV(node_id) => todo!(), }; if *pattern_constant == update { Some(*right) } else { None - } + } } else { None } - } - ); + }); pattern.iter().cloned().next() } else { None @@ -635,11 +785,20 @@ pub fn compute_iv_ranges(editor: &FunctionEditor, l: &Loop, }; let basic = match iv { - InductionVariable::Basic { node, initializer, update, final_value: _ } => InductionVariable::Basic { node: *node, initializer: *initializer, update: *update, final_value }, + InductionVariable::Basic { + node, + initializer, + update, + final_value: _, + } => InductionVariable::Basic { + node: *node, + initializer: *initializer, + update: *update, + final_value, + }, InductionVariable::SCEV(node_id) => todo!(), }; - // Propagate bounds to other IVs. + // Propagate bounds to other IVs. vec![basic].into_iter().chain(other_ivs).collect() } - diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index c74f587580db2f198fa91259b82a1513315fa861..01ae1c99ad3613e826801afebdb0e15376ae1377 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -9,11 +9,13 @@ pub mod editor; pub mod float_collections; pub mod fork_concat_split; pub mod fork_guard_elim; +pub mod fork_transforms; pub mod forkify; pub mod gcm; pub mod gvn; pub mod inline; pub mod interprocedural_sroa; +pub mod ivar; pub mod lift_dc_math; pub mod outline; pub mod phi_elim; @@ -21,8 +23,6 @@ pub mod pred; pub mod schedule; pub mod slf; pub mod sroa; -pub mod fork_transforms; -pub mod ivar; pub mod unforkify; pub mod utils; @@ -35,11 +35,13 @@ pub use crate::editor::*; pub use crate::float_collections::*; pub use crate::fork_concat_split::*; pub use crate::fork_guard_elim::*; +pub use crate::fork_transforms::*; pub use crate::forkify::*; pub use crate::gcm::*; pub use crate::gvn::*; pub use crate::inline::*; pub use crate::interprocedural_sroa::*; +pub use crate::ivar::*; pub use crate::lift_dc_math::*; pub use crate::outline::*; pub use crate::phi_elim::*; @@ -47,7 +49,5 @@ pub use crate::pred::*; pub use crate::schedule::*; pub use crate::slf::*; pub use crate::sroa::*; -pub use crate::fork_transforms::*; -pub use crate::ivar::*; pub use crate::unforkify::*; pub use crate::utils::*; diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index 2c8209aae002beffca3c8f7ee40d6986fed49fb0..f9f720bef2c00f97facfaf90bfd880bc06beaf5a 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -29,7 +29,7 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap< /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the - * ThreadID nodes attached to the corresponding Fork, and data of the Write is + * ThreadID nodes attached to the corresponding Fork, and data of the Write is * not in the Reduce node's cycle. This procedure also adds the ParallelReduce * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the * base Write node also has position indices of the ThreadID of the outer fork. @@ -37,7 +37,11 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap< * as long as each ThreadID dimension appears in the positional indexing of the * original Write. */ -pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { +pub fn infer_parallel_reduce( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { @@ -146,11 +150,17 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N * operation's operands must be the Reduce node, and all other operands must * not be in the Reduce node's cycle. */ -pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { - let is_binop_associative = |op| matches!(op, - BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor); - let is_intrinsic_associative = |intrinsic| matches!(intrinsic, - Intrinsic::Max | Intrinsic::Min); +pub fn infer_tight_associative( + editor: &mut FunctionEditor, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { + let is_binop_associative = |op| { + matches!( + op, + BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor + ) + }; + let is_intrinsic_associative = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); for id in editor.node_ids() { let func = editor.func(); @@ -162,8 +172,8 @@ pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &Hash && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } if ((left == id && !reduce_cycles[&id].contains(&right)) || (right == id && !reduce_cycles[&id].contains(&left))) && - is_binop_associative(op)) || - matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } + is_binop_associative(op)) + || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 3bcc689e13b55d6d36ae72cb48b0e40cd39c3abf..66d11d69c33d1a77ce5a54bfd13ad88618916bfd 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -389,7 +389,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: }, AllocatedTernary { cond: NodeID, - thn: NodeID, + thn: NodeID, els: NodeID, node: NodeID, fields: IndexTree<NodeID>, diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 7e2e267a2b893a6953b6bdd5dbd26521c2f2f285..0efd0b855969dba4b0c8e2d0dfdc9ab2220f6f50 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -7,27 +7,37 @@ use hercules_ir::{ir::*, LoopTree}; use crate::*; type NodeVec = BitVec<u8, Lsb0>; -pub fn calculate_fork_nodes(editor: &FunctionEditor, inner_control: &NodeVec, fork: NodeID, join: NodeID) -> HashSet<NodeID> { - // Stop on PHIs / reduces outside of loop. - let stop_on: HashSet<NodeID> = editor.node_ids().filter( - |node|{ +pub fn calculate_fork_nodes( + editor: &FunctionEditor, + inner_control: &NodeVec, + fork: NodeID, + join: NodeID, +) -> HashSet<NodeID> { + // Stop on PHIs / reduces outside of loop. + let stop_on: HashSet<NodeID> = editor + .node_ids() + .filter(|node| { let data = &editor.func().nodes[node.idx()]; // External Phi if let Node::Phi { control, data } = data { if match inner_control.get(control.idx()) { - Some(v) => !*v, // - None => true, // Doesn't exist, must be external + Some(v) => !*v, // + None => true, // Doesn't exist, must be external } { return true; } - } // External Reduce - if let Node::Reduce { control, init, reduct} = data { + if let Node::Reduce { + control, + init, + reduct, + } = data + { if match inner_control.get(control.idx()) { - Some(v) => !*v, // - None => true, // Doesn't exist, must be external + Some(v) => !*v, // + None => true, // Doesn't exist, must be external } { return true; } @@ -36,37 +46,49 @@ pub fn calculate_fork_nodes(editor: &FunctionEditor, inner_control: &NodeVec, fo // External Control if data.is_control() { return match inner_control.get(node.idx()) { - Some(v) => !*v, // - None => true, // Doesn't exist, must be external - } + Some(v) => !*v, // + None => true, // Doesn't exist, must be external + }; } // else return false; - } - ).collect(); + }) + .collect(); - let reduces: Vec<_> = editor.node_ids().filter(|node| { - let Node::Reduce { control, .. } = editor.func().nodes[node.idx()] else {return false}; - match inner_control.get(control.idx()) { - Some(v) => *v, - None => false, - } - }).chain(editor.get_users(fork).filter(|node| { - editor.node(node).is_thread_id() - })).collect(); + let reduces: Vec<_> = editor + .node_ids() + .filter(|node| { + let Node::Reduce { control, .. } = editor.func().nodes[node.idx()] else { + return false; + }; + match inner_control.get(control.idx()) { + Some(v) => *v, + None => false, + } + }) + .chain( + editor + .get_users(fork) + .filter(|node| editor.node(node).is_thread_id()), + ) + .collect(); - let all_users: HashSet<NodeID> = reduces.clone().iter().flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) + let all_users: HashSet<NodeID> = reduces + .clone() + .iter() + .flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone())) .chain(reduces.clone()) .collect(); - let all_uses: HashSet<_> = reduces.clone().iter() + let all_uses: HashSet<_> = reduces + .clone() + .iter() .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone())) .chain(reduces) - .filter(|node| - { + .filter(|node| { // Get rid of nodes in stop_on !stop_on.contains(node) - }) + }) .collect(); all_users.intersection(&all_uses).cloned().collect() @@ -77,7 +99,13 @@ pub fn calculate_fork_nodes(editor: &FunctionEditor, inner_control: &NodeVec, fo * sequential loops in LLVM is actually not entirely trivial, so it's easier to * just do this transformation within Hercules IR. */ -pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, loop_tree: &LoopTree) { + +// FIXME: Only works on fully split fork nests. +pub fn unforkify( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + loop_tree: &LoopTree, +) { let mut zero_cons_id = ConstantID::new(0); let mut one_cons_id = ConstantID::new(0); assert!(editor.edit(|mut edit| { @@ -129,7 +157,7 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No let add_id = NodeID::new(num_nodes + 7); let dc_id = NodeID::new(num_nodes + 8); let neq_id = NodeID::new(num_nodes + 9); - + let guard_if_id = NodeID::new(num_nodes + 10); let guard_join_id = NodeID::new(num_nodes + 11); let guard_taken_proj_id = NodeID::new(num_nodes + 12); @@ -140,20 +168,29 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No let s = num_nodes + 15 + reduces.len(); let join_phi_ids = (s..s + reduces.len()).map(NodeID::new); - let guard_cond = Node::Binary { left: zero_id, right: dc_id, op: BinaryOperator::LT}; - let guard_if = Node::If { control: fork_control, cond: guard_cond_id}; - let guard_taken_proj = Node::Projection { control: guard_if_id, selection: 1 }; - let guard_skipped_proj = Node::Projection { control: guard_if_id, selection: 0 }; - let guard_join = Node::Region { preds: Box::new([ - guard_skipped_proj_id, - proj_exit_id, - ])}; + let guard_cond = Node::Binary { + left: zero_id, + right: dc_id, + op: BinaryOperator::LT, + }; + let guard_if = Node::If { + control: fork_control, + cond: guard_cond_id, + }; + let guard_taken_proj = Node::Projection { + control: guard_if_id, + selection: 1, + }; + let guard_skipped_proj = Node::Projection { + control: guard_if_id, + selection: 0, + }; + let guard_join = Node::Region { + preds: Box::new([guard_skipped_proj_id, proj_exit_id]), + }; let region = Node::Region { - preds: Box::new([ - guard_taken_proj_id, - proj_back_id, - ]), + preds: Box::new([guard_taken_proj_id, proj_back_id]), }; let if_node = Node::If { control: join_control, @@ -188,14 +225,16 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No .iter() .map(|reduce_id| { let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap(); - (Node::Phi { - control: region_id, - data: Box::new([init, reduct]), - }, - Node::Phi { - control: guard_join_id, - data: Box::new([init, reduct]) - }) + ( + Node::Phi { + control: region_id, + data: Box::new([init, reduct]), + }, + Node::Phi { + control: guard_join_id, + data: Box::new([init, reduct]), + }, + ) }) .unzip(); @@ -231,13 +270,20 @@ pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, No edit.sub_edit(*tid, indvar_id); edit = edit.replace_all_uses(*tid, indvar_id)?; } - for (((reduce, phi_id), phi), join_phi_id) in zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids) { + for (((reduce, phi_id), phi), join_phi_id) in + zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids) + { edit.sub_edit(*reduce, phi_id); - let Node::Phi { control, data } = phi else {panic!()}; - edit = edit.replace_all_uses_where(*reduce, join_phi_id, |usee| !fork_nodes.contains(usee))?; //, |usee| *usee != *reduct)?; - edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| fork_nodes.contains(usee) || *usee == data[1])?; + let Node::Phi { control, data } = phi else { + panic!() + }; + edit = edit.replace_all_uses_where(*reduce, join_phi_id, |usee| { + !fork_nodes.contains(usee) + })?; //, |usee| *usee != *reduct)?; + edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| { + fork_nodes.contains(usee) || *usee == data[1] + })?; edit = edit.delete_node(*reduce)?; - } edit = edit.delete_node(*fork)?; diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index aa0d53fe32f855d5a1f9ad689fab44ef330e7a76..67225bffeeb8d585caffb85ac2e2a91ab81a6d0b 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -1,7 +1,12 @@ +extern crate nestify; + +use std::collections::HashMap; +use std::collections::HashSet; use std::iter::zip; use hercules_ir::def_use::*; use hercules_ir::ir::*; +use nestify::nest; use crate::*; @@ -376,3 +381,106 @@ pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> boo // may overlap when one indexes a larger sub-value than the other. true } + +pub type DenseNodeMap<T> = Vec<T>; +pub type SparseNodeMap<T> = HashMap<NodeID, T>; + +nest! { +// Is this something editor should give... Or is it just for analyses. +// +#[derive(Clone, Debug)] +pub struct NodeIterator<'a> { + pub direction: + #[derive(Clone, Debug, PartialEq)] + enum Direction { + Uses, + Users, + }, + visited: DenseNodeMap<bool>, + stack: Vec<NodeID>, + func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor. + // `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search. + stop_on: HashSet<NodeID>, // Don't add neighbors of these. +} +} + +pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + NodeIterator { + direction: Direction::Uses, + visited: vec![false; len], + stack: vec![node], + func: editor, + stop_on: HashSet::new(), + } +} + +pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + NodeIterator { + direction: Direction::Users, + visited: vec![false; len], + stack: vec![node], + func: editor, + stop_on: HashSet::new(), + } +} + +pub fn walk_all_uses_stop_on<'a>( + node: NodeID, + editor: &'a FunctionEditor<'a>, + stop_on: HashSet<NodeID>, +) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + let uses = editor.get_uses(node).collect(); + NodeIterator { + direction: Direction::Uses, + visited: vec![false; len], + stack: uses, + func: editor, + stop_on, + } +} + +pub fn walk_all_users_stop_on<'a>( + node: NodeID, + editor: &'a FunctionEditor<'a>, + stop_on: HashSet<NodeID>, +) -> NodeIterator<'a> { + let len = editor.func().nodes.len(); + let users = editor.get_users(node).collect(); + NodeIterator { + direction: Direction::Users, + visited: vec![false; len], + stack: users, + func: editor, + stop_on, + } +} + +impl<'a> Iterator for NodeIterator<'a> { + type Item = NodeID; + + fn next(&mut self) -> Option<Self::Item> { + while let Some(current) = self.stack.pop() { + if !self.visited[current.idx()] { + self.visited[current.idx()] = true; + + if !self.stop_on.contains(¤t) { + if self.direction == Direction::Uses { + for neighbor in self.func.get_uses(current) { + self.stack.push(neighbor) + } + } else { + for neighbor in self.func.get_users(current) { + self.stack.push(neighbor) + } + } + } + + return Some(current); + } + } + None + } +} diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs index f895af867a019dfd23381a4df2d9a02f80a032f8..c15ca97fa4b0730622f28e6cf16f7ab24de7310a 100644 --- a/hercules_samples/matmul/build.rs +++ b/hercules_samples/matmul/build.rs @@ -4,7 +4,7 @@ fn main() { JunoCompiler::new() .ir_in_src("matmul.hir") .unwrap() - //.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + // .schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) .schedule_in_src("cpu.sch") .unwrap() .build() diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 9b8e2e9cbee9a064678ae929b944338c049ff561..1ef705612c9e1c8d15bdbfd72a89d0ae16427450 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -1,10 +1,9 @@ - +use std::collections::hash_map::Entry::Occupied; use std::collections::HashMap; use std::panic; -use std::collections::hash_map::Entry::Occupied; use itertools::Itertools; -use std::cmp::{min, max}; +use std::cmp::{max, min}; use hercules_ir::*; @@ -44,8 +43,8 @@ pub struct FunctionContext<'a> { fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>, } -impl <'a> FunctionContext<'a> { - pub fn new ( +impl<'a> FunctionContext<'a> { + pub fn new( control_subgraph: &'a Subgraph, def_use: &'a ImmutableDefUseMap, fork_join_map: &'a HashMap<NodeID, NodeID>, // Map forks -> joins @@ -61,18 +60,43 @@ impl <'a> FunctionContext<'a> { } // TODO: (@xrouth) I feel like this funcitonality should be provided by the manager that holds and allocates dynamic constants & IDs. -pub fn dyn_const_value(dc: &DynamicConstantID, dyn_const_values: &[DynamicConstant], dyn_const_params: &[usize]) -> usize { +pub fn dyn_const_value( + dc: &DynamicConstantID, + dyn_const_values: &[DynamicConstant], + dyn_const_params: &[usize], +) -> usize { let dc = &dyn_const_values[dc.idx()]; match dc { DynamicConstant::Constant(v) => *v, DynamicConstant::Parameter(v) => dyn_const_params[*v], - DynamicConstant::Add(a, b) => dyn_const_value(a, dyn_const_values, dyn_const_params) + dyn_const_value(b, dyn_const_values, dyn_const_params), - DynamicConstant::Sub(a, b) => dyn_const_value(a, dyn_const_values, dyn_const_params) - dyn_const_value(b, dyn_const_values, dyn_const_params), - DynamicConstant::Mul(a, b) => dyn_const_value(a, dyn_const_values, dyn_const_params) * dyn_const_value(b, dyn_const_values, dyn_const_params), - DynamicConstant::Div(a, b) => dyn_const_value(a, dyn_const_values, dyn_const_params) / dyn_const_value(b, dyn_const_values, dyn_const_params), - DynamicConstant::Rem(a, b) => dyn_const_value(a, dyn_const_values, dyn_const_params) % dyn_const_value(b, dyn_const_values, dyn_const_params), - DynamicConstant::Max(a, b) => max(dyn_const_value(a, dyn_const_values, dyn_const_params), dyn_const_value(b, dyn_const_values, dyn_const_params)), - DynamicConstant::Min(a, b) => min(dyn_const_value(a, dyn_const_values, dyn_const_params), dyn_const_value(b, dyn_const_values, dyn_const_params)), + DynamicConstant::Add(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + + dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Sub(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + - dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Mul(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + * dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Div(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + / dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Rem(a, b) => { + dyn_const_value(a, dyn_const_values, dyn_const_params) + % dyn_const_value(b, dyn_const_values, dyn_const_params) + } + DynamicConstant::Max(a, b) => max( + dyn_const_value(a, dyn_const_values, dyn_const_params), + dyn_const_value(b, dyn_const_values, dyn_const_params), + ), + DynamicConstant::Min(a, b) => min( + dyn_const_value(a, dyn_const_values, dyn_const_params), + dyn_const_value(b, dyn_const_values, dyn_const_params), + ), } } @@ -91,7 +115,12 @@ pub struct ControlToken { impl ControlToken { pub fn moved_to(&self, next: NodeID) -> ControlToken { - ControlToken { curr: next, prev: self.curr, thread_indicies: self.thread_indicies.clone(), phi_values: self.phi_values.clone() } + ControlToken { + curr: next, + prev: self.curr, + thread_indicies: self.thread_indicies.clone(), + phi_values: self.phi_values.clone(), + } } } impl<'a> FunctionExecutionState<'a> { @@ -102,9 +131,15 @@ impl<'a> FunctionExecutionState<'a> { function_contexts: &'a Vec<FunctionContext>, dynamic_constant_params: Vec<usize>, ) -> Self { - println!("param types: {:?}", module.functions[function_id.idx()].param_types); + println!( + "param types: {:?}", + module.functions[function_id.idx()].param_types + ); - assert_eq!(args.len(), module.functions[function_id.idx()].param_types.len()); + assert_eq!( + args.len(), + module.functions[function_id.idx()].param_types.len() + ); FunctionExecutionState { args, @@ -138,15 +173,10 @@ impl<'a> FunctionExecutionState<'a> { } /* Drives PHI values of this region for a control token, returns the next control node. */ - pub fn handle_region( - &mut self, - token: &mut ControlToken, - preds: &Box<[NodeID]>, - ) -> NodeID { - + pub fn handle_region(&mut self, token: &mut ControlToken, preds: &Box<[NodeID]>) -> NodeID { let prev = token.prev; let node = token.curr; - + // Gather PHI nodes for this region node. let phis: Vec<NodeID> = self .get_def_use() @@ -193,12 +223,12 @@ impl<'a> FunctionExecutionState<'a> { .try_phi() .expect("PANIC: handle_phi on non-phi node."); let value_node = data[edge]; - + let value = self.handle_data(token, value_node); if VERBOSE { println!("Latching PHI {:?} to {:?}", phi.idx(), value); } - + (phi, value) } @@ -221,7 +251,7 @@ impl<'a> FunctionExecutionState<'a> { for reduction in &reduces { self.handle_reduction(&token, *reduction); } - + let thread_values = self.get_thread_factors(&token, join); // println!("join for: {:?}", token); // dbg!(thread_values.clone()); @@ -231,7 +261,11 @@ impl<'a> FunctionExecutionState<'a> { .and_modify(|v| *v -= 1); if VERBOSE { - println!("join, thread_values : {:?}, {:?}", join, thread_values.clone()); + println!( + "join, thread_values : {:?}, {:?}", + join, + thread_values.clone() + ); } if *self .join_counters @@ -259,15 +293,28 @@ impl<'a> FunctionExecutionState<'a> { // Take the top N entries such that it matches the length of the TRF in the control token. // Get the depth of the control token that is requesting this reduction node. - + // Sum over all thread dimensions in nested forks - let fork_levels: usize = nested_forks.iter().map(|ele| - self.get_function().nodes[ele.idx()].try_fork().unwrap().1.len()).sum(); - + let fork_levels: usize = nested_forks + .iter() + .map(|ele| { + self.get_function().nodes[ele.idx()] + .try_fork() + .unwrap() + .1 + .len() + }) + .sum(); + let len = if nested_forks.is_empty() { fork_levels - 1 } else { - fork_levels - (self.get_function().nodes[nested_forks.first().unwrap().idx()].try_fork().unwrap().1.len()) + fork_levels + - (self.get_function().nodes[nested_forks.first().unwrap().idx()] + .try_fork() + .unwrap() + .1 + .len()) }; let mut thread_values = token.thread_indicies.clone(); @@ -276,7 +323,6 @@ impl<'a> FunctionExecutionState<'a> { } pub fn initialize_reduction(&mut self, token_at_fork: &ControlToken, reduce: NodeID) { - let token = token_at_fork; let (control, init, _) = &self.get_function().nodes[reduce.idx()] @@ -286,12 +332,16 @@ impl<'a> FunctionExecutionState<'a> { let thread_values = self.get_thread_factors(token, *control); let init = self.handle_data(&token, *init); - + if VERBOSE { - println!("reduction {:?} initialized to: {:?} on thread {:?}", reduce, init, thread_values); + println!( + "reduction {:?} initialized to: {:?} on thread {:?}", + reduce, init, thread_values + ); } - self.reduce_values.insert((thread_values.clone(), reduce), init); + self.reduce_values + .insert((thread_values.clone(), reduce), init); } // Drive the reduction, this will be invoked for each control token. @@ -305,7 +355,10 @@ impl<'a> FunctionExecutionState<'a> { let data = self.handle_data(&token, *reduct); if VERBOSE { - println!("reduction {:?} write of {:?} on thread {:?}", reduce, data, thread_values); + println!( + "reduction {:?} write of {:?} on thread {:?}", + reduce, data, thread_values + ); } self.reduce_values.insert((thread_values, reduce), data); @@ -315,8 +368,11 @@ impl<'a> FunctionExecutionState<'a> { // println!("Data Node: {} {:?}", node.idx(), &self.get_function().nodes[node.idx()]); // Partial borrow complaint. :/ - match &self.module.functions[self.function_id.idx()].nodes[node.idx()]{ - Node::Phi { control: _, data: _ } => (*token + match &self.module.functions[self.function_id.idx()].nodes[node.idx()] { + Node::Phi { + control: _, + data: _, + } => (*token .phi_values .get(&node) .expect(&format!("PANIC: Phi {:?} value not latched.", node))) @@ -330,23 +386,45 @@ impl<'a> FunctionExecutionState<'a> { .expect("PANIC: No nesting information for thread index!") .clone(); - let num_dims_this_level = (self.get_function().nodes[nested_forks.first().unwrap().idx()].try_fork().unwrap().1.len()); + let num_dims_this_level = (self.get_function().nodes + [nested_forks.first().unwrap().idx()] + .try_fork() + .unwrap() + .1 + .len()); // println!("num forks this level:{:?} ", num_forks_this_level); - // Skip forks until we get to this level. - // How many forks are outer? idfk. - let outer_forks: Vec<NodeID> = nested_forks.iter().cloned().take_while(|fork| *fork != node).collect(); + // Skip forks until we get to this level. + // How many forks are outer? idfk. + let outer_forks: Vec<NodeID> = nested_forks + .iter() + .cloned() + .take_while(|fork| *fork != node) + .collect(); // println!("otuer_forkes: {:?}", outer_forks); - - let fork_levels: usize = outer_forks.iter().skip(1).map(|ele| self.get_function().nodes[ele.idx()].try_fork().unwrap().1.len()).sum(); + + let fork_levels: usize = outer_forks + .iter() + .skip(1) + .map(|ele| { + self.get_function().nodes[ele.idx()] + .try_fork() + .unwrap() + .1 + .len() + }) + .sum(); // println!("nested forks:{:?} ", nested_forks); // println!("fork levels: {:?}", fork_levels); // dimension might need to instead be dimensions - dimension let v = token.thread_indicies[fork_levels + dimension]; // Might have to -1? if VERBOSE { - println!("node: {:?} gives tid: {:?} for thread: {:?}, dim: {:?}", node, v, token.thread_indicies, dimension); + println!( + "node: {:?} gives tid: {:?} for thread: {:?}, dim: {:?}", + node, v, token.thread_indicies, dimension + ); } InterpreterVal::DynamicConstant((v).into()) } @@ -360,13 +438,14 @@ impl<'a> FunctionExecutionState<'a> { let thread_values = self.get_thread_factors(token, *control); // println!("reduction read: {:?}, {:?}", thread_values, node); - let entry = self - .reduce_values - .entry((thread_values.clone(), node)); - + let entry = self.reduce_values.entry((thread_values.clone(), node)); + let val = match entry { Occupied(v) => v.get().clone(), - std::collections::hash_map::Entry::Vacant(_) => panic!("Ctrl token: {:?}, Reduce {:?} has not been initialized!, TV: {:?}", token, node, thread_values), + std::collections::hash_map::Entry::Vacant(_) => panic!( + "Ctrl token: {:?}, Reduce {:?} has not been initialized!, TV: {:?}", + token, node, thread_values + ), }; // println!("value: {:?}", val.clone()); val @@ -379,12 +458,16 @@ impl<'a> FunctionExecutionState<'a> { &self.module.constants, &self.module.types, &self.module.dynamic_constants, - &self.dynamic_constant_params + &self.dynamic_constant_params, ) } Node::DynamicConstant { id } => { - let v = dyn_const_value(id, &self.module.dynamic_constants, &self.dynamic_constant_params); - + let v = dyn_const_value( + id, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ); + // TODO: Figure out what type / semantics are of thread ID and dynamic const. InterpreterVal::UnsignedInteger64(v.try_into().expect("too big dyn const!")) } @@ -425,15 +508,21 @@ impl<'a> FunctionExecutionState<'a> { control, } => { // todo!("call currently dissabled lol"); - let args = args.into_iter() - .map(|arg_node| self.handle_data(token, *arg_node)) - .collect(); - + let args = args + .into_iter() + .map(|arg_node| self.handle_data(token, *arg_node)) + .collect(); - let dynamic_constant_params = dynamic_constants.into_iter() - .map(|id| { - dyn_const_value(id, &self.module.dynamic_constants, &self.dynamic_constant_params) - }).collect_vec(); + let dynamic_constant_params = dynamic_constants + .into_iter() + .map(|id| { + dyn_const_value( + id, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) + .collect_vec(); let mut state = FunctionExecutionState::new( args, @@ -453,12 +542,13 @@ impl<'a> FunctionExecutionState<'a> { let result = self.handle_read(token, collection.clone(), indices); if VERBOSE { - println!("{:?} read value : {:?} from {:?}, {:?} at index {:?}", node, result, collect, collection, indices); + println!( + "{:?} read value : {:?} from {:?}, {:?} at index {:?}", + node, result, collect, collection, indices + ); } result } - - } Node::Write { collect, @@ -473,11 +563,7 @@ impl<'a> FunctionExecutionState<'a> { self.handle_write(token, collection, data, indices) } } - Node::Undef { - ty - } => { - InterpreterVal::Undef(*ty) - } + Node::Undef { ty } => InterpreterVal::Undef(*ty), _ => todo!(), } } @@ -489,7 +575,6 @@ impl<'a> FunctionExecutionState<'a> { data: InterpreterVal, indices: &[Index], ) -> InterpreterVal { - // TODO (@xrouth): Recurse on writes correctly let val = match indices.first() { Some(Index::Field(idx)) => { @@ -499,10 +584,8 @@ impl<'a> FunctionExecutionState<'a> { } else { panic!("PANIC: Field index on not a product type") } - }, - None => { - collection } + None => collection, Some(Index::Variant(_)) => todo!(), Some(Index::Position(array_indices)) => { // Arrays also have inner indices... @@ -518,7 +601,13 @@ impl<'a> FunctionExecutionState<'a> { .try_extents() .expect("PANIC: wrong type for array") .into_iter() - .map(|extent| dyn_const_value(extent, &self.module.dynamic_constants, &self.dynamic_constant_params)) + .map(|extent| { + dyn_const_value( + extent, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) .collect(); let idx = InterpreterVal::array_idx(&extents, &array_indices); //println!("idx: {:?}", idx); @@ -528,7 +617,6 @@ impl<'a> FunctionExecutionState<'a> { vals[idx] = data; InterpreterVal::Array(type_id, vals) } - } else { panic!("PANIC: Position index on not an array") } @@ -556,10 +644,10 @@ impl<'a> FunctionExecutionState<'a> { .map(|idx| self.handle_data(token, *idx).as_usize()) .collect(); - if VERBOSE{ + if VERBOSE { println!("read at rt indicies: {:?}", array_indices); } - + // TODO: Implemenet . try_array() and other try_conversions on the InterpreterVal type if let InterpreterVal::Array(type_id, vals) = collection { // TODO: Make this its own funciton to reuse w/ array_size @@ -567,15 +655,23 @@ impl<'a> FunctionExecutionState<'a> { .try_extents() .expect("PANIC: wrong type for array") .into_iter() - .map(|extent| dyn_const_value(extent, &self.module.dynamic_constants, &self.dynamic_constant_params)) + .map(|extent| { + dyn_const_value( + extent, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) .collect(); - // FIXME: This type may be wrong. - let ret = vals.get(InterpreterVal::array_idx(&extents, &array_indices)).unwrap_or(&InterpreterVal::Undef(type_id)).clone(); + // FIXME: This type may be wrong. + let ret = vals + .get(InterpreterVal::array_idx(&extents, &array_indices)) + .unwrap_or(&InterpreterVal::Undef(type_id)) + .clone(); if let InterpreterVal::Undef(_) = ret { panic!("bad read!") } ret - } else { panic!("PANIC: Position index on not an array") } @@ -603,10 +699,11 @@ impl<'a> FunctionExecutionState<'a> { let mut live_tokens: Vec<ControlToken> = Vec::new(); live_tokens.push(start_token); - // To do reduction nodes correctly we have to traverse control tokens in a depth-first fashion (i.e immediately handle spawned threads). 'outer: loop { - let mut ctrl_token = live_tokens.pop().expect("PANIC: Interpreter ran out of control tokens without returning."); + let mut ctrl_token = live_tokens + .pop() + .expect("PANIC: Interpreter ran out of control tokens without returning."); // println!( // "\n\nNew Token at: Control State: {} threads: {:?}, {:?}", @@ -614,28 +711,34 @@ impl<'a> FunctionExecutionState<'a> { // ctrl_token.thread_indicies.clone(), // &self.get_function().nodes[ctrl_token.curr.idx()] // ); - // TODO: (@xrouth): Enable this + PHI latch logging wi/ a simple debug flag. + // TODO: (@xrouth): Enable this + PHI latch logging wi/ a simple debug flag. // Tracking PHI vals and control state is very useful for debugging. - if VERBOSE { - println!("control token {} {}", ctrl_token.curr.idx(), &self.get_function().nodes[ctrl_token.curr.idx()].lower_case_name()); + println!( + "control token {} {}", + ctrl_token.curr.idx(), + &self.get_function().nodes[ctrl_token.curr.idx()].lower_case_name() + ); } // TODO: Rust is annoying and can't recognize that this is a partial borrow. - // Can't partial borrow, so need a clone. + // Can't partial borrow, so need a clone. let node = &self.get_function().nodes[ctrl_token.curr.idx()].clone(); let new_tokens = match node { Node::Start => { - let next: NodeID = self.get_control_subgraph().succs(ctrl_token.curr).next().unwrap(); + let next: NodeID = self + .get_control_subgraph() + .succs(ctrl_token.curr) + .next() + .unwrap(); let ctrl_token = ctrl_token.moved_to(next); - + vec![ctrl_token] } Node::Region { preds } => { - - // Updates + // Updates let next = self.handle_region(&mut ctrl_token, &preds); let ctrl_token = ctrl_token.moved_to(next); @@ -666,7 +769,11 @@ impl<'a> FunctionExecutionState<'a> { vec![ctrl_token] } Node::Projection { .. } => { - let next: NodeID = self.get_control_subgraph().succs(ctrl_token.curr).next().unwrap(); + let next: NodeID = self + .get_control_subgraph() + .succs(ctrl_token.curr) + .next() + .unwrap(); let ctrl_token = ctrl_token.moved_to(next); @@ -674,18 +781,34 @@ impl<'a> FunctionExecutionState<'a> { } Node::Match { control: _, sum: _ } => todo!(), - Node::Fork { control: _, factors } => { + Node::Fork { + control: _, + factors, + } => { let fork = ctrl_token.curr; // if factors.len() > 1 { // panic!("multi-dimensional forks unimplemented") // } - let factors = factors.iter().map(|f| dyn_const_value(&f, &self.module.dynamic_constants, &self.dynamic_constant_params)).rev(); + let factors = factors + .iter() + .map(|f| { + dyn_const_value( + &f, + &self.module.dynamic_constants, + &self.dynamic_constant_params, + ) + }) + .rev(); let n_tokens: usize = factors.clone().product(); - // Update control token - let next = self.get_control_subgraph().succs(ctrl_token.curr).nth(0).unwrap(); + // Update control token + let next = self + .get_control_subgraph() + .succs(ctrl_token.curr) + .nth(0) + .unwrap(); let ctrl_token = ctrl_token.moved_to(next); let mut tokens_to_add = Vec::with_capacity(n_tokens); @@ -707,7 +830,6 @@ impl<'a> FunctionExecutionState<'a> { tokens_to_add.push(new_token); } - let thread_factors = self.get_thread_factors(&ctrl_token, ctrl_token.curr); // Find join and initialize them, and set their reduction counters as well. @@ -729,7 +851,7 @@ impl<'a> FunctionExecutionState<'a> { } }) .collect(); - + for reduction in reduces { // TODO: Is this the correct reduction? self.initialize_reduction(&ctrl_token, reduction); @@ -737,7 +859,10 @@ impl<'a> FunctionExecutionState<'a> { // println!("tokens_to_add: {:?}", tokens_to_add); if VERBOSE { - println!("tf, fork, join, n_tokens: {:?}, {:?}, {:?}, {:?}", thread_factors, fork, join, n_tokens); + println!( + "tf, fork, join, n_tokens: {:?}, {:?}, {:?}, {:?}", + thread_factors, fork, join, n_tokens + ); } self.join_counters.insert((thread_factors, join), n_tokens); @@ -767,9 +892,6 @@ impl<'a> FunctionExecutionState<'a> { for i in new_tokens { live_tokens.push(i); } - } } } - - diff --git a/hercules_test/hercules_interpreter/src/lib.rs b/hercules_test/hercules_interpreter/src/lib.rs index 7792f95a311bc30f221a01c7ab96aea6fda9f4a2..baf0093e299a4754aca6adecff38af8da7ac60bd 100644 --- a/hercules_test/hercules_interpreter/src/lib.rs +++ b/hercules_test/hercules_interpreter/src/lib.rs @@ -1,7 +1,7 @@ pub mod interpreter; pub mod value; -extern crate postcard; extern crate juno_scheduler; +extern crate postcard; use std::fs::File; use std::io::Read; @@ -10,15 +10,18 @@ use hercules_ir::Module; use hercules_ir::TypeID; use hercules_ir::ID; -pub use juno_scheduler::PassManager; use juno_scheduler::run_schedule_on_hercules; +pub use juno_scheduler::PassManager; pub use crate::interpreter::*; pub use crate::value::*; -// Get a vec of -pub fn into_interp_val(module: &Module, wrapper: InterpreterWrapper, target_ty_id: TypeID) -> InterpreterVal -{ +// Get a vec of +pub fn into_interp_val( + module: &Module, + wrapper: InterpreterWrapper, + target_ty_id: TypeID, +) -> InterpreterVal { match wrapper { InterpreterWrapper::Boolean(v) => InterpreterVal::Boolean(v), InterpreterWrapper::Integer8(v) => InterpreterVal::Integer8(v), @@ -36,31 +39,34 @@ pub fn into_interp_val(module: &Module, wrapper: InterpreterWrapper, target_ty_i InterpreterWrapper::Array(array) => { let ty = &module.types[target_ty_id.idx()]; - let ele_type = ty.try_element_type().expect("PANIC: Invalid parameter type"); - // unwrap -> map to rust type, check - + let ele_type = ty + .try_element_type() + .expect("PANIC: Invalid parameter type"); + // unwrap -> map to rust type, check + let mut values = vec![]; - + for i in 0..array.len() { values.push(into_interp_val(module, array[i].clone(), TypeID::new(0))); } - + InterpreterVal::Array(target_ty_id, values.into_boxed_slice()) } } -} +} -pub fn array_from_interp_val<T: Clone>(module: &Module, interp_val: InterpreterVal) -> Vec<T> - where value::InterpreterVal: Into<T> +pub fn array_from_interp_val<T: Clone>(module: &Module, interp_val: InterpreterVal) -> Vec<T> +where + value::InterpreterVal: Into<T>, { - vec![] + vec![] } // Recursively turns rt args into interpreter wrappers. #[macro_export] macro_rules! parse_rt_args { ($arg:expr) => { - { + { let mut values: Vec<InterpreterWrapper> = vec![]; @@ -70,7 +76,7 @@ macro_rules! parse_rt_args { } }; ( $arg:expr, $($tail_args:expr), +) => { - { + { let mut values: Vec<InterpreterWrapper> = vec![]; values.push($arg.into()); @@ -157,20 +163,19 @@ macro_rules! interp_module { }; } - #[macro_export] macro_rules! interp_file_with_passes { ($path:literal, $dynamic_constants:expr, $passes:expr, $($args:expr), *) => { { let module = parse_file($path); - + let result_before = interp_module!(module, $dynamic_constants, $($args), *); let module = run_schedule_on_hercules(module, None).unwrap(); - let result_after = interp_module!(module, $dynamic_constants, $($args), *); + let result_after = interp_module!(module, $dynamic_constants, $($args), *); assert_eq!(result_after, result_before); } }; -} \ No newline at end of file +} diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index 2ca043c2e7a9b25ab11a6ccd7b1b6ea78e2a9890..c84b48492837df0c356d9c76ec6fcc6f2f21f126 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -215,10 +215,10 @@ impl<'a> InterpreterVal { ) -> InterpreterVal { // If either are undef, propogate undef if let InterpreterVal::Undef(v) = left { - return InterpreterVal::Undef(v) + return InterpreterVal::Undef(v); } if let InterpreterVal::Undef(v) = right { - return InterpreterVal::Undef(v) + return InterpreterVal::Undef(v); } // Do some type conversion first. @@ -862,7 +862,6 @@ impl<'a> InterpreterVal { } } - pub fn as_i128(&self) -> i128 { match *self { InterpreterVal::Boolean(v) => v.try_into().unwrap(), diff --git a/hercules_test/hercules_tests/tests/fork_transform_tests.rs b/hercules_test/hercules_tests/tests/fork_transform_tests.rs index faae39ace8101f34e626bfae7c61bd0304326114..16813b03e7c236684662c5abdb9e4aaea3978951 100644 --- a/hercules_test/hercules_tests/tests/fork_transform_tests.rs +++ b/hercules_test/hercules_tests/tests/fork_transform_tests.rs @@ -4,39 +4,32 @@ use hercules_interpreter::*; use hercules_ir::ID; use juno_scheduler::ir::*; - extern crate rand; -use juno_scheduler::{default_schedule, run_schedule_on_hercules}; -use rand::Rng; use juno_scheduler::pass; - - +use juno_scheduler::{default_schedule, run_schedule_on_hercules}; +use rand::Rng; #[test] fn fission_simple1() { let module = parse_file("../test_inputs/fork_transforms/fork_fission/simple1.hir"); let dyn_consts = [10]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - + let sched = Some(default_schedule![ - Verify, - //Xdot, - Unforkify, - //Xdot, - DCE, - Verify, + Verify, //Xdot, + Unforkify, //Xdot, + DCE, Verify, ]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_2); assert_eq!(result_1, result_2) } - // #[test] // fn fission_simple2() { // let module = parse_file("../test_inputs/fork_transforms/fork_fission/simple2.hir"); @@ -45,7 +38,7 @@ fn fission_simple1() { // let result_1 = interp_module!(module, 0, dyn_consts, 2); // println!("result: {:?}", result_1); - + // let sched: Option<ScheduleStmt> = Some(default_schedule![ // Verify, // ForkFission, @@ -69,7 +62,7 @@ fn fission_simple1() { // let result_1 = interp_module!(module, 0, dyn_consts, 2); // println!("result: {:?}", result_1); - + // let sched: Option<ScheduleStmt> = Some(default_schedule![ // Verify, // ForkFission, @@ -92,7 +85,7 @@ fn fission_simple1() { // let result_1 = interp_module!(module, 0, dyn_consts, 2); // println!("result: {:?}", result_1); - + // let sched: Option<ScheduleStmt> = Some(default_schedule![ // Verify, // ForkFission, @@ -104,4 +97,4 @@ fn fission_simple1() { // let result_2 = interp_module!(module, 0, dyn_consts, 2); // println!("result: {:?}", result_2); // assert_eq!(result_1, result_2) -// } \ No newline at end of file +// } diff --git a/hercules_test/hercules_tests/tests/forkify_tests.rs b/hercules_test/hercules_tests/tests/forkify_tests.rs index 9d123672c9c6b8497c61969eaf6b4aa6f19993a6..025aaad382991761fd4d08354c6e37af8271bf4c 100644 --- a/hercules_test/hercules_tests/tests/forkify_tests.rs +++ b/hercules_test/hercules_tests/tests/forkify_tests.rs @@ -11,52 +11,39 @@ extern crate rand; use juno_scheduler::{default_schedule, run_schedule_on_hercules}; use rand::Rng; - #[test] #[ignore] fn inner_fork_chain() { let module = parse_file("../test_inputs/forkify/inner_fork_chain.hir"); let dyn_consts = [10]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - // let result_1 = interp_module!(module, 0, dyn_consts, 2); + // let result_1 = interp_module!(module, 0, dyn_consts, 2); // println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - - Forkify, - PhiElim, - - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, PhiElim, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_2); // assert_eq!(result_1, result_2) } - #[test] fn loop_simple_iv() { let module = parse_file("../test_inputs/forkify/loop_simple_iv.hir"); let dyn_consts = [10]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_2); assert_eq!(result_1, result_2) } @@ -67,19 +54,15 @@ fn merged_phi_cycle() { let module = parse_file("../test_inputs/forkify/merged_phi_cycle.hir"); let dyn_consts = [10]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_2); assert_eq!(result_1, result_2) } @@ -89,19 +72,15 @@ fn split_phi_cycle() { let module = parse_file("../test_inputs/forkify/split_phi_cycle.hir"); let dyn_consts = [10]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_2); assert_eq!(result_1, result_2) } @@ -111,12 +90,12 @@ fn loop_sum() { let module = parse_file("../test_inputs/forkify/loop_sum.hir"); let dyn_consts = [20]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - + let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); assert_eq!(result_1, result_2); println!("{:?}, {:?}", result_1, result_2); } @@ -126,12 +105,12 @@ fn loop_tid_sum() { let module = parse_file("../test_inputs/forkify/loop_tid_sum.hir"); let dyn_consts = [20]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - + let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); assert_eq!(result_1, result_2); println!("{:?}, {:?}", result_1, result_2); } @@ -142,24 +121,24 @@ fn loop_array_sum() { let len = 5; let dyn_consts = [len]; let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, params.clone()); + let result_1 = interp_module!(module, 0, dyn_consts, params.clone()); println!("result: {:?}", result_1); - + let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, params); + let result_2 = interp_module!(module, 0, dyn_consts, params); assert_eq!(result_1, result_2); println!("{:?}, {:?}", result_1, result_2); } -/** Nested loop 2 is 2 nested loops with different dyn var parameter dimensions. +/** Nested loop 2 is 2 nested loops with different dyn var parameter dimensions. * It is a add of 1 for each iteration, so the result should be dim1 x dim2 * The loop PHIs are structured such that on every outer iteration, inner loop increment is set to the running sum, - * Notice how there is no outer_var_inc. - * - * The alternative, seen in nested_loop1, is to intiailize the inner loop to 0 every time, and track + * Notice how there is no outer_var_inc. + * + * The alternative, seen in nested_loop1, is to intiailize the inner loop to 0 every time, and track * the outer sum more separaetly. - * + * * Idk what im yapping about. */ #[test] @@ -168,14 +147,13 @@ fn nested_loop2() { let len = 5; let dyn_consts = [5, 6]; let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - + let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); assert_eq!(result_1, result_2); - } #[test] @@ -184,20 +162,19 @@ fn super_nested_loop() { let len = 5; let dyn_consts = [5, 10, 15]; let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - + let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); assert_eq!(result_1, result_2); } - /** - * Tests forkify on a loop where there is control in between the continue projection - * and the header. aka control *after* the `loop condition / guard`. This should forkify. + * Tests forkify on a loop where there is control in between the continue projection + * and the header. aka control *after* the `loop condition / guard`. This should forkify. */ #[test] fn control_after_condition() { @@ -212,21 +189,20 @@ fn control_after_condition() { *x = rng.gen::<i32>() / 100; } - let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); + let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); println!("result: {:?}", result_1); - + let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, vec); + let result_2 = interp_module!(module, 0, dyn_consts, vec); assert_eq!(result_1, result_2); - } /** - * Tests forkify on a loop where there is control before the loop condition, so in between the header - * and the loop condition. This should not forkify. - * + * Tests forkify on a loop where there is control before the loop condition, so in between the header + * and the loop condition. This should not forkify. + * * This example is bugged, it reads out of bounds even before forkify. */ #[ignore] @@ -243,21 +219,15 @@ fn control_before_condition() { *x = rng.gen::<i32>() / 100; } - let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); + let result_1 = interp_module!(module, 0, dyn_consts, vec.clone()); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - DCE, - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, vec); + let result_2 = interp_module!(module, 0, dyn_consts, vec); assert_eq!(result_1, result_2); - } #[test] @@ -266,30 +236,20 @@ fn nested_tid_sum() { let len = 5; let dyn_consts = [5, 6]; let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - DCE, - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); assert_eq!(result_1, result_2); - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - DCE, - Verify, - ]); + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_3 = interp_module!(module, 0, dyn_consts, 2); + let result_3 = interp_module!(module, 0, dyn_consts, 2); println!("{:?}, {:?}, {:?}", result_1, result_2, result_3); } @@ -300,54 +260,38 @@ fn nested_tid_sum_2() { let len = 5; let dyn_consts = [5, 6]; let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - DCE, - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); assert_eq!(result_1, result_2); - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - DCE, - Verify, - ]); + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_3 = interp_module!(module, 0, dyn_consts, 2); + let result_3 = interp_module!(module, 0, dyn_consts, 2); println!("{:?}, {:?}, {:?}", result_1, result_2, result_3); } - /** Tests weird control in outer loop for possible 2d fork-join pair. */ #[test] fn inner_fork_complex() { let module = parse_file("../test_inputs/forkify/inner_fork_complex.hir"); let dyn_consts = [5, 6]; let params = vec![1, 2, 3, 4, 5]; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0, dyn_consts, 10); + let result_1 = interp_module!(module, 0, dyn_consts, 10); println!("result: {:?}", result_1); - - let sched: Option<ScheduleStmt> = Some(default_schedule![ - Verify, - Forkify, - DCE, - Verify, - ]); + + let sched: Option<ScheduleStmt> = Some(default_schedule![Verify, Forkify, DCE, Verify,]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 10); + let result_2 = interp_module!(module, 0, dyn_consts, 10); assert_eq!(result_1, result_2); println!("{:?}, {:?}", result_1, result_2); -} \ No newline at end of file +} diff --git a/hercules_test/hercules_tests/tests/interpreter_tests.rs b/hercules_test/hercules_tests/tests/interpreter_tests.rs index e619f18a8a2b7a33dd2b41d5085371924429b05c..69e1920e35815ca0532b5c66bdfd7773515395f0 100644 --- a/hercules_test/hercules_tests/tests/interpreter_tests.rs +++ b/hercules_test/hercules_tests/tests/interpreter_tests.rs @@ -10,27 +10,22 @@ extern crate rand; use juno_scheduler::{default_schedule, run_schedule_on_hercules}; use rand::Rng; - #[test] fn twodeefork() { let module = parse_file("../test_inputs/2d_fork.hir"); let d1 = 2; let d2 = 3; let dyn_consts = [d1, d2]; - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); let sched = Some(default_schedule![ - Verify, - ForkSplit, - //Xdot, - Unforkify, - //Xdot, - DCE, - Verify, + Verify, ForkSplit, //Xdot, + Unforkify, //Xdot, + DCE, Verify, ]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); let res = (d1 as i32 * d2 as i32); let result_2: InterpreterWrapper = res.into(); @@ -44,31 +39,26 @@ fn threedee() { let d2 = 3; let d3 = 5; let dyn_consts = [d1, d2, 5]; - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); let sched = Some(default_schedule![ - Verify, - ForkSplit, - //Xdot, - Unforkify, - //Xdot, - DCE, - Verify, + Verify, ForkSplit, //Xdot, + Unforkify, //Xdot, + DCE, Verify, ]); let module = run_schedule_on_hercules(module, sched).unwrap(); - let result_2 = interp_module!(module, 0, dyn_consts, 2); + let result_2 = interp_module!(module, 0, dyn_consts, 2); let res = (d1 as i32 * d2 as i32 * d3 as i32); let result_2: InterpreterWrapper = res.into(); println!("result: {:?}", result_1); // Should be d1 * d2. } - #[test] fn fivedeefork() { let module = parse_file("../test_inputs/5d_fork.hir"); let dyn_consts = [1, 2, 3, 4, 5]; - let result_1 = interp_module!(module, 0, dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); // Should be 1 * 2 * 3 * 4 * 5; } diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs index 2406360cd4719ec3f4f6a011768cd13c2ff15c2c..29b8692bf787aecff7308208a3c6a6424cd23333 100644 --- a/hercules_test/hercules_tests/tests/loop_tests.rs +++ b/hercules_test/hercules_tests/tests/loop_tests.rs @@ -18,12 +18,11 @@ fn loop_trip_count() { let module = parse_file("../test_inputs/loop_analysis/loop_trip_count.hir"); let dyn_consts = [10]; let params = 2; // TODO: (@xrouth) fix macro to take no params as an option. - let result_1 = interp_module!(module, 0,dyn_consts, 2); + let result_1 = interp_module!(module, 0, dyn_consts, 2); println!("result: {:?}", result_1); } - // Test canonicalization #[test] #[ignore] @@ -31,8 +30,9 @@ fn alternate_bounds_use_after_loop_no_tid() { let len = 1; let dyn_consts = [len]; - let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, 3); + let module = + parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); println!("result: {:?}", result_1); @@ -43,8 +43,8 @@ fn alternate_bounds_use_after_loop_no_tid() { ]; let module = run_schedule_on_hercules(module, Some(schedule)).unwrap(); - - let result_2 = interp_module!(module, 0,dyn_consts, 3); + + let result_2 = interp_module!(module, 0, dyn_consts, 3); println!("{:?}", result_1); println!("{:?}", result_2); @@ -60,7 +60,7 @@ fn alternate_bounds_use_after_loop() { let a = vec![3, 4, 5, 6]; let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, a.clone()); + let result_1 = interp_module!(module, 0, dyn_consts, a.clone()); println!("result: {:?}", result_1); @@ -72,7 +72,7 @@ fn alternate_bounds_use_after_loop() { let module = run_schedule_on_hercules(module, schedule).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, a.clone()); + let result_2 = interp_module!(module, 0, dyn_consts, a.clone()); //println!("{:?}", result_1); println!("{:?}", result_2); @@ -88,7 +88,7 @@ fn alternate_bounds_use_after_loop2() { let a = vec![3, 4, 5, 6]; let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, a.clone()); + let result_1 = interp_module!(module, 0, dyn_consts, a.clone()); println!("result: {:?}", result_1); @@ -98,7 +98,7 @@ fn alternate_bounds_use_after_loop2() { let module = run_schedule_on_hercules(module, schedule).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, a.clone()); + let result_2 = interp_module!(module, 0, dyn_consts, a.clone()); //println!("{:?}", result_1); println!("{:?}", result_2); @@ -119,8 +119,7 @@ fn do_while_separate_body() { let schedule = Some(default_schedule![ ////Xdot,, - PhiElim, - ////Xdot,, + PhiElim, ////Xdot,, Forkify, //Xdot, ]); @@ -140,21 +139,20 @@ fn alternate_bounds_internal_control() { let dyn_consts = [len]; let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_internal_control.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, 3); + let result_1 = interp_module!(module, 0, dyn_consts, 3); println!("result: {:?}", result_1); let schedule = Some(default_schedule![ ////Xdot,, - PhiElim, - ////Xdot,, + PhiElim, ////Xdot,, Forkify, //Xdot, ]); let module = run_schedule_on_hercules(module, schedule).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 3); + let result_2 = interp_module!(module, 0, dyn_consts, 3); println!("{:?}", result_1); println!("{:?}", result_2); @@ -167,21 +165,20 @@ fn alternate_bounds_internal_control2() { let dyn_consts = [len]; let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_internal_control2.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, 3); + let result_1 = interp_module!(module, 0, dyn_consts, 3); println!("result: {:?}", result_1); let schedule = Some(default_schedule![ ////Xdot,, - PhiElim, - ////Xdot,, + PhiElim, ////Xdot,, Forkify, //Xdot, ]); let module = run_schedule_on_hercules(module, schedule).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 3); + let result_2 = interp_module!(module, 0, dyn_consts, 3); println!("{:?}", result_1); println!("{:?}", result_2); @@ -194,13 +191,13 @@ fn alternate_bounds_nested_do_loop() { let dyn_consts = [10, 5]; let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, 3); + let result_1 = interp_module!(module, 0, dyn_consts, 3); println!("result: {:?}", result_1); let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 3); + let result_2 = interp_module!(module, 0, dyn_consts, 3); println!("{:?}", result_1); println!("{:?}", result_2); @@ -213,14 +210,15 @@ fn alternate_bounds_nested_do_loop_array() { let dyn_consts = [10, 5]; let a = vec![4u64, 4, 4, 4, 4]; - let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, a.clone()); + let module = + parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, a.clone()); println!("result: {:?}", result_1); let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, a); + let result_2 = interp_module!(module, 0, dyn_consts, a); println!("{:?}", result_1); println!("{:?}", result_2); @@ -232,14 +230,15 @@ fn alternate_bounds_nested_do_loop_guarded() { let len = 1; let dyn_consts = [3, 2]; - let module = parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, 3); + let module = + parse_file("../test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir"); + let result_1 = interp_module!(module, 0, dyn_consts, 3); println!("result: {:?}", result_1); let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 3); + let result_2 = interp_module!(module, 0, dyn_consts, 3); println!("{:?}", result_1); println!("{:?}", result_2); @@ -249,16 +248,16 @@ fn alternate_bounds_nested_do_loop_guarded() { let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 3); + let result_2 = interp_module!(module, 0, dyn_consts, 3); println!("{:?}", result_1); println!("{:?}", result_2); assert_eq!(result_1, result_2); } -// Tests a do while loop that only iterates once, -// canonicalization *should not* transform this to a while loop, as there is no -// guard that replicates the loop condition. +// Tests a do while loop that only iterates once, +// canonicalization *should not* transform this to a while loop, as there is no +// guard that replicates the loop condition. #[ignore] #[test] fn do_loop_not_continued() { @@ -272,21 +271,21 @@ fn do_loop_not_continued() { // println!("result: {:?}", result_1); } -// Tests a do while loop that is guarded, so should be canonicalized -// It also has +// Tests a do while loop that is guarded, so should be canonicalized +// It also has #[test] fn do_loop_complex_immediate_guarded() { let len = 1; let dyn_consts = [len]; let module = parse_file("../test_inputs/loop_analysis/do_loop_immediate_guard.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, 3); + let result_1 = interp_module!(module, 0, dyn_consts, 3); println!("result: {:?}", result_1); let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 3); + let result_2 = interp_module!(module, 0, dyn_consts, 3); assert_eq!(result_1, result_2); } @@ -298,12 +297,11 @@ fn loop_canonical_sum() { let params = vec![1, 2, 3, 4, 5]; let module = parse_file("../test_inputs/loop_analysis/loop_array_sum.hir"); - let result_1 = interp_module!(module, 0,dyn_consts, params); + let result_1 = interp_module!(module, 0, dyn_consts, params); println!("result: {:?}", result_1); } - #[test] #[ignore] fn antideps_pipeline() { @@ -312,13 +310,13 @@ fn antideps_pipeline() { // FIXME: This path should not leave the crate let module = parse_module_from_hbin("../../juno_samples/antideps/antideps.hbin"); - let result_1 = interp_module!(module, 0,dyn_consts, 9i32); + let result_1 = interp_module!(module, 0, dyn_consts, 9i32); println!("result: {:?}", result_1); let module = run_schedule_on_hercules(module, None).unwrap(); - let result_2 = interp_module!(module, 0,dyn_consts, 9i32); + let result_2 = interp_module!(module, 0, dyn_consts, 9i32); assert_eq!(result_1, result_2); } @@ -330,8 +328,8 @@ fn implicit_clone_pipeline() { // FIXME: This path should not leave the crate let module = parse_module_from_hbin("../../juno_samples/implicit_clone/out.hbin"); - let result_1 = interp_module!(module, 0,dyn_consts, 2u64, 2u64); - + let result_1 = interp_module!(module, 0, dyn_consts, 2u64, 2u64); + println!("result: {:?}", result_1); let schedule = default_schedule![ ////Xdot,, @@ -359,8 +357,8 @@ fn implicit_clone_pipeline() { GCM, ]; let module = run_schedule_on_hercules(module, Some(schedule)).unwrap(); - - let result_2 = interp_module!(module, 0,dyn_consts, 2u64, 2u64); + + let result_2 = interp_module!(module, 0, dyn_consts, 2u64, 2u64); assert_eq!(result_1, result_2); } @@ -382,7 +380,9 @@ fn look_at_local() { } } - let module = parse_module_from_hbin("/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin"); + let module = parse_module_from_hbin( + "/home/xavierrouth/dev/hercules/hercules_test/hercules_tests/save_me.hbin", + ); let schedule = Some(default_schedule![ ////Xdot,, @@ -394,15 +394,14 @@ fn look_at_local() { let schedule = Some(default_schedule![ ////Xdot,, - Unforkify, - Verify, + Unforkify, Verify, ////Xdot,, ]); - + let module = run_schedule_on_hercules(module.clone(), schedule).unwrap(); let result_2 = interp_module!(module, 0, dyn_consts, a.clone(), b.clone()); - + println!("golden: {:?}", correct_c); println!("result: {:?}", result_2); } @@ -410,19 +409,21 @@ fn look_at_local() { #[ignore] fn matmul_pipeline() { let len = 1; - + const I: usize = 4; const J: usize = 4; const K: usize = 4; let a: Vec<i32> = (0i32..(I * J) as i32).map(|v| v + 1).collect(); - let b: Vec<i32> = ((I * J) as i32..(J * K) as i32 + (I * J) as i32).map(|v| v + 1).collect(); + let b: Vec<i32> = ((I * J) as i32..(J * K) as i32 + (I * J) as i32) + .map(|v| v + 1) + .collect(); let a: Vec<i32> = (0..I * J).map(|_| random::<i32>() % 100).collect(); let b: Vec<i32> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let dyn_consts = [I, J, K]; // FIXME: This path should not leave the crate let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin"); - // + // let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { @@ -437,27 +438,22 @@ fn matmul_pipeline() { println!("golden: {:?}", correct_c); println!("result: {:?}", result_1); - let InterpreterVal::Array(_, d) = result_1.clone() else {panic!()}; - let InterpreterVal::Integer32(value) = d[0] else {panic!()}; + let InterpreterVal::Array(_, d) = result_1.clone() else { + panic!() + }; + let InterpreterVal::Integer32(value) = d[0] else { + panic!() + }; assert_eq!(correct_c[0], value); - let schedule = Some(default_schedule![ - ////Xdot,, - ForkSplit, - ////Xdot,, - ]); - + let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]); + module = run_schedule_on_hercules(module, schedule).unwrap(); let result_2 = interp_module!(module, 1, dyn_consts, a.clone(), b.clone()); println!("result: {:?}", result_2); - assert_eq!(result_1, result_2); - - - - - + assert_eq!(result_1, result_2); // Verify, // GVN, @@ -473,4 +469,4 @@ fn matmul_pipeline() { // FloatCollections, // GCM, // //Xdot, -} \ No newline at end of file +} diff --git a/hercules_test/hercules_tests/tests/opt_tests.rs b/hercules_test/hercules_tests/tests/opt_tests.rs index f994f447d213e7afc44c5ebf76c3ec2cc52e95ca..2f85b78b6103d84d7aacb9f35e34b834aea56005 100644 --- a/hercules_test/hercules_tests/tests/opt_tests.rs +++ b/hercules_test/hercules_tests/tests/opt_tests.rs @@ -3,9 +3,8 @@ use std::env; use rand::Rng; use hercules_interpreter::*; -use juno_scheduler::*; use hercules_ir::ID; - +use juno_scheduler::*; // #[test] // fn matmul_int() { @@ -79,7 +78,7 @@ use hercules_ir::ID; // let x: i32 = rand::random(); // let x = x / 32; // let y: i32 = rand::random(); -// let y = y / 32; // prevent overflow, +// let y = y / 32; // prevent overflow, // let result_1 = interp_module!(module, 0, dyn_consts, x, y); // let mut pm = hercules_opt::pass::PassManager::new(module.clone()); @@ -147,7 +146,6 @@ use hercules_ir::ID; // let module = pm.get_module(); // let result_2 = interp_module!(module, 0, dyn_consts, vec); - // assert_eq!(result_1, result_2) // } @@ -192,8 +190,8 @@ use hercules_ir::ID; // #[test] // fn sum_int2_smaller() { -// interp_file_with_passes!("../test_inputs/sum_int2.hir", -// [100], +// interp_file_with_passes!("../test_inputs/sum_int2.hir", +// [100], // vec![ // Pass::Verify, // Pass::CCP, diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 73a75a94f67edfab905c8c3830191dc342da337f..8ad6824f01d4d94a76f088ba1540ba44d2ce7b71 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -59,7 +59,10 @@ fn run_cava( tonemap, ) .await - }).as_slice::<u8>().to_vec().into_boxed_slice() + }) + .as_slice::<u8>() + .to_vec() + .into_boxed_slice() } enum Error { diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index c3ba785e352826362c99760779f1de9001be63d9..511bf483099ba78cc62754b39146517aa3623103 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -4,8 +4,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() - // .schedule_in_src("sched.sch") - // .unwrap() + //.schedule_in_src("sched.sch") + //.unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 6d3b6624252d8d67e8dbb5fb2381e0c7bf4f9618..e40c429d757169a9ccd9d3abce3d52fe5899108f 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -24,10 +24,14 @@ fn main() { let a = HerculesCPURef::from_slice(&a); let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); let mut r = runner!(tiled_2_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); }); } @@ -36,4 +40,3 @@ fn main() { fn matmul_test() { main(); } - diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index ee2d0bd684f46b3fce53f7f8e6dc1cd8e0adf2ba..0b3264ac26124c459aa61ac9454f1387ccda4990 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -105,7 +105,9 @@ impl FromStr for Appliable { "forkify" => Ok(Appliable::Pass(ir::Pass::Forkify)), "gcm" | "bbs" => Ok(Appliable::Pass(ir::Pass::GCM)), "gvn" => Ok(Appliable::Pass(ir::Pass::GVN)), - "loop-canon" | "loop-canonicalization" => Ok(Appliable::Pass(ir::Pass::LoopCanonicalization)), + "loop-canon" | "loop-canonicalization" => { + Ok(Appliable::Pass(ir::Pass::LoopCanonicalization)) + } "infer-schedules" => Ok(Appliable::Pass(ir::Pass::InferSchedules)), "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { @@ -122,6 +124,7 @@ impl FromStr for Appliable { "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), + "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs index 88d55b33b1e4073a5e4e0eb4e591fd3317801530..fd45a3713dfa48ac425a0c453707e133198e1d79 100644 --- a/juno_scheduler/src/default.rs +++ b/juno_scheduler/src/default.rs @@ -66,8 +66,9 @@ pub fn default_schedule() -> ScheduleStmt { DCE, GVN, DCE, - /*Forkify,*/ - /*ForkGuardElim,*/ + // Forkify, + // ForkGuardElim, + // ForkCoalesce, DCE, ForkSplit, Unforkify, @@ -83,6 +84,5 @@ pub fn default_schedule() -> ScheduleStmt { DCE, FloatCollections, GCM, - ] } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9c705c1cf83a254e74ff163e2b83edc50811b578..33a7b4807e3d0e3a6ec425f4a15951a7086138cd 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -512,7 +512,7 @@ impl PassManager { typing: _, control_subgraphs: _, bbs: _, - collection_objects:_, + collection_objects: _, callgraph: _, .. } = self; @@ -1299,17 +1299,17 @@ fn run_pass( let output_file = "out.hbin"; let module = pm.clone().get_module().clone(); let module_contents: Vec<u8> = postcard::to_allocvec(&module).unwrap(); - let mut file = File::create(&output_file) - .expect("PANIC: Unable to open output module file."); + let mut file = + File::create(&output_file).expect("PANIC: Unable to open output module file."); file.write_all(&module_contents) .expect("PANIC: Unable to write output module file contents."); } Pass::ForkSplit => { assert!(args.is_empty()); // FIXME: I'm not sure if this is the correct way to build fixpoint into the PM, - // i.e cloning selection. Does something need to be done to propagate labels between iterations + // i.e cloning selection. Does something need to be done to propagate labels between iterations // of this loop? - + loop { let mut inner_changed = false; pm.make_fork_join_maps(); @@ -1332,7 +1332,6 @@ fn run_pass( pm.clear_analyses(); if !inner_changed { - break; } } @@ -1345,11 +1344,12 @@ fn run_pass( let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); - for (((func, fork_join_map), loop_nest), control_subgraph) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - .zip(control_subgraphs.iter()) + for (((func, fork_join_map), loop_nest), control_subgraph) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) { let Some(mut func) = func else { continue; @@ -1700,11 +1700,12 @@ fn run_pass( let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); - for (((func, fork_join_map), loop_nest), control_subgraph) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - .zip(control_subgraphs.iter()) + for (((func, fork_join_map), loop_nest), control_subgraph) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) { let Some(mut func) = func else { continue; @@ -1714,7 +1715,7 @@ fn run_pass( } pm.delete_gravestones(); pm.clear_analyses(); - }, + } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection) { @@ -1794,12 +1795,13 @@ fn run_pass( let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let typing = pm.typing.take().unwrap(); - for ((((func, fork_join_map), loop_nest), control_subgraph), typing) in build_selection(pm, selection) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - .zip(control_subgraphs.iter()) - .zip(typing.iter()) + for ((((func, fork_join_map), loop_nest), control_subgraph), typing) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(control_subgraphs.iter()) + .zip(typing.iter()) { let Some(mut func) = func else { continue;