From b2d0899df264c2081a979798311877bd70c81632 Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Thu, 30 Jan 2025 00:53:03 -0600 Subject: [PATCH] forkify iv use condition refined --- hercules_opt/src/fork_transforms.rs | 47 +--- hercules_opt/src/forkify.rs | 136 ++++-------- hercules_opt/src/ivar.rs | 205 +----------------- .../hercules_interpreter/src/interpreter.rs | 3 - juno_samples/matmul/src/main.rs | 17 +- juno_samples/matmul/src/matmul.jn | 38 ++-- juno_samples/matmul/src/sched.sch | 76 +++++++ 7 files changed, 167 insertions(+), 355 deletions(-) create mode 100644 juno_samples/matmul/src/sched.sch diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 14145f57..c0196ca0 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -97,7 +97,7 @@ pub fn find_reduce_dependencies<'a>( recurse(function, reduce, fork, &mut depdendent, &mut visited); // Return node IDs that are dependent - let a: Vec<_> = depdendent + let ret_val: Vec<_> = depdendent .iter() .enumerate() .filter_map(|(idx, dependent)| { @@ -109,7 +109,7 @@ pub fn find_reduce_dependencies<'a>( }) .collect(); - a + ret_val } pub fn copy_subgraph( @@ -119,7 +119,9 @@ pub fn copy_subgraph( HashSet<NodeID>, HashMap<NodeID, NodeID>, Vec<(NodeID, NodeID)>, -) // set of all nodes, set of new nodes, outside node. s.t old node-> outside node exists as an edge. +) // returns all new nodes, a map from old nodes to new nodes, and + // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node, + // outside means not part of the original subgraph. { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); let mut new_nodes: HashSet<NodeID> = HashSet::new(); @@ -314,25 +316,9 @@ pub fn fork_reduce_fission_helper<'a>( fork: NodeID, ) -> (NodeID, NodeID) { - // returns Fork, Join pair { - let join = fork_join_map[&fork]; - // If there is control in between then j give up. let mut new_control_pred: NodeID = original_control_pred; - - // Get nodes to copy - // let factors: Box<[DynamicConstantID]> = edit..nodes[fork.idx()].try_fork().unwrap().1.into(); - - // None of this matters, just assume we have DCE for control flow. - // Make new fork put it after the existing loop (deal with dependencies later.) - // Make new join, put it after fork (FIXME: THIS IS WRONG) - // Make copies of all control + data nodes, including the reduce and join, with equivalent uses / users, mark them as NEW - // - Need an editor utility to copy a subsection of the graph. - // 1) Edges going into the subsection stay the same, i.e something new still *uses* something old. - // 2) Edges leaving the subsection need to be handled by the user, (can't force outgoing new edges into nodes) - // return a list of outgoing (but unattatached) edges + the old destination to the programmer. - // Important edges are: Reduces, // NOTE: @@ -341,17 +327,6 @@ pub fn fork_reduce_fission_helper<'a>( // - we can simply refuse // - or we can duplicate B - // OR we can allow reduces to end up in multiple forks, (no restrictions on the reduce->fork mapping function). - // And complain when user doesn't put them in the same fork correctly. - // for now, DONT HANDLE IT. LOL. - - // NOTE: - // - - // Replace all - // Replace all uses of (fork, reduce, ) w/ predicate that they are the newly copied nodes. - // repalce uses - let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); @@ -422,10 +397,10 @@ pub fn fork_coalesce( }); let fork_joins: Vec<_> = fork_joins.collect(); - // FIXME: postorder traversal. + // FIXME: Add a postorder traversal to optimize this. - // Fixme: This could give us two forks that aren't actually ancestors / related, but then the helper will just retunr false early. - //for (inner, outer) in fork_joins.windows(2) { + // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. + // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { return true; @@ -513,11 +488,11 @@ pub fn fork_coalesce_helper( return false; } + // Checklist: // Increment inner TIDs - // Add outers dimension to front of inner fork. + // Add outer fork's dimension to front of inner fork. // Fuse reductions // - Initializer becomes outer initializer - // - // Replace uses of outer fork w/ inner fork. // Replace uses of outer join w/ inner join. // Delete outer fork-join @@ -532,7 +507,7 @@ pub fn fork_coalesce_helper( let num_outer_dims = outer_dims.len(); let mut new_factors = outer_dims.to_vec(); - // CHECK ME: Might need to be added the other way. + // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); for tid in inner_tids { diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index c7acfe6b..abd0aaca 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -2,6 +2,7 @@ extern crate bitvec; extern crate hercules_ir; extern crate nestify; +use core::panic; use std::collections::HashMap; use std::collections::HashSet; use std::iter::zip; @@ -26,7 +27,6 @@ use crate::walk_all_users; use crate::walk_all_users_stop_on; use crate::walk_all_uses; use crate::walk_all_uses_stop_on; -use crate::BasicInductionVariable; use crate::DenseNodeMap; use crate::FunctionEditor; use crate::InductionVariable; @@ -212,7 +212,7 @@ pub fn forkify_loop( // we currently have. let loop_nodes = calculate_loop_nodes(editor, l); - // // Check reductionable phis, only PHIs depending on the loop are considered, + // Check phis to see if they are reductionable, only PHIs depending on the loop are considered, let candidate_phis: Vec<_> = editor .get_users(l.header) .filter(|id| function.nodes[id.idx()].is_phi()) @@ -223,21 +223,9 @@ pub fn forkify_loop( .into_iter() .collect(); - // START EDITING - - // What we do is: - // 1) Find a (the?) basic induction variable, create a ThreadID + Fork + Join over it. - // 2) Turn reductionable PHIs into reduces (including the redcutionable PHI) - // - a) If the PHI is the IV: - // Uses of the IV become: - // 1) Inside the loop: Uses of the ThreadID - // 2) Outside the loop: Uses of the reduction node. - // - b) if the PHI is not the IV: - // Make it a reduce - let function = editor.func(); - // TOOD: Handle multiple loop body lasts. + // TODO: Handle multiple loop body lasts. // If there are multiple candidates for loop body last, return false. if editor .get_uses(loop_if) @@ -257,23 +245,41 @@ pub fn forkify_loop( return false; } - // 1) If there is any control between header and loop condition, exit. - let header_control_users: Vec<_> = editor - .get_users(l.header) - .filter(|id| function.nodes[id.idx()].is_control()) - .collect(); + let phi_latches: Vec<_> = reductionable_phis.iter().map(|phi| { + let LoopPHI::Reductionable { phi, data_cycle, continue_latch, is_associative } = phi else {unreachable!()}; + continue_latch + }).collect(); - // Outside uses of IV, then exit; - if editor - .get_users(canonical_iv.phi()) - .any(|node| !loop_nodes.contains(&node)) - { + let stop_on: HashSet<_> = editor.node_ids().filter(|node| { + if editor.node(node).is_phi() { + return true; + } + if editor.node(node).is_reduce() { + return true; + } + if editor.node(node).is_control() { + return true; + } + if phi_latches.contains(&node) { + return true; + } + + false + }).collect(); + + + // Outside loop users of IV, then exit; + // Unless the outside user is through the loop latch of a reducing phi, + // then we know how to replace this edge, so its fine! + let iv_users: Vec<_> = walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect(); + + if iv_users.iter().any(|node| !loop_nodes.contains(&node) && *node != loop_if) { return false; } // Start Transformation: - // Graft everyhting between header and loop condition + // Graft everything between header and loop condition // Attach join to right before header (after loop_body_last, unless loop body last *is* the header). // Attach fork to right after loop_continue_projection. @@ -285,7 +291,7 @@ pub fn forkify_loop( let bound_dc_id = { let mut max_id = DynamicConstantID::new(0); editor.edit(|mut edit| { - // FIXME: Maybe add dynamic constant should intern? + // FIXME: Maybe add_dynamic_constant should intern? let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1)); max_id = edit.add_dynamic_constant(DynamicConstant::Max(one_id, bound_dc_id)); Ok(edit) @@ -293,7 +299,7 @@ pub fn forkify_loop( max_id }; - // // FIXME (@xrouth), handle control in loop body. + // FIXME: (@xrouth) double check handling of control in loop body. editor.edit(|mut edit| { let fork = Node::Fork { control: loop_pred, @@ -314,21 +320,6 @@ pub fn forkify_loop( Ok(edit) }); - // let function = editor.func(); - - // let update = *zip( - // editor.get_uses(l.header), - // function.nodes[canonical_iv.phi().idx()] - // .try_phi() - // .unwrap() - // .1 - // .iter(), - // ) - // .filter(|(c, _)| *c == loop_body_last) - // .next() - // .unwrap() - // .1; - let function = editor.func(); let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); let dimension = factors.len() - 1; @@ -341,15 +332,6 @@ pub fn forkify_loop( }; let thread_id_id = edit.add_node(thread_id); - // let iv_reduce = Node::Reduce { - // control: join_id, - // init: basic_iv.initializer, - // reduct: update, - // }; - - // If a user occurs after the loop is finished, we replace it with the DC that is the IV bound, - // If a user occurs inside the loop, we replace it with the IV. - // Replace uses that are inside with the thread id edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { loop_nodes.contains(node) @@ -372,7 +354,7 @@ pub fn forkify_loop( is_associative, } = reduction_phi else { - continue; + panic!(); }; let function = editor.func(); @@ -451,11 +433,10 @@ impl LoopPHI { /** Checks some conditions on loop variables that will need to be converted into reductions to be forkified. - To convert a phi into a reduce we need to check that every cycle containing the PHI does not contain any other PHI. -I think this restriction can be loosened (more specified) - - Every cycle *in the loop* containing the PHI does not contain any other PHI. Or something, IDK. - - -We also need to make it not control dependent on anything other than the loop header. */ + - The phi is in a cycle *in the loop* with itself. + - Every cycle *in the loop* containing the phi does not contain any other phi of the loop header. + - The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header. + */ pub fn analyze_phis<'a>( editor: &'a FunctionEditor, natural_loop: &'a Loop, @@ -473,9 +454,6 @@ pub fn analyze_phis<'a>( if *control != natural_loop.header { return true; } - // if !natural_loop.control[control.idx()] { - // return true; - // } } // External Reduce if let Node::Reduce { @@ -491,9 +469,8 @@ pub fn analyze_phis<'a>( } } - // External Control + // Data Cycles Only if data.is_control() { - //&& !natural_loop.control[node.idx()] { return true; } @@ -503,11 +480,6 @@ pub fn analyze_phis<'a>( // TODO: We may need to stop on exiting the loop for looking for data cycles. let uses = walk_all_uses_stop_on(*phi, editor, stop_on.clone()); - // .filter(|node| - // { - // // Get rid of nodes in stop_on - // !stop_on.contains(node) - // }); let users = walk_all_users_stop_on(*phi, editor, stop_on.clone()); let other_stop_on: HashSet<NodeID> = editor @@ -531,7 +503,6 @@ pub fn analyze_phis<'a>( // External Control if data.is_control() { - //&& !natural_loop.control[node.idx()] { return true; } @@ -551,11 +522,6 @@ pub fn analyze_phis<'a>( if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) { LoopPHI::LoopDependant(*phi) } - // // If this phi is used by other phis in the loop, FIXME: include reduces, they are the same as phis right? - // // DOn't go through nodes that would become a reduction. - // else if set2.clone().iter().any(|node| phis.contains(node) && node != phi ) { - // LoopPHI::UsedByDependant(*phi) - // } else if intersection.clone().iter().any(|node| true) { let continue_idx = editor .get_uses(natural_loop.header) @@ -564,16 +530,12 @@ pub fn analyze_phis<'a>( let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx]; - // Phis on the frontier of the intersection, i.e in uses_for_dependance need - // to have headers + // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need + // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined + // by the time the reduce is triggered (at the end of the loop's internal control). - // FIXME: Need to postdominate the loop continue latch - // The phi's region needs to postdominate all PHI / Reduceses (that are in the control of the loop, i.e that or uses of the loop_continue_latch) - // that it uses, not going through phis / reduces, - // - - // let uses = // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch. + // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce. if intersection .iter() .filter(|node| **node != loop_continue_latch) @@ -590,14 +552,8 @@ pub fn analyze_phis<'a>( return LoopPHI::LoopDependant(*phi); } - // if tehre are separate types of ops, or any non associative ops, then its not associative - - // Extract ops - // let is_associative = intersection.iter().filter_map(|node| match editor.node(node) { - // Node::Unary { input, op } => todo!(), - // Node::Binary { left, right, op } => todo!(), - // Node::Ternary { first, second, third, op } => todo!(), - // }); + // FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify + // i.e as described above. let is_associative = false; // No nodes in the data cycle are used outside of the loop, besides the latched value of the phi diff --git a/hercules_opt/src/ivar.rs b/hercules_opt/src/ivar.rs index 7f76b0f5..bde3bde3 100644 --- a/hercules_opt/src/ivar.rs +++ b/hercules_opt/src/ivar.rs @@ -25,12 +25,7 @@ use self::hercules_ir::ir::*; use crate::*; -/** - * This represents induction vairable analysis, to be used by forkify! - */ -/* ASIDE: (@xrouth) I want a word for something that can be 'queried', but doesn't reveal anything about the underlying data structure, -single loop only... */ #[derive(Debug)] pub struct LoopVarianceInfo { @@ -60,19 +55,6 @@ impl Loop { all_loop_nodes } } -nest! { -/** Represents a basic induction variable. - NOTE (@xrouth): May switch to using SCEV to represent induction vairables, for now we assume only basic induction variables - with a constant update (not even linear). Eventually add dynamic constant updates, and linear updates - */ -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct BasicInductionVariable { - pub node: NodeID, - pub initializer: NodeID, - pub update: NodeID, - pub final_value: Option<NodeID>, -} -} // nest nest! { #[derive(Clone, Copy, Debug, PartialEq)]* @@ -83,9 +65,7 @@ nest! { update: NodeID, final_value: Option<NodeID>, }, - SCEV(NodeID), - //ScevAdd(NodeID, NodeID), - // ScevMul(NodeID, NodeID), + SCEV(NodeID), // TODO @(xrouth) } } @@ -101,30 +81,8 @@ impl InductionVariable { InductionVariable::SCEV(_) => todo!(), } } - - // Editor has become just a 'context' that everything needs. This is similar to how analyses / passes are structured, - // but editor forces recomputation / bookkeeping of simple / more commonly used info (even though it really is just def use, constants, dyn_constants) - // While if the pass writer wants more complicated info, like analyses results, they have to thread it through the pass manager. - // This seems fine. - // pub fn update_i64(&self, editor: &FunctionEditor) -> Option<i64> { - // match self { - // InductionVariable::Basic { node, initializer, update, final_value } => { - // match editor.node(update) { - // Node::Constant {id } => match *editor.get_constant(*id) { - // Constant::UnsignedInteger64(v) => v.try_into().ok(), - // _ => None, - // }, - // _ => None, - // } - // }, - // InductionVariable::SCEV(node_id) => todo!(), - // } - // } - - // It would be nice for functions, as they (kinda) automatically capture 'self' to also automatically capture a 'context' that is in the same scope, - // so I don't have to keep passing a context into every function that needs one. - // } + // TODO: Optimize. pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> HashSet<NodeID> { // Stop on PHIs / reduces outside of loop. @@ -170,11 +128,6 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has }) .collect(); - // let all_users: HashSet<_> = editor.get_users(natural_loop.header).filter(|node| editor.func().nodes[node.idx()].is_phi()) - // .flat_map(|phi| walk_all_users_stop_on(phi, editor, stop_on.clone())) - // .chain(editor.get_users(natural_loop.header).filter(|node| editor.func().nodes[node.idx()].is_phi())) - // .collect(); - let all_users: HashSet<NodeID> = phis .clone() .iter() @@ -186,26 +139,17 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has .clone() .iter() .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone())) - .chain(phis) + .chain(phis.clone()) .filter(|node| { // Get rid of nodes in stop_on !stop_on.contains(node) }) .collect(); - // let all_uses: HashSet<_> = editor.get_users(natural_loop.header).filter(|node| editor.func().nodes[node.idx()].is_phi()) - // .flat_map(|phi| walk_all_uses_stop_on(phi, editor, stop_on.clone())) - // .chain(editor.get_users(natural_loop.header).filter(|node| editor.func().nodes[node.idx()].is_phi())) - // .filter(|node| - // { - // // Get rid of nodes in stop_on - // !stop_on.contains(node) - // }) - // .collect(); - - all_users.intersection(&all_uses).cloned().collect() + + all_users.intersection(&all_uses).chain(phis.iter()).cloned().collect() } -/** returns PHIs that are *in* a loop */ +/** returns PHIs that are on any regions inside the loop. */ pub fn get_all_loop_phis<'a>( function: &'a Function, l: &'a Loop, @@ -323,7 +267,7 @@ pub enum LoopExit { if_node: NodeID, condition_node: NodeID, }, - Unconditional(NodeID) // Probably a region. + Unconditional(NodeID) } } @@ -335,6 +279,7 @@ pub fn get_loop_exit_conditions( // impl IntoIterator<Item = LoopExit> // DFS Traversal on loop control subgraph until we find a node that is outside the loop, find the last IF on this path. let mut last_if_on_path: DenseNodeMap<Option<NodeID>> = vec![None; function.nodes.len()]; + // FIXME: (@xrouth) THIS IS MOST CERTAINLY BUGGED // this might be bugged... i.e might need to udpate `last if` even if already defined. // needs to be `saturating` kinda, more iterative. May need to visit nodes more than once? @@ -380,140 +325,6 @@ pub fn get_loop_exit_conditions( }) } -pub fn match_canonicalization_bound( - editor: &mut FunctionEditor, - natural_loop: &Loop, - loop_condition: NodeID, - loop_if: NodeID, - ivar: BasicInductionVariable, -) -> Option<NodeID> { - // Match for code generated by loop canon - let Node::Phi { control, data } = &editor.func().nodes[loop_condition.idx()] else { - unreachable!() - }; - - if *control != natural_loop.header { - return None; - } - - let continue_idx = editor - .get_uses(natural_loop.header) - .position(|node| natural_loop.control[node.idx()]) - .unwrap(); - - let init_idx = 1 - continue_idx; - - // FIXME: Handle multiple loop entries - if editor.get_uses(natural_loop.header).len() > 2 { - todo!() - } - - let Node::Constant { id } = &editor.func().nodes[data[init_idx].idx()] else { - return None; - }; - - // Check that the ID is true. - let Constant::Boolean(val) = *editor.get_constant(*id) else { - return None; - }; - if val != true { - return None; - }; - - // Check other phi input. - - // FIXME: Factor this out into diff loop analysis. - let Node::Binary { left, right, op } = &editor.func().nodes[data[continue_idx].idx()].clone() - else { - return None; - }; - - let BinaryOperator::LT = op else { return None }; - - let bound = &editor.func().nodes[right.idx()]; - if !(bound.is_constant() || bound.is_dynamic_constant()) { - return None; - }; - let bound = match bound { - Node::Constant { id } => { - let constant = editor.get_constant(*id).clone(); - let Constant::UnsignedInteger64(v) = constant else { - return None; - }; - let mut b = DynamicConstantID::new(0); - editor.edit(|mut edit| { - b = edit.add_dynamic_constant(DynamicConstant::Constant(v.try_into().unwrap())); - Ok(edit) - }); - // Return the ID of the dynamic constant that is generated from the constant - // or dynamic constant that is the existing loop bound - b - } - Node::DynamicConstant { id } => *id, - _ => unreachable!(), - }; - - let Node::Binary { - left: add_left, - right: add_right, - op: add_op, - } = &editor.func().nodes[left.idx()] - else { - return None; - }; - - let (phi, inc) = if let Node::Phi { control, data } = &editor.func().nodes[add_left.idx()] { - (add_left, add_right) - } else if let Node::Phi { control, data } = &editor.func().nodes[add_right.idx()] { - (add_right, add_left) - } else { - return None; - }; - - // Check Constant - let Node::Constant { id } = &editor.func().nodes[inc.idx()] else { - return None; - }; - - if !editor.get_constant(*id).is_one() { - return None; - } - - // Check PHI - let Node::Phi { - control: outer_control, - data: outer_data, - } = &editor.func().nodes[phi.idx()] - else { - unreachable!() - }; - - // FIXME: Multiple loop predecessors. - if outer_data[continue_idx] != *left { - return None; - }; - - let Node::Constant { id } = &editor.func().nodes[outer_data[init_idx].idx()] else { - return None; - }; - - if !editor.get_constant(*id).is_zero() { - return None; - } - - // All checks passed, make new DC - let mut final_node = NodeID::new(0); - - editor.edit(|mut edit| { - let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); - let max_dc = edit.add_dynamic_constant(DynamicConstant::Max(one, bound)); - final_node = edit.add_node(Node::DynamicConstant { id: max_dc }); - Ok(edit) - }); - - Some(final_node) -} - pub fn has_const_fields(editor: &FunctionEditor, ivar: InductionVariable) -> bool { match ivar { InductionVariable::Basic { diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 1ef70561..730f6216 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -668,9 +668,6 @@ impl<'a> FunctionExecutionState<'a> { .get(InterpreterVal::array_idx(&extents, &array_indices)) .unwrap_or(&InterpreterVal::Undef(type_id)) .clone(); - if let InterpreterVal::Undef(_) = ret { - panic!("bad read!") - } ret } else { panic!("PANIC: Position index on not an array") diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index e40c429d..fa5d1f04 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -8,9 +8,9 @@ juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { - const I: usize = 4; - const J: usize = 4; - const K: usize = 4; + const I: usize = 256; + const J: usize = 64; + const K: usize = 128; let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); @@ -24,14 +24,10 @@ fn main() { let a = HerculesCPURef::from_slice(&a); let b = HerculesCPURef::from_slice(&b); let mut r = runner!(matmul); - let c = r - .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) - .await; + let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; assert_eq!(c.as_slice::<i32>(), &*correct_c); - let mut r = runner!(tiled_2_matmul); - let tiled_c = r - .run(I as u64, J as u64, K as u64, a.clone(), b.clone()) - .await; + let mut r = runner!(tiled_64_matmul); + let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); }); } @@ -40,3 +36,4 @@ fn main() { fn matmul_test() { main(); } + diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index 92c25710..ca9be73a 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -15,33 +15,33 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[ } #[entry] -fn tiled_2_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { +fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { let res : i32[n, l]; - let atile : i32[2, 2]; - let btile : i32[2, 2]; - let ctile : i32[2, 2]; + let atile : i32[64, 64]; + let btile : i32[64, 64]; + let ctile : i32[64, 64]; - for bi = 0 to n / 2 { - for bk = 0 to l / 2 { - for ti = 0 to 2 { - for tk = 0 to 2 { + for bi = 0 to n / 64 { + for bk = 0 to l / 64 { + for ti = 0 to 64 { + for tk = 0 to 64 { atile[ti, tk] = 0; btile[ti, tk] = 0; ctile[ti, tk] = 0; } } - for tile_idx = 0 to m / 2 { - for ti = 0 to 2 { - for tk = 0 to 2 { - atile[ti, tk] = a[bi * 2 + ti, tile_idx * 2 + tk]; - btile[ti, tk] = b[tile_idx * 2 + ti, bk * 2 + tk]; + for tile_idx = 0 to m / 64 { + for ti = 0 to 64 { + for tk = 0 to 64 { + atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk]; + btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk]; } } - for ti = 0 to 2 { - for tk = 0 to 2 { + for ti = 0 to 64 { + for tk = 0 to 64 { let c_acc = ctile[ti, tk]; - for inner_idx = 0 to 2 { + for inner_idx = 0 to 64 { c_acc += atile[ti, inner_idx] * btile[inner_idx, tk]; } ctile[ti, tk] = c_acc; @@ -49,9 +49,9 @@ fn tiled_2_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) } } - for ti = 0 to 2 { - for tk = 0 to 2 { - res[bi * 2 + ti, bk * 2 + tk] = ctile[ti, tk]; + for ti = 0 to 64 { + for tk = 0 to 64 { + res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk]; } } } diff --git a/juno_samples/matmul/src/sched.sch b/juno_samples/matmul/src/sched.sch new file mode 100644 index 00000000..3999f923 --- /dev/null +++ b/juno_samples/matmul/src/sched.sch @@ -0,0 +1,76 @@ +macro juno-setup!(X) { + gvn(X); + dce(X); + phi-elim(X); +} + +macro default!(X) { + dce(X); + crc(X); + dce(X); + slf(X); + dce(X); + inline(X); + ip-sroa(X); + sroa(X); + phi-elim(X); + dce(X); + ccp(X); + dce(X); + gvn(X); + dce(X); + write-predication(X); + phi-elim(X); + dce(X); + crc(X); + dce(X); + slf(X); + dce(X); + predication(X); + dce(X); + ccp(X); + dce(X); + gvn(X); + dce(X); + lift-dc-math(X); + dce(X); + gvn(X); + dce(X); +} + +macro codegen-prep!(X) { + verify(*); + ip-sroa(*); + sroa(*); + infer-schedules(X); + dce(X); + gcm(X); + dce(X); + phi-elim(X); + float-collections(X); + gcm(X); +} + +juno-setup!(*); +default!(*); +// your stuff here. + +fixpoint stop after 13 { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); + phi-elim(*); + dce(*); +} + +xdot[true](*); +// serialize(*); + +fork-split(*); +unforkify(*); + +gvn(*); +dce(*); + +auto-outline(*); +codegen-prep!(*); -- GitLab