diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index f6a00c8582b8289062d353bc06a5b65e32daac19..e6db74596fbb85f03333742dc1e8e035a6a9f821 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -4,8 +4,6 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::mem::take; use std::ops::Deref; -use nestify::nest; - use bitvec::prelude::*; use either::Either; @@ -156,10 +154,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.modified } - pub fn node(&self, node: impl Borrow<NodeID>) -> &Node { - &self.function.nodes[node.borrow().idx()] - } - pub fn edit<F>(&'b mut self, edit: F) -> bool where F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>, @@ -342,6 +336,10 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.function_id } + pub fn node(&self, node: impl Borrow<NodeID>) -> &Node { + &self.function.nodes[node.borrow().idx()] + } + pub fn get_types(&self) -> Ref<'_, Vec<Type>> { self.types.borrow() } @@ -363,7 +361,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { .as_ref() .into_iter() .map(|x| *x) - .collect::<Vec<_>>() // @(xrouth): wtf??? + .collect::<Vec<_>>() .into_iter() } diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index edf26911aa4a43854842d986d794638dacdc7d5a..5a6d5ff26d34effab7c3e581b8c767a75851937f 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -107,7 +107,7 @@ pub fn find_reduce_dependencies<'a>( }) .collect(); - ret_val + ret_val } pub fn copy_subgraph( @@ -119,7 +119,7 @@ pub fn copy_subgraph( Vec<(NodeID, NodeID)>, ) // returns all new nodes, a map from old nodes to new nodes, and // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node, - // outside means not part of the original subgraph. + // outside means not part of the original subgraph. { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); let mut new_nodes: HashSet<NodeID> = HashSet::new(); @@ -395,7 +395,7 @@ pub fn fork_coalesce( }); let fork_joins: Vec<_> = fork_joins.collect(); - // FIXME: Add a postorder traversal to optimize this. + // FIXME: Add a postorder traversal to optimize this. // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. @@ -486,7 +486,7 @@ pub fn fork_coalesce_helper( return false; } - // Checklist: + // Checklist: // Increment inner TIDs // Add outer fork's dimension to front of inner fork. // Fuse reductions diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 10a8fe215985f6c07c791e846984846defd43cf4..fd4fc838bae662c2d411ec804198b483bc183bb8 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -241,35 +241,52 @@ pub fn forkify_loop( return false; } - let phi_latches: Vec<_> = reductionable_phis.iter().map(|phi| { - let LoopPHI::Reductionable { phi, data_cycle, continue_latch, is_associative } = phi else {unreachable!()}; - continue_latch - }).collect(); + let phi_latches: Vec<_> = reductionable_phis + .iter() + .map(|phi| { + let LoopPHI::Reductionable { + phi, + data_cycle, + continue_latch, + is_associative, + } = phi + else { + unreachable!() + }; + continue_latch + }) + .collect(); - let stop_on: HashSet<_> = editor.node_ids().filter(|node| { - if editor.node(node).is_phi() { - return true; - } - if editor.node(node).is_reduce() { - return true; - } - if editor.node(node).is_control() { - return true; - } - if phi_latches.contains(&node) { - return true; - } + let stop_on: HashSet<_> = editor + .node_ids() + .filter(|node| { + if editor.node(node).is_phi() { + return true; + } + if editor.node(node).is_reduce() { + return true; + } + if editor.node(node).is_control() { + return true; + } + if phi_latches.contains(&node) { + return true; + } + + false + }) + .collect(); - false - }).collect(); - - // Outside loop users of IV, then exit; - // Unless the outside user is through the loop latch of a reducing phi, + // Unless the outside user is through the loop latch of a reducing phi, // then we know how to replace this edge, so its fine! - let iv_users: Vec<_> = walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect(); - - if iv_users.iter().any(|node| !loop_nodes.contains(&node) && *node != loop_if) { + let iv_users: Vec<_> = + walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect(); + + if iv_users + .iter() + .any(|node| !loop_nodes.contains(&node) && *node != loop_if) + { return false; } @@ -429,9 +446,9 @@ impl LoopPHI { /** Checks some conditions on loop variables that will need to be converted into reductions to be forkified. - - The phi is in a cycle *in the loop* with itself. + - The phi is in a cycle *in the loop* with itself. - Every cycle *in the loop* containing the phi does not contain any other phi of the loop header. - - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header. + - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header. */ pub fn analyze_phis<'a>( editor: &'a FunctionEditor, @@ -473,7 +490,7 @@ pub fn analyze_phis<'a>( return false; }) .collect(); - + let continue_idx = editor .get_uses(natural_loop.header) .position(|node| natural_loop.control[node.idx()]) @@ -512,10 +529,9 @@ pub fn analyze_phis<'a>( return false; }) .collect(); - - - let mut uses_for_dependance = walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on); + let mut uses_for_dependance = + walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on); let set1: HashSet<_> = HashSet::from_iter(uses); let set2: HashSet<_> = HashSet::from_iter(users); @@ -526,19 +542,16 @@ pub fn analyze_phis<'a>( // we use `phis` because this phi can actually contain the loop iv and its fine. if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) { LoopPHI::LoopDependant(*phi) - } - else if intersection.clone().iter().any(|node| true) { - - + } else if intersection.clone().iter().any(|node| true) { // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined // by the time the reduce is triggered (at the end of the loop's internal control). // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. - // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. + // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection .iter() - .filter(|node| **node != loop_continue_latch ) + .filter(|node| **node != loop_continue_latch) .filter(|node| !(editor.node(*node).is_reduce() || editor.node(*node).is_phi())) .any(|data_node| { editor @@ -553,8 +566,8 @@ pub fn analyze_phis<'a>( return LoopPHI::LoopDependant(*phi); } - // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify - // i.e as described above. + // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify + // i.e as described above. let is_associative = false; // No nodes in the data cycle are used outside of the loop, besides the latched value of the phi diff --git a/hercules_opt/src/ivar.rs b/hercules_opt/src/ivar.rs index 1f31e22088f170e2726241fc8796a3b496e81af3..15f9416c1a203b7cbd55b5d1515538e41714cb2c 100644 --- a/hercules_opt/src/ivar.rs +++ b/hercules_opt/src/ivar.rs @@ -139,7 +139,11 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has }) .collect(); - all_users.intersection(&all_uses).chain(phis.iter()).cloned().collect() + all_users + .intersection(&all_uses) + .chain(phis.iter()) + .cloned() + .collect() } /** returns PHIs that are on any regions inside the loop. */ diff --git a/hercules_test/hercules_tests/tests/fork_transform_tests.rs b/hercules_test/hercules_tests/tests/fork_transform_tests.rs index 432fdda029e0b1fec52cd20857430df9ddd5387d..3799ca0ac7e8abe9907603269692fbd438c4e33d 100644 --- a/hercules_test/hercules_tests/tests/fork_transform_tests.rs +++ b/hercules_test/hercules_tests/tests/fork_transform_tests.rs @@ -18,7 +18,7 @@ fn fission_simple1() { println!("result: {:?}", result_1); let sched = Some(default_schedule![ - Verify, //Xdot, + Verify, //Xdot, Unforkify, //Xdot, DCE, Verify, ]); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index fa5d1f04d48cdf48cf377e8f3d08de80d30e688e..624ee5652a78d9c2ab7bc84d3974bf2df5b02838 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -24,10 +24,14 @@ fn main() { let a = HerculesCPURef::from_slice(&a); let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(c.as_slice::<i32>(), &*correct_c); let mut r = runner!(tiled_64_matmul); - let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; + let tiled_c = r + .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) + .await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); }); } @@ -36,4 +40,3 @@ fn main() { fn matmul_test() { main(); } -