From 7370cb51e84d2992fcde703ed4bd8090250e9000 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 28 Feb 2025 11:47:28 -0600 Subject: [PATCH] Formatting --- hercules_opt/src/fork_transforms.rs | 19 ++++---- hercules_samples/matmul/src/main.rs | 3 +- juno_scheduler/src/compile.rs | 52 +++++++++++--------- juno_scheduler/src/pm.rs | 75 ++++++++++++++++++++--------- 4 files changed, 93 insertions(+), 56 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 9a16c99c..e6db0345 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -693,15 +693,11 @@ pub fn fork_coalesce_helper( } // Fuse Reductions for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = edit.get_node(outer_reduce) - .try_reduce() - .unwrap(); - let (_, inner_init, _) = edit.get_node(inner_reduce) - .try_reduce() - .unwrap(); + let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap(); // Set inner init to outer init. edit = - edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?; + edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; } @@ -712,8 +708,13 @@ pub fn fork_coalesce_helper( }; new_fork = edit.add_node(new_fork_node); - if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork) - && edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) { + if edit + .get_schedule(outer_fork) + .contains(&Schedule::ParallelFork) + && edit + .get_schedule(inner_fork) + .contains(&Schedule::ParallelFork) + { edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; } diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 27727664..00e5b873 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -25,7 +25,8 @@ fn main() { let a = HerculesImmBox::from(a.as_ref()); let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); - let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + let mut c: HerculesMutBox<i32> = + HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); assert_eq!(c.as_slice(), correct_c.as_ref()); }); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 32d2c8d1..9d020c64 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -208,23 +208,27 @@ fn compile_stmt( exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } - parser::Stmt::LetsStmt { span: _, vars, expr } => { + parser::Stmt::LetsStmt { + span: _, + vars, + expr, + } => { let tmp = format!("{}_tmp", macros.uniq()); Ok(std::iter::once(ir::ScheduleStmt::Let { var: tmp.clone(), exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, - }).chain(vars.into_iter().enumerate() - .map(|(idx, v)| { - let var = lexer.span_str(v).to_string(); - ir::ScheduleStmt::Let { - var, - exp: ir::ScheduleExp::TupleField { - lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), - field: idx, - } - } - }) - ).collect()) + }) + .chain(vars.into_iter().enumerate().map(|(idx, v)| { + let var = lexer.span_str(v).to_string(); + ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::TupleField { + lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), + field: idx, + }, + } + })) + .collect()) } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); @@ -508,17 +512,14 @@ fn compile_expr( rhs: Box::new(rhs), })) } - parser::Expr::Tuple { - span: _, - exps, - } => { - let exprs = exps.into_iter() + parser::Expr::Tuple { span: _, exps } => { + let exprs = exps + .into_iter() .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros)) - .fold(Ok(vec![]), - |mut res, exp| { - let mut res = res?; - res.push(exp?); - Ok(res) + .fold(Ok(vec![]), |mut res, exp| { + let mut res = res?; + res.push(exp?); + Ok(res) })?; Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs })) } @@ -529,7 +530,10 @@ fn compile_expr( } => { let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; let field = lexer.span_str(field).parse().expect("Parsing"); - Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { lhs: Box::new(lhs), field })) + Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { + lhs: Box::new(lhs), + field, + })) } } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f0b55eca..456df2ed 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1475,7 +1475,7 @@ fn interp_expr( } Ok((Value::Selection { selection: values }, changed)) } - } + }, ScheduleExp::Tuple { exprs } => { let mut vals = vec![]; let mut changed = false; @@ -1489,8 +1489,13 @@ fn interp_expr( ScheduleExp::TupleField { lhs, field } => { let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?; match val { - Value::Tuple { values } if *field < values.len() => Ok((vec_take(values, *field), changed)), - _ => Err(SchedulerError::SemanticError(format!("No field at index {}", field))), + Value::Tuple { values } if *field < values.len() => { + Ok((vec_take(values, *field), changed)) + } + _ => Err(SchedulerError::SemanticError(format!( + "No field at index {}", + field + ))), } } } @@ -1553,12 +1558,11 @@ fn update_value( // For tuples, if we deleted values like we do for records this would mess up the indices // which would behave very strangely. Instead if any field cannot be updated then we // eliminate the entire value - Value::Tuple { values } => { - values.into_iter() - .map(|v| update_value(v, func_idx, juno_func_idx)) - .collect::<Option<Vec<_>>>() - .map(|values| Value::Tuple { values }) - } + Value::Tuple { values } => values + .into_iter() + .map(|v| update_value(v, func_idx, juno_func_idx)) + .collect::<Option<Vec<_>>>() + .map(|values| Value::Tuple { values }), Value::JunoFunction { func } => { juno_func_idx[func.idx] .clone() @@ -2963,7 +2967,9 @@ fn run_pass( if loops != (0..fork_count).collect() { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), - error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(), + error: + "expected forks to be numbered sequentially from 0 and used exactly once" + .to_string(), }); } @@ -3005,12 +3011,19 @@ fn run_pass( // We determine the loops (ordered top-down) that are contained in the selection // (in particular the header is in the selection) and its a fork-join (the header // is a fork) - let mut loops = loops.bottom_up_loops().into_iter().rev() + let mut loops = loops + .bottom_up_loops() + .into_iter() + .rev() .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); let Some((top_fork_head, top_fork_body)) = loops.next() else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), - error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name), + error: format!( + "expected {} forks found 0 in {}", + fork_count, + editor.func().name + ), }); }; // All the remaining forks need to be contained in the top fork body @@ -3031,7 +3044,10 @@ fn run_pass( if num_dims != fork_count { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), - error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name), + error: format!( + "expected {} forks, found {} in {}", + fork_count, num_dims, pm.functions[func].name + ), }); } @@ -3040,7 +3056,9 @@ fn run_pass( let top_fork = forks.next().unwrap(); let mut cur_fork = top_fork; for next_fork in forks { - let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else { + let Some((new_fork, new_join)) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "failed to coalesce forks".to_string(), @@ -3068,7 +3086,8 @@ fn run_pass( continue; } assert!(idx < cur_idx); - let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else { + let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) + else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "failed to interchange forks".to_string(), @@ -3103,7 +3122,9 @@ fn run_pass( let next_fork = forks[fork_idx + i]; // Again, not sure at this point how coalesce could fail, so panic if it // does - let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap(); + let (new_fork, new_join) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + .unwrap(); cur_fork = new_fork; fork_join_map.insert(new_fork, new_join); } @@ -3118,13 +3139,13 @@ fn run_pass( pm.clear_analyses(); pm.make_def_uses(); pm.make_nodes_in_fork_joins(); - + let def_uses = pm.def_uses.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let def_use = &def_uses[func]; let nodes_in_fork_joins = &nodes_in_fork_joins[func]; - + let mut editor = FunctionEditor::new( &mut pm.functions[func], func_id, @@ -3135,10 +3156,20 @@ fn run_pass( def_use, ); - let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied())) - .into_iter() - .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] }) - .collect(); + let labels = create_labels_for_node_sets( + &mut editor, + final_forks + .into_iter() + .map(|fork| nodes_in_fork_joins[&fork].iter().copied()), + ) + .into_iter() + .map(|(_, label)| Value::Label { + labels: vec![LabelInfo { + func: func_id, + label, + }], + }) + .collect(); result = Value::Tuple { values: labels }; changed = true; -- GitLab