From 2d577f79433c36f7c00e2368f9505dfee0efb644 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 10:59:16 -0600 Subject: [PATCH] Fork-reshape --- juno_samples/matmul/build.rs | 2 + juno_samples/matmul/src/cpu.sch | 62 +++++++++ juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 3 + juno_scheduler/src/lang.y | 12 +- juno_scheduler/src/pm.rs | 214 ++++++++++++++++++++++++++++++++ 6 files changed, 284 insertions(+), 10 deletions(-) create mode 100644 juno_samples/matmul/src/cpu.sch diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c6..d2813388 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch new file mode 100644 index 00000000..7577aebb --- /dev/null +++ b/juno_samples/matmul/src/cpu.sch @@ -0,0 +1,62 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); +} + +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} + +macro unforkify!(X) { + fork-split(X); + unforkify(X); +} + +optimize!(*); +forkify!(*); +associative(matmul@outer); + +// Parallelize by computing output array as 16 chunks +let par = matmul@outer \ matmul@inner; +fork-tile; +let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1; +parallelize!(outer \ inner); + +let body = outline(inner); +cpu(body); + +// Tile for cache, assuming 64B cache lines +fork-split(body); +fork-tile; +let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1; + +reduce-slf(inner); +unforkify!(body); +codegen-prep!(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 7b8c5020..0f473f0d 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -134,6 +134,7 @@ impl FromStr for Appliable { "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), + "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 71a185ba..cf4b6558 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -20,6 +20,7 @@ pub enum Pass { ForkFusion, ForkGuardElim, ForkInterchange, + ForkReshape, ForkSplit, ForkUnroll, Forkify, @@ -57,6 +58,7 @@ impl Pass { Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, + Pass::ForkReshape => true, Pass::Print => num == 1, Pass::Rename => num == 1, Pass::SROA => num == 0 || num == 1, @@ -73,6 +75,7 @@ impl Pass { Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", + Pass::ForkReshape => "any", Pass::Print => "1", Pass::Rename => "1", Pass::SROA => "0 or 1", diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 5903e429..69d50014 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -61,9 +61,9 @@ Expr -> Expr | Expr '@' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } | '(' Exprs ')' - { exprs_to_expr($span, $2) } + { Expr::Tuple { span: $span, exps: $2 } } | '[' Exprs ']' - { exprs_to_expr($span, $2) } + { Expr::Tuple { span: $span, exps: $2 } } | '{' Schedule '}' { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' @@ -203,11 +203,3 @@ pub struct MacroDecl { pub selection_name: Span, pub def: Box<OperationList>, } - -fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr { - if exps.len() == 1 { - exps.pop().unwrap() - } else { - Expr::Tuple { span, exps } - } -} diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index ef9ec038..36259718 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2901,6 +2901,220 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkReshape => { + let mut shape = vec![]; + let mut loops = BTreeSet::new(); + let mut fork_count = 0; + + for arg in args { + let Value::Tuple { values } = arg else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + + let mut indices = vec![]; + for val in values { + let Value::Integer { val: idx } = val else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + indices.push(idx); + loops.insert(idx); + fork_count += 1; + } + shape.push(indices); + } + + if loops != (0..fork_count).collect() { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(), + }); + } + + let Some((nodes, func_id)) = selection_as_set(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "must be applied to nodes in a single function".to_string(), + }); + }; + let func = func_id.idx(); + + pm.make_def_uses(); + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_reduce_cycles(); + + let def_uses = pm.def_uses.take().unwrap(); + let mut fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + + let def_use = &def_uses[func]; + let fork_join_map = &mut fork_join_maps[func]; + let loops = &loops[func]; + let reduce_cycles = &reduce_cycles[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + // There should be exactly one fork nest in the selection and it should contain + // exactly fork_count forks (counting each dimension of each fork) + // We determine the loops (ordered top-down) that are contained in the selection + // (in particular the header is in the selection) and its a fork-join (the header + // is a fork) + let mut loops = loops.bottom_up_loops().into_iter().rev() + .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); + let Some((top_fork_head, top_fork_body)) = loops.next() else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name), + }); + }; + // All the remaining forks need to be contained in the top fork body + let mut forks = vec![top_fork_head]; + let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len(); + for (head, _) in loops { + if !top_fork_body[head.idx()] { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "selection includes multiple non-nested forks".to_string(), + }); + } else { + forks.push(head); + num_dims += editor.node(head).try_fork().unwrap().1.len(); + } + } + + if num_dims != fork_count { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name), + }); + } + + // Now, we coalesce all of these forks into one so that we can interchange them + let mut forks = forks.into_iter(); + let top_fork = forks.next().unwrap(); + let mut cur_fork = top_fork; + for next_fork in forks { + let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to coalesce forks".to_string(), + }); + }; + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + let join = *fork_join_map.get(&cur_fork).unwrap(); + + // Now we have just one fork and we can perform the interchanges we need + // To do this, we track two maps: from original index to current index and from + // current index to original index + let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>(); + let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>(); + + // Now, starting from the first (outermost) index we move the desired fork bound + // into place + for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() { + let cur_idx = orig_to_cur[*original_idx]; + let swapping = cur_to_orig[idx]; + + // If the desired factor is already in the correct place, do nothing + if cur_idx == idx { + continue; + } + assert!(idx < cur_idx); + let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to interchange forks".to_string(), + }); + }; + cur_fork = fork_res; + + // Update our maps + orig_to_cur[*original_idx] = idx; + orig_to_cur[swapping] = cur_idx; + cur_to_orig[idx] = *original_idx; + cur_to_orig[cur_idx] = swapping; + } + + // Finally we split the fork into the desired pieces. We do this by first splitting + // the fork into individual forks and then coalesce the chunks together + // Not sure how split_fork could fail, so if it does panic is fine + let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap(); + + for (fork, join) in forks.iter().zip(joins.iter()) { + fork_join_map.insert(*fork, *join); + } + + // Finally coalesce the chunks together + let mut fork_idx = 0; + let mut final_forks = vec![]; + for chunk in shape.iter() { + let chunk_len = chunk.len(); + + let mut cur_fork = forks[fork_idx]; + for i in 1..chunk_len { + let next_fork = forks[fork_idx + i]; + // Again, not sure at this point how coalesce could fail, so panic if it + // does + let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap(); + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + + fork_idx += chunk_len; + final_forks.push(cur_fork); + } + + // Label each fork and return the labels + // We've trashed our analyses at this point, so rerun them so that we can determine the + // nodes in each of the result fork-joins + pm.clear_analyses(); + pm.make_def_uses(); + pm.make_nodes_in_fork_joins(); + + let def_uses = pm.def_uses.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + let def_use = &def_uses[func]; + let nodes_in_fork_joins = &nodes_in_fork_joins[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied())) + .into_iter() + .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] }) + .collect(); + + result = Value::Tuple { values: labels }; + changed = true; + + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { -- GitLab