From 9de4d3bb84a0e9c34d6119b3d82f559b1ae89f95 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Sat, 22 Feb 2025 18:28:49 -0600 Subject: [PATCH 01/15] Make fork coalesce use one edit to handle partial selection properly --- hercules_opt/src/fork_transforms.rs | 47 ++++++++++++----------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e635b3c0..ec111e69 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -686,40 +686,33 @@ pub fn fork_coalesce_helper( // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); - for tid in inner_tids { - let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); - let new_tid = Node::ThreadID { - control: fork, - dimension: dim + num_outer_dims, - }; + editor.edit(|mut edit| { + for tid in inner_tids { + let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap(); + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; - editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); - let edit = edit.replace_all_uses(tid, new_tid)?; - Ok(edit) - }); - } + edit = edit.replace_all_uses(tid, new_tid)?; + } - // Fuse Reductions - for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] - .try_reduce() - .unwrap(); - let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] - .try_reduce() - .unwrap(); - editor.edit(|mut edit| { + // Fuse Reductions + for (outer_reduce, inner_reduce) in pairs { + let (_, outer_init, _) = edit.get_node(outer_reduce) + .try_reduce() + .unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce) + .try_reduce() + .unwrap(); // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; + } - Ok(edit) - }); - } - - editor.edit(|mut edit| { let new_fork = Node::Fork { control: outer_pred, factors: new_factors.into(), @@ -734,9 +727,7 @@ pub fn fork_coalesce_helper( edit = edit.delete_node(outer_fork)?; Ok(edit) - }); - - true + }) } pub fn split_any_fork( -- GitLab From 972c1e6ebe2f616181f4a8f7a5d90c500f9ac518 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Sat, 22 Feb 2025 18:42:41 -0600 Subject: [PATCH 02/15] Propagate ParallelFork in coalesce and interchange --- hercules_opt/src/fork_transforms.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ec111e69..9fceb892 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -719,6 +719,11 @@ pub fn fork_coalesce_helper( }; let new_fork = edit.add_node(new_fork); + if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork) + && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } + edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; edit = edit.replace_all_uses(outer_join, inner_join)?; @@ -1271,6 +1276,9 @@ fn fork_interchange( edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); + if edit.get_schedule(fork).contains(&Schedule::ParallelFork) { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } edit = edit.replace_all_uses(fork, new_fork)?; edit.delete_node(fork) }); -- GitLab From 199a8a80e72e8377faedea785a879847edf11ff7 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Sat, 22 Feb 2025 18:44:44 -0600 Subject: [PATCH 03/15] Manual gpu schedule --- juno_samples/matmul/src/gpu.sch | 63 +++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 76808149..effdc6b2 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -1,26 +1,51 @@ -phi-elim(*); +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen!(X) { + gcm(*); + float-collections(*); + dce(*); + gcm(*); +} -forkify(*); -fork-guard-elim(*); -dce(*); +optimize!(*); -fixpoint { - reduce-slf(*); - slf(*); - infer-schedules(*); +fixpoint panic after 20 { + forkify(matmul); + fork-guard-elim(matmul); } -fork-coalesce(*); -infer-schedules(*); -dce(*); -rewrite(*); -fixpoint { - simplify-cfg(*); - dce(*); + +optimize!(*); + +fixpoint panic after 20 { + reduce-slf(matmul); + slf(matmul); + infer-schedules(matmul); } +dce(matmul); -ip-sroa(*); -sroa(*); -dce(*); +// Tile outer and middle loops into 32x32 sized blocks +fork-tile[32, 0, false, true](matmul@outer \ matmul@inner); +// Merge outer and middle loops and interchange so blocks are first +fork-coalesce(matmul@outer \ matmul@inner); +fork-interchange[1, 2](matmul@outer \ matmul@inner); +// Split forks +let split = fork-split(matmul); +// Join the threads and then blocks into a single fork each +fork-coalesce(split.matmul.fj2 \ matmul@inner); +fork-coalesce(split.matmul.fj0 \ split.matmul.fj2); +let auto = auto-outline(*); float-collections(*); -gcm(*); +gpu(auto.matmul); + +codegen!(*); -- GitLab From 4b1630b2537bd82e2d52c24509797e9c5bef77d5 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 24 Feb 2025 14:27:36 -0600 Subject: [PATCH 04/15] Restore gpu schedule --- juno_samples/matmul/src/gpu.sch | 63 ++++++++++----------------------- 1 file changed, 19 insertions(+), 44 deletions(-) diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index effdc6b2..76808149 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -1,51 +1,26 @@ -macro optimize!(X) { - gvn(X); - phi-elim(X); - dce(X); - ip-sroa(X); - sroa(X); - dce(X); - gvn(X); - phi-elim(X); - dce(X); -} - -macro codegen!(X) { - gcm(*); - float-collections(*); - dce(*); - gcm(*); -} +phi-elim(*); -optimize!(*); +forkify(*); +fork-guard-elim(*); +dce(*); -fixpoint panic after 20 { - forkify(matmul); - fork-guard-elim(matmul); +fixpoint { + reduce-slf(*); + slf(*); + infer-schedules(*); } - -optimize!(*); - -fixpoint panic after 20 { - reduce-slf(matmul); - slf(matmul); - infer-schedules(matmul); +fork-coalesce(*); +infer-schedules(*); +dce(*); +rewrite(*); +fixpoint { + simplify-cfg(*); + dce(*); } -dce(matmul); -// Tile outer and middle loops into 32x32 sized blocks -fork-tile[32, 0, false, true](matmul@outer \ matmul@inner); -// Merge outer and middle loops and interchange so blocks are first -fork-coalesce(matmul@outer \ matmul@inner); -fork-interchange[1, 2](matmul@outer \ matmul@inner); -// Split forks -let split = fork-split(matmul); -// Join the threads and then blocks into a single fork each -fork-coalesce(split.matmul.fj2 \ matmul@inner); -fork-coalesce(split.matmul.fj0 \ split.matmul.fj2); +ip-sroa(*); +sroa(*); +dce(*); -let auto = auto-outline(*); float-collections(*); -gpu(auto.matmul); - -codegen!(*); +gcm(*); -- GitLab From d179193de2081aec72dfe06baa96dd58612b98a4 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 24 Feb 2025 15:28:02 -0600 Subject: [PATCH 05/15] Parallel tiled cpu schedule --- juno_samples/matmul/src/cpu.sch | 63 +++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 juno_samples/matmul/src/cpu.sch diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch new file mode 100644 index 00000000..bef45ca2 --- /dev/null +++ b/juno_samples/matmul/src/cpu.sch @@ -0,0 +1,63 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen!(X) { + gcm(*); + float-collections(*); + dce(*); + gcm(*); +} + +optimize!(*); + +fixpoint panic after 20 { + forkify(matmul); + fork-guard-elim(matmul); +} + +// Mark the whole loop nest as associative, any order of iterations is equivalent +associative(matmul@outer); + +// Tile the outer 2 loops to create 16 parallel threads (each responsible for +// computing one block of the output +let par = matmul@outer \ matmul@inner; +fork-tile[4, 0, false, true](par); +fork-coalesce(par); +fork-interchange[0, 1](par); +fork-interchange[2, 3](par); +fork-interchange[1, 2](par); + +let split = fork-split(*); +fork-coalesce(split.matmul.fj0 \ split.matmul.fj2); +parallel-fork(split.matmul.fj0 \ split.matmul.fj2); + +// Pull the body of the parallel loop out into its own device function +let body = outline(split.matmul.fj2); +cpu(body); + +// Tile the loop nest for cache performance; 16x16x16 tile +fork-tile[16, 0, false, true](body); +fixpoint { fork-coalesce(body); } + +fork-interchange[1, 2](body); +fork-interchange[3, 4](body); +fork-interchange[2, 3](body); + +optimize!(*); + +fork-split(body); +reduce-slf(*); +unforkify(body); + +optimize!(*); + +codegen!(*); -- GitLab From 568b399d103e9d91d29754d1f0510df3b949695b Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 24 Feb 2025 15:29:07 -0600 Subject: [PATCH 06/15] Use cpu schedule --- juno_samples/matmul/build.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c6..d2813388 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } -- GitLab From aec9cf1fe4bf40f000d472d63e79cdf66b6b0cb8 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 26 Feb 2025 11:45:20 -0600 Subject: [PATCH 07/15] parallel cpu schedule --- juno_samples/matmul/src/cpu.sch | 3 +++ 1 file changed, 3 insertions(+) diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch index bef45ca2..2fcf3108 100644 --- a/juno_samples/matmul/src/cpu.sch +++ b/juno_samples/matmul/src/cpu.sch @@ -60,4 +60,7 @@ unforkify(body); optimize!(*); +parallel-reduce(split.matmul.fj0); +xdot[true](*); + codegen!(*); -- GitLab From 10f9373682232536ad0c400e4365d88e2f7050f0 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 26 Feb 2025 14:49:03 -0600 Subject: [PATCH 08/15] Clean-up cpu schedule --- juno_samples/matmul/src/cpu.sch | 54 +++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch index 2fcf3108..fb08c254 100644 --- a/juno_samples/matmul/src/cpu.sch +++ b/juno_samples/matmul/src/cpu.sch @@ -10,27 +10,45 @@ macro optimize!(X) { dce(X); } -macro codegen!(X) { - gcm(*); - float-collections(*); - dce(*); - gcm(*); +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); } -optimize!(*); +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} -fixpoint panic after 20 { - forkify(matmul); - fork-guard-elim(matmul); +macro unforkify!(X) { + fork-split(X); + unforkify(X); } +optimize!(*); +forkify!(*); + // Mark the whole loop nest as associative, any order of iterations is equivalent associative(matmul@outer); // Tile the outer 2 loops to create 16 parallel threads (each responsible for // computing one block of the output let par = matmul@outer \ matmul@inner; -fork-tile[4, 0, false, true](par); +fork-tile; fork-coalesce(par); fork-interchange[0, 1](par); fork-interchange[2, 3](par); @@ -38,29 +56,25 @@ fork-interchange[1, 2](par); let split = fork-split(*); fork-coalesce(split.matmul.fj0 \ split.matmul.fj2); -parallel-fork(split.matmul.fj0 \ split.matmul.fj2); + +parallelize!(split.matmul.fj0 \ split.matmul.fj2); // Pull the body of the parallel loop out into its own device function let body = outline(split.matmul.fj2); cpu(body); // Tile the loop nest for cache performance; 16x16x16 tile -fork-tile[16, 0, false, true](body); +fork-tile; fixpoint { fork-coalesce(body); } - fork-interchange[1, 2](body); fork-interchange[3, 4](body); fork-interchange[2, 3](body); -optimize!(*); - fork-split(body); reduce-slf(*); -unforkify(body); - -optimize!(*); -parallel-reduce(split.matmul.fj0); xdot[true](*); -codegen!(*); +unforkify!(body); + +codegen-prep!(*); -- GitLab From 205e68e2d6fbb1649641750bc80a301f0743deeb Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 26 Feb 2025 18:38:52 -0600 Subject: [PATCH 09/15] Add tuples to scheduling language --- juno_scheduler/src/compile.rs | 23 ++++++++++++++++++++ juno_scheduler/src/ir.rs | 7 ++++++ juno_scheduler/src/lang.y | 32 +++++++++++++++++++++------ juno_scheduler/src/pm.rs | 41 ++++++++++++++++++++++++++++++++++- 4 files changed, 95 insertions(+), 8 deletions(-) diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 3c288ca7..7b8c5020 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -488,6 +488,29 @@ fn compile_expr( rhs: Box::new(rhs), })) } + parser::Expr::Tuple { + span: _, + exps, + } => { + let exprs = exps.into_iter() + .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros)) + .fold(Ok(vec![]), + |mut res, exp| { + let mut res = res?; + res.push(exp?); + Ok(res) + })?; + Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs })) + } + parser::Expr::TupleField { + span: _, + lhs, + field, + } => { + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + let field = lexer.span_str(field).parse().expect("Parsing"); + Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { lhs: Box::new(lhs), field })) + } } } diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 3a087c0d..71a185ba 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -127,6 +127,13 @@ pub enum ScheduleExp { lhs: Box<ScheduleExp>, rhs: Box<ScheduleExp>, }, + Tuple { + exprs: Vec<ScheduleExp>, + }, + TupleField { + lhs: Box<ScheduleExp>, + field: usize, + }, // This is used to "box" a selection by evaluating it at one point and then // allowing it to be used as a selector later on Selection { diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 3b030e1d..5903e429 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -56,10 +56,14 @@ Expr -> Expr { Expr::String { span: $span } } | Expr '.' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | Expr '.' 'INT' + { Expr::TupleField { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } | Expr '@' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } - | '(' Expr ')' - { $2 } + | '(' Exprs ')' + { exprs_to_expr($span, $2) } + | '[' Exprs ']' + { exprs_to_expr($span, $2) } | '{' Schedule '}' { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' @@ -73,14 +77,18 @@ Expr -> Expr ; Args -> Vec<Expr> - : { vec![] } - | '[' Exprs ']' { rev($2) } + : { vec![] } + | '[' RExprs ']' { rev($2) } ; Exprs -> Vec<Expr> - : { vec![] } - | Expr { vec![$1] } - | Expr ',' Exprs { snoc($1, $3) } + : RExprs { rev($1) } + ; + +RExprs -> Vec<Expr> + : { vec![] } + | Expr { vec![$1] } + | Expr ',' RExprs { snoc($1, $3) } ; Fields -> Vec<(Span, Expr)> @@ -180,6 +188,8 @@ pub enum Expr { BlockExpr { span: Span, body: Box<OperationList> }, Record { span: Span, fields: Vec<(Span, Expr)> }, SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> }, + Tuple { span: Span, exps: Vec<Expr> }, + TupleField { span: Span, lhs: Box<Expr>, field: Span }, } pub enum Selector { @@ -193,3 +203,11 @@ pub struct MacroDecl { pub selection_name: Span, pub def: Box<OperationList>, } + +fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr { + if exps.len() == 1 { + exps.pop().unwrap() + } else { + Expr::Tuple { span, exps } + } +} diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 5f2fa4cc..ef9ec038 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -294,6 +294,9 @@ pub enum Value { Record { fields: HashMap<String, Value>, }, + Tuple { + values: Vec<Value>, + }, Everything {}, Selection { selection: Vec<Value>, @@ -371,6 +374,11 @@ impl Value { "Expected code selection, found record".to_string(), )); } + Value::Tuple { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found tuple".to_string(), + )); + } Value::Integer { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found integer".to_string(), @@ -1291,6 +1299,7 @@ fn interp_expr( | Value::Integer { .. } | Value::Boolean { .. } | Value::String { .. } + | Value::Tuple { .. } | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { @@ -1463,7 +1472,24 @@ fn interp_expr( } Ok((Value::Selection { selection: values }, changed)) } - }, + } + ScheduleExp::Tuple { exprs } => { + let mut vals = vec![]; + let mut changed = false; + for exp in exprs { + let (val, change) = interp_expr(pm, exp, stringtab, env, functions)?; + vals.push(val); + changed = changed || change; + } + Ok((Value::Tuple { values: vals }, changed)) + } + ScheduleExp::TupleField { lhs, field } => { + let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?; + match val { + Value::Tuple { values } if *field < values.len() => Ok((vec_take(values, *field), changed)), + _ => Err(SchedulerError::SemanticError(format!("No field at index {}", field))), + } + } } } @@ -1521,6 +1547,15 @@ fn update_value( Some(Value::Record { fields: new_fields }) } } + // For tuples, if we deleted values like we do for records this would mess up the indices + // which would behave very strangely. Instead if any field cannot be updated then we + // eliminate the entire value + Value::Tuple { values } => { + values.into_iter() + .map(|v| update_value(v, func_idx, juno_func_idx)) + .collect::<Option<Vec<_>>>() + .map(|values| Value::Tuple { values }) + } Value::JunoFunction { func } => { juno_func_idx[func.idx] .clone() @@ -3016,3 +3051,7 @@ where }); labels } + +fn vec_take<T>(mut v: Vec<T>, index: usize) -> T { + v.swap_remove(index) +} -- GitLab From 3c8eaae28662fc423b2514d9ede60ab9ef32227f Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 09:51:46 -0600 Subject: [PATCH 10/15] add_node only adds to def use map if not already present --- hercules_opt/src/editor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 17cea325..ce344699 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -457,7 +457,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { pub fn add_node(&mut self, node: Node) -> NodeID { let id = NodeID::new(self.num_node_ids()); // Added nodes need to have an entry in the def-use map. - self.updated_def_use.insert(id, HashSet::new()); + self.updated_def_use.entry(id).or_insert(HashSet::new()); // Added nodes use other nodes, and we need to update their def-use // entries. for u in get_uses(&node).as_ref() { -- GitLab From 4754a909271bdc228294e501cbaa1aa067fbf599 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 09:52:32 -0600 Subject: [PATCH 11/15] Return new fork ID from fork coalesce and interchange --- hercules_opt/src/fork_transforms.rs | 47 +++++++++++++++++------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ff0f0283..cb0e7de4 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -578,7 +578,7 @@ pub fn fork_coalesce( // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { - if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { + if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() { return true; } } @@ -587,13 +587,15 @@ pub fn fork_coalesce( /** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. + Returns None if the forks could not be merged and the NodeIDs of the + resulting fork and join if it succeeds in merging them. */ pub fn fork_coalesce_helper( editor: &mut FunctionEditor, outer_fork: NodeID, inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, -) -> bool { +) -> Option<(NodeID, NodeID)> { // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; @@ -621,19 +623,19 @@ pub fn fork_coalesce_helper( reduct: _, } = inner_reduce_node else { - return false; + return None; }; // FIXME: check this condition better (i.e reduce might not be attached to join) if *inner_control != inner_join { - return false; + return None; }; if *inner_init != outer_reduce { - return false; + return None; }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { - return false; + return None; } else { pairs.insert(outer_reduce, inner_reduce); } @@ -645,11 +647,11 @@ pub fn fork_coalesce_helper( .filter(|node| editor.func().nodes[node.idx()].is_control()) .next() else { - return false; + return None; }; if user != inner_fork { - return false; + return None; } let Some(user) = editor @@ -657,11 +659,11 @@ pub fn fork_coalesce_helper( .filter(|node| editor.func().nodes[node.idx()].is_control()) .next() else { - return false; + return None; }; if user != outer_join { - return false; + return None; } // Checklist: @@ -709,10 +711,10 @@ pub fn fork_coalesce_helper( let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] .try_reduce() .unwrap(); - editor.edit(|mut edit| { + let success = editor.edit(|mut edit| { // Set inner init to outer init. edit = - edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; + edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; @@ -720,12 +722,15 @@ pub fn fork_coalesce_helper( }); } + let mut new_fork = NodeID::new(0); + let new_join = inner_join; // We reuse the inner join as the join of the new fork + editor.edit(|mut edit| { - let new_fork = Node::Fork { + let new_fork_node = Node::Fork { control: outer_pred, factors: new_factors.into(), }; - let new_fork = edit.add_node(new_fork); + new_fork = edit.add_node(new_fork_node); edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; @@ -737,7 +742,7 @@ pub fn fork_coalesce_helper( Ok(edit) }); - true + Some((new_fork, new_join)) } pub fn split_any_fork( @@ -760,7 +765,7 @@ pub fn split_any_fork( * Useful for code generation. A single iteration of `fork_split` only splits * at most one fork-join, it must be called repeatedly to split all fork-joins. */ -pub(crate) fn split_fork( +pub fn split_fork( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, @@ -1215,13 +1220,13 @@ pub fn fork_interchange_all_forks( } } -fn fork_interchange( +pub fn fork_interchange( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, first_dim: usize, second_dim: usize, -) { +) -> Option<NodeID> { // Check that every reduce on the join is parallel or associative. let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; @@ -1234,7 +1239,7 @@ fn fork_interchange( }) { // If not, we can't necessarily do interchange. - return; + return None; } let Node::Fork { @@ -1276,6 +1281,7 @@ fn fork_interchange( let mut factors = factors.clone(); factors.swap(first_dim, second_dim); let new_fork = Node::Fork { control, factors }; + let mut new_fork_id = None; editor.edit(|mut edit| { for (old_id, new_tid) in fix_tids { let new_id = edit.add_node(new_tid); @@ -1283,9 +1289,12 @@ fn fork_interchange( edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); + new_fork_id = Some(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; edit.delete_node(fork) }); + + new_fork_id } /* -- GitLab From 751ded74756f4d650e23e8cac8ce55081106c8cb Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 10:59:01 -0600 Subject: [PATCH 12/15] Improve fork coalesce check for intermediate control --- hercules_opt/src/fork_transforms.rs | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index cb0e7de4..12b91194 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -642,27 +642,15 @@ pub fn fork_coalesce_helper( } // Check for control between join-join and fork-fork - let Some(user) = editor - .get_users(outer_fork) - .filter(|node| editor.func().nodes[node.idx()].is_control()) - .next() - else { - return None; - }; + let (control, _) = editor.node(inner_fork).try_fork().unwrap(); - if user != inner_fork { + if control != outer_fork { return None; } - let Some(user) = editor - .get_users(inner_join) - .filter(|node| editor.func().nodes[node.idx()].is_control()) - .next() - else { - return None; - }; + let control = editor.node(outer_join).try_join().unwrap(); - if user != outer_join { + if control != inner_join { return None; } -- GitLab From 2d577f79433c36f7c00e2368f9505dfee0efb644 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 10:59:16 -0600 Subject: [PATCH 13/15] Fork-reshape --- juno_samples/matmul/build.rs | 2 + juno_samples/matmul/src/cpu.sch | 62 +++++++++ juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 3 + juno_scheduler/src/lang.y | 12 +- juno_scheduler/src/pm.rs | 214 ++++++++++++++++++++++++++++++++ 6 files changed, 284 insertions(+), 10 deletions(-) create mode 100644 juno_samples/matmul/src/cpu.sch diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c6..d2813388 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch new file mode 100644 index 00000000..7577aebb --- /dev/null +++ b/juno_samples/matmul/src/cpu.sch @@ -0,0 +1,62 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); +} + +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} + +macro unforkify!(X) { + fork-split(X); + unforkify(X); +} + +optimize!(*); +forkify!(*); +associative(matmul@outer); + +// Parallelize by computing output array as 16 chunks +let par = matmul@outer \ matmul@inner; +fork-tile; +let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1; +parallelize!(outer \ inner); + +let body = outline(inner); +cpu(body); + +// Tile for cache, assuming 64B cache lines +fork-split(body); +fork-tile; +let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1; + +reduce-slf(inner); +unforkify!(body); +codegen-prep!(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 7b8c5020..0f473f0d 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -134,6 +134,7 @@ impl FromStr for Appliable { "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), + "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 71a185ba..cf4b6558 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -20,6 +20,7 @@ pub enum Pass { ForkFusion, ForkGuardElim, ForkInterchange, + ForkReshape, ForkSplit, ForkUnroll, Forkify, @@ -57,6 +58,7 @@ impl Pass { Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, + Pass::ForkReshape => true, Pass::Print => num == 1, Pass::Rename => num == 1, Pass::SROA => num == 0 || num == 1, @@ -73,6 +75,7 @@ impl Pass { Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", + Pass::ForkReshape => "any", Pass::Print => "1", Pass::Rename => "1", Pass::SROA => "0 or 1", diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 5903e429..69d50014 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -61,9 +61,9 @@ Expr -> Expr | Expr '@' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } | '(' Exprs ')' - { exprs_to_expr($span, $2) } + { Expr::Tuple { span: $span, exps: $2 } } | '[' Exprs ']' - { exprs_to_expr($span, $2) } + { Expr::Tuple { span: $span, exps: $2 } } | '{' Schedule '}' { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' @@ -203,11 +203,3 @@ pub struct MacroDecl { pub selection_name: Span, pub def: Box<OperationList>, } - -fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr { - if exps.len() == 1 { - exps.pop().unwrap() - } else { - Expr::Tuple { span, exps } - } -} diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index ef9ec038..36259718 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2901,6 +2901,220 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkReshape => { + let mut shape = vec![]; + let mut loops = BTreeSet::new(); + let mut fork_count = 0; + + for arg in args { + let Value::Tuple { values } = arg else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + + let mut indices = vec![]; + for val in values { + let Value::Integer { val: idx } = val else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + indices.push(idx); + loops.insert(idx); + fork_count += 1; + } + shape.push(indices); + } + + if loops != (0..fork_count).collect() { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(), + }); + } + + let Some((nodes, func_id)) = selection_as_set(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "must be applied to nodes in a single function".to_string(), + }); + }; + let func = func_id.idx(); + + pm.make_def_uses(); + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_reduce_cycles(); + + let def_uses = pm.def_uses.take().unwrap(); + let mut fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + + let def_use = &def_uses[func]; + let fork_join_map = &mut fork_join_maps[func]; + let loops = &loops[func]; + let reduce_cycles = &reduce_cycles[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + // There should be exactly one fork nest in the selection and it should contain + // exactly fork_count forks (counting each dimension of each fork) + // We determine the loops (ordered top-down) that are contained in the selection + // (in particular the header is in the selection) and its a fork-join (the header + // is a fork) + let mut loops = loops.bottom_up_loops().into_iter().rev() + .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); + let Some((top_fork_head, top_fork_body)) = loops.next() else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name), + }); + }; + // All the remaining forks need to be contained in the top fork body + let mut forks = vec![top_fork_head]; + let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len(); + for (head, _) in loops { + if !top_fork_body[head.idx()] { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "selection includes multiple non-nested forks".to_string(), + }); + } else { + forks.push(head); + num_dims += editor.node(head).try_fork().unwrap().1.len(); + } + } + + if num_dims != fork_count { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name), + }); + } + + // Now, we coalesce all of these forks into one so that we can interchange them + let mut forks = forks.into_iter(); + let top_fork = forks.next().unwrap(); + let mut cur_fork = top_fork; + for next_fork in forks { + let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to coalesce forks".to_string(), + }); + }; + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + let join = *fork_join_map.get(&cur_fork).unwrap(); + + // Now we have just one fork and we can perform the interchanges we need + // To do this, we track two maps: from original index to current index and from + // current index to original index + let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>(); + let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>(); + + // Now, starting from the first (outermost) index we move the desired fork bound + // into place + for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() { + let cur_idx = orig_to_cur[*original_idx]; + let swapping = cur_to_orig[idx]; + + // If the desired factor is already in the correct place, do nothing + if cur_idx == idx { + continue; + } + assert!(idx < cur_idx); + let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to interchange forks".to_string(), + }); + }; + cur_fork = fork_res; + + // Update our maps + orig_to_cur[*original_idx] = idx; + orig_to_cur[swapping] = cur_idx; + cur_to_orig[idx] = *original_idx; + cur_to_orig[cur_idx] = swapping; + } + + // Finally we split the fork into the desired pieces. We do this by first splitting + // the fork into individual forks and then coalesce the chunks together + // Not sure how split_fork could fail, so if it does panic is fine + let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap(); + + for (fork, join) in forks.iter().zip(joins.iter()) { + fork_join_map.insert(*fork, *join); + } + + // Finally coalesce the chunks together + let mut fork_idx = 0; + let mut final_forks = vec![]; + for chunk in shape.iter() { + let chunk_len = chunk.len(); + + let mut cur_fork = forks[fork_idx]; + for i in 1..chunk_len { + let next_fork = forks[fork_idx + i]; + // Again, not sure at this point how coalesce could fail, so panic if it + // does + let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap(); + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + + fork_idx += chunk_len; + final_forks.push(cur_fork); + } + + // Label each fork and return the labels + // We've trashed our analyses at this point, so rerun them so that we can determine the + // nodes in each of the result fork-joins + pm.clear_analyses(); + pm.make_def_uses(); + pm.make_nodes_in_fork_joins(); + + let def_uses = pm.def_uses.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + let def_use = &def_uses[func]; + let nodes_in_fork_joins = &nodes_in_fork_joins[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied())) + .into_iter() + .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] }) + .collect(); + + result = Value::Tuple { values: labels }; + changed = true; + + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { -- GitLab From 39e2ea9372530618a5f6aab8cf6c8b371849b637 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 11:16:08 -0600 Subject: [PATCH 14/15] Allow let-tuple destruction in scheduling language --- juno_samples/matmul/src/cpu.sch | 5 ++--- juno_scheduler/src/compile.rs | 18 ++++++++++++++++++ juno_scheduler/src/lang.y | 3 +++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch index 7577aebb..69f1811d 100644 --- a/juno_samples/matmul/src/cpu.sch +++ b/juno_samples/matmul/src/cpu.sch @@ -46,16 +46,15 @@ associative(matmul@outer); // Parallelize by computing output array as 16 chunks let par = matmul@outer \ matmul@inner; fork-tile; -let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1; +let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); parallelize!(outer \ inner); let body = outline(inner); cpu(body); // Tile for cache, assuming 64B cache lines -fork-split(body); fork-tile; -let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1; +let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); reduce-slf(inner); unforkify!(body); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 0f473f0d..51ba3699 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -207,6 +207,24 @@ fn compile_stmt( exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } + parser::Stmt::LetsStmt { span: _, vars, expr } => { + let tmp = format!("{}_tmp", macros.uniq()); + Ok(std::iter::once(ir::ScheduleStmt::Let { + var: tmp.clone(), + exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, + }).chain(vars.into_iter().enumerate() + .map(|(idx, v)| { + let var = lexer.span_str(v).to_string(); + ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::TupleField { + lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), + field: idx, + } + } + }) + ).collect()) + } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); Ok(vec![ir::ScheduleStmt::Assign { diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 69d50014..451f035b 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -19,6 +19,8 @@ Schedule -> OperationList Stmt -> Stmt : 'let' 'ID' '=' Expr ';' { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } } + | 'let' '(' Ids ')' '=' Expr ';' + { Stmt::LetsStmt { span: $span, vars: rev($3), expr: $6 } } | 'ID' '=' Expr ';' { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } } | Expr ';' @@ -157,6 +159,7 @@ pub enum OperationList { pub enum Stmt { LetStmt { span: Span, var: Span, expr: Expr }, + LetsStmt { span: Span, vars: Vec<Span>, expr: Expr }, AssignStmt { span: Span, var: Span, rhs: Expr }, ExprStmt { span: Span, exp: Expr }, Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> }, -- GitLab From 7370cb51e84d2992fcde703ed4bd8090250e9000 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 11:47:28 -0600 Subject: [PATCH 15/15] Formatting --- hercules_opt/src/fork_transforms.rs | 19 ++++---- hercules_samples/matmul/src/main.rs | 3 +- juno_scheduler/src/compile.rs | 52 +++++++++++--------- juno_scheduler/src/pm.rs | 75 ++++++++++++++++++++--------- 4 files changed, 93 insertions(+), 56 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 9a16c99c..e6db0345 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -693,15 +693,11 @@ pub fn fork_coalesce_helper( } // Fuse Reductions for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = edit.get_node(outer_reduce) - .try_reduce() - .unwrap(); - let (_, inner_init, _) = edit.get_node(inner_reduce) - .try_reduce() - .unwrap(); + let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap(); // Set inner init to outer init. edit = - edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?; + edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; } @@ -712,8 +708,13 @@ pub fn fork_coalesce_helper( }; new_fork = edit.add_node(new_fork_node); - if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork) - && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) { + if edit + .get_schedule(outer_fork) + .contains(&Schedule::ParallelFork) + && edit + .get_schedule(inner_fork) + .contains(&Schedule::ParallelFork) + { edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; } diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 27727664..00e5b873 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -25,7 +25,8 @@ fn main() { let a = HerculesImmBox::from(a.as_ref()); let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); - let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + let mut c: HerculesMutBox<i32> = + HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); assert_eq!(c.as_slice(), correct_c.as_ref()); }); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 32d2c8d1..9d020c64 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -208,23 +208,27 @@ fn compile_stmt( exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } - parser::Stmt::LetsStmt { span: _, vars, expr } => { + parser::Stmt::LetsStmt { + span: _, + vars, + expr, + } => { let tmp = format!("{}_tmp", macros.uniq()); Ok(std::iter::once(ir::ScheduleStmt::Let { var: tmp.clone(), exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, - }).chain(vars.into_iter().enumerate() - .map(|(idx, v)| { - let var = lexer.span_str(v).to_string(); - ir::ScheduleStmt::Let { - var, - exp: ir::ScheduleExp::TupleField { - lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), - field: idx, - } - } - }) - ).collect()) + }) + .chain(vars.into_iter().enumerate().map(|(idx, v)| { + let var = lexer.span_str(v).to_string(); + ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::TupleField { + lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), + field: idx, + }, + } + })) + .collect()) } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); @@ -508,17 +512,14 @@ fn compile_expr( rhs: Box::new(rhs), })) } - parser::Expr::Tuple { - span: _, - exps, - } => { - let exprs = exps.into_iter() + parser::Expr::Tuple { span: _, exps } => { + let exprs = exps + .into_iter() .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros)) - .fold(Ok(vec![]), - |mut res, exp| { - let mut res = res?; - res.push(exp?); - Ok(res) + .fold(Ok(vec![]), |mut res, exp| { + let mut res = res?; + res.push(exp?); + Ok(res) })?; Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs })) } @@ -529,7 +530,10 @@ fn compile_expr( } => { let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; let field = lexer.span_str(field).parse().expect("Parsing"); - Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { lhs: Box::new(lhs), field })) + Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { + lhs: Box::new(lhs), + field, + })) } } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f0b55eca..456df2ed 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1475,7 +1475,7 @@ fn interp_expr( } Ok((Value::Selection { selection: values }, changed)) } - } + }, ScheduleExp::Tuple { exprs } => { let mut vals = vec![]; let mut changed = false; @@ -1489,8 +1489,13 @@ fn interp_expr( ScheduleExp::TupleField { lhs, field } => { let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?; match val { - Value::Tuple { values } if *field < values.len() => Ok((vec_take(values, *field), changed)), - _ => Err(SchedulerError::SemanticError(format!("No field at index {}", field))), + Value::Tuple { values } if *field < values.len() => { + Ok((vec_take(values, *field), changed)) + } + _ => Err(SchedulerError::SemanticError(format!( + "No field at index {}", + field + ))), } } } @@ -1553,12 +1558,11 @@ fn update_value( // For tuples, if we deleted values like we do for records this would mess up the indices // which would behave very strangely. Instead if any field cannot be updated then we // eliminate the entire value - Value::Tuple { values } => { - values.into_iter() - .map(|v| update_value(v, func_idx, juno_func_idx)) - .collect::<Option<Vec<_>>>() - .map(|values| Value::Tuple { values }) - } + Value::Tuple { values } => values + .into_iter() + .map(|v| update_value(v, func_idx, juno_func_idx)) + .collect::<Option<Vec<_>>>() + .map(|values| Value::Tuple { values }), Value::JunoFunction { func } => { juno_func_idx[func.idx] .clone() @@ -2963,7 +2967,9 @@ fn run_pass( if loops != (0..fork_count).collect() { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), - error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(), + error: + "expected forks to be numbered sequentially from 0 and used exactly once" + .to_string(), }); } @@ -3005,12 +3011,19 @@ fn run_pass( // We determine the loops (ordered top-down) that are contained in the selection // (in particular the header is in the selection) and its a fork-join (the header // is a fork) - let mut loops = loops.bottom_up_loops().into_iter().rev() + let mut loops = loops + .bottom_up_loops() + .into_iter() + .rev() .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); let Some((top_fork_head, top_fork_body)) = loops.next() else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), - error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name), + error: format!( + "expected {} forks found 0 in {}", + fork_count, + editor.func().name + ), }); }; // All the remaining forks need to be contained in the top fork body @@ -3031,7 +3044,10 @@ fn run_pass( if num_dims != fork_count { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), - error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name), + error: format!( + "expected {} forks, found {} in {}", + fork_count, num_dims, pm.functions[func].name + ), }); } @@ -3040,7 +3056,9 @@ fn run_pass( let top_fork = forks.next().unwrap(); let mut cur_fork = top_fork; for next_fork in forks { - let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else { + let Some((new_fork, new_join)) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "failed to coalesce forks".to_string(), @@ -3068,7 +3086,8 @@ fn run_pass( continue; } assert!(idx < cur_idx); - let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else { + let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) + else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "failed to interchange forks".to_string(), @@ -3103,7 +3122,9 @@ fn run_pass( let next_fork = forks[fork_idx + i]; // Again, not sure at this point how coalesce could fail, so panic if it // does - let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap(); + let (new_fork, new_join) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + .unwrap(); cur_fork = new_fork; fork_join_map.insert(new_fork, new_join); } @@ -3118,13 +3139,13 @@ fn run_pass( pm.clear_analyses(); pm.make_def_uses(); pm.make_nodes_in_fork_joins(); - + let def_uses = pm.def_uses.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let def_use = &def_uses[func]; let nodes_in_fork_joins = &nodes_in_fork_joins[func]; - + let mut editor = FunctionEditor::new( &mut pm.functions[func], func_id, @@ -3135,10 +3156,20 @@ fn run_pass( def_use, ); - let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied())) - .into_iter() - .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] }) - .collect(); + let labels = create_labels_for_node_sets( + &mut editor, + final_forks + .into_iter() + .map(|fork| nodes_in_fork_joins[&fork].iter().copied()), + ) + .into_iter() + .map(|(_, label)| Value::Label { + labels: vec![LabelInfo { + func: func_id, + label, + }], + }) + .collect(); result = Value::Tuple { values: labels }; changed = true; -- GitLab