diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 0e332a0033c50585160242f04bfdcceb37f87ad5..6f0fdf4dcb04e4e9d5adfec80053be3ecfc2b08d 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -457,7 +457,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { pub fn add_node(&mut self, node: Node) -> NodeID { let id = NodeID::new(self.num_node_ids()); // Added nodes need to have an entry in the def-use map. - self.updated_def_use.insert(id, HashSet::new()); + self.updated_def_use.entry(id).or_insert(HashSet::new()); // Added nodes use other nodes, and we need to update their def-use // entries. for u in get_uses(&node).as_ref() { diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ff0f0283767996914e8f9b2274ed9a6d538b1812..e6db0345def31324243cdee2bdcb6b5cca5d9a7b 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -578,7 +578,7 @@ pub fn fork_coalesce( // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { - if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { + if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() { return true; } } @@ -587,13 +587,15 @@ pub fn fork_coalesce( /** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. + Returns None if the forks could not be merged and the NodeIDs of the + resulting fork and join if it succeeds in merging them. */ pub fn fork_coalesce_helper( editor: &mut FunctionEditor, outer_fork: NodeID, inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, -) -> bool { +) -> Option<(NodeID, NodeID)> { // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; @@ -621,47 +623,35 @@ pub fn fork_coalesce_helper( reduct: _, } = inner_reduce_node else { - return false; + return None; }; // FIXME: check this condition better (i.e reduce might not be attached to join) if *inner_control != inner_join { - return false; + return None; }; if *inner_init != outer_reduce { - return false; + return None; }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { - return false; + return None; } else { pairs.insert(outer_reduce, inner_reduce); } } // Check for control between join-join and fork-fork - let Some(user) = editor - .get_users(outer_fork) - .filter(|node| editor.func().nodes[node.idx()].is_control()) - .next() - else { - return false; - }; + let (control, _) = editor.node(inner_fork).try_fork().unwrap(); - if user != inner_fork { - return false; + if control != outer_fork { + return None; } - let Some(user) = editor - .get_users(inner_join) - .filter(|node| editor.func().nodes[node.idx()].is_control()) - .next() - else { - return false; - }; + let control = editor.node(outer_join).try_join().unwrap(); - if user != outer_join { - return false; + if control != inner_join { + return None; } // Checklist: @@ -686,46 +676,47 @@ pub fn fork_coalesce_helper( // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); - for tid in inner_tids { - let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); - let new_tid = Node::ThreadID { - control: fork, - dimension: dim + num_outer_dims, - }; + let mut new_fork = NodeID::new(0); + let new_join = inner_join; // We'll reuse the inner join as the join of the new fork + + let success = editor.edit(|mut edit| { + for tid in inner_tids { + let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap(); + let new_tid = Node::ThreadID { + control: fork, + dimension: dim + num_outer_dims, + }; - editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); - let mut edit = edit.replace_all_uses(tid, new_tid)?; + edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); - Ok(edit) - }); - } - - // Fuse Reductions - for (outer_reduce, inner_reduce) in pairs { - let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] - .try_reduce() - .unwrap(); - let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] - .try_reduce() - .unwrap(); - editor.edit(|mut edit| { + } + // Fuse Reductions + for (outer_reduce, inner_reduce) in pairs { + let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap(); + let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap(); // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; + } - Ok(edit) - }); - } - - editor.edit(|mut edit| { - let new_fork = Node::Fork { + let new_fork_node = Node::Fork { control: outer_pred, factors: new_factors.into(), }; - let new_fork = edit.add_node(new_fork); + new_fork = edit.add_node(new_fork_node); + + if edit + .get_schedule(outer_fork) + .contains(&Schedule::ParallelFork) + && edit + .get_schedule(inner_fork) + .contains(&Schedule::ParallelFork) + { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; @@ -737,7 +728,11 @@ pub fn fork_coalesce_helper( Ok(edit) }); - true + if success { + Some((new_fork, new_join)) + } else { + None + } } pub fn split_any_fork( @@ -760,7 +755,7 @@ pub fn split_any_fork( * Useful for code generation. A single iteration of `fork_split` only splits * at most one fork-join, it must be called repeatedly to split all fork-joins. */ -pub(crate) fn split_fork( +pub fn split_fork( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, @@ -1215,13 +1210,13 @@ pub fn fork_interchange_all_forks( } } -fn fork_interchange( +pub fn fork_interchange( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, first_dim: usize, second_dim: usize, -) { +) -> Option<NodeID> { // Check that every reduce on the join is parallel or associative. let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; @@ -1234,7 +1229,7 @@ fn fork_interchange( }) { // If not, we can't necessarily do interchange. - return; + return None; } let Node::Fork { @@ -1276,6 +1271,7 @@ fn fork_interchange( let mut factors = factors.clone(); factors.swap(first_dim, second_dim); let new_fork = Node::Fork { control, factors }; + let mut new_fork_id = None; editor.edit(|mut edit| { for (old_id, new_tid) in fix_tids { let new_id = edit.add_node(new_tid); @@ -1283,9 +1279,17 @@ fn fork_interchange( edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); + if edit.get_schedule(fork).contains(&Schedule::ParallelFork) { + edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; + } edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) + edit = edit.delete_node(fork)?; + + new_fork_id = Some(new_fork); + Ok(edit) }); + + new_fork_id } /* diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 277276648e905186bfeb54714fb00f7275f17b22..00e5b873e2061f98b911876900890deb8b3abcef 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -25,7 +25,8 @@ fn main() { let a = HerculesImmBox::from(a.as_ref()); let b = HerculesImmBox::from(b.as_ref()); let mut r = runner!(matmul); - let mut c: HerculesMutBox<i32> = HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); + let mut c: HerculesMutBox<i32> = + HerculesMutBox::from(r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await); assert_eq!(c.as_slice(), correct_c.as_ref()); }); } diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 0be838c620761e8726590e2dbaf7bfdb7a82e3df..d2813388e0e7a1d7bd1696ffbb641e629096e2c2 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -6,6 +6,8 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..69f1811d08a5691de6dc10fa86f0720e416fb85a --- /dev/null +++ b/juno_samples/matmul/src/cpu.sch @@ -0,0 +1,61 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); +} + +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} + +macro unforkify!(X) { + fork-split(X); + unforkify(X); +} + +optimize!(*); +forkify!(*); +associative(matmul@outer); + +// Parallelize by computing output array as 16 chunks +let par = matmul@outer \ matmul@inner; +fork-tile; +let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); +parallelize!(outer \ inner); + +let body = outline(inner); +cpu(body); + +// Tile for cache, assuming 64B cache lines +fork-tile; +let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); + +reduce-slf(inner); +unforkify!(body); +codegen-prep!(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index bd27350a26d58f4e729b24d0026f12cd13ca7195..9d020c64ccef3b9c0a79694876c5b0ace606f938 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -135,6 +135,7 @@ impl FromStr for Appliable { "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), + "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), @@ -207,6 +208,28 @@ fn compile_stmt( exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } + parser::Stmt::LetsStmt { + span: _, + vars, + expr, + } => { + let tmp = format!("{}_tmp", macros.uniq()); + Ok(std::iter::once(ir::ScheduleStmt::Let { + var: tmp.clone(), + exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, + }) + .chain(vars.into_iter().enumerate().map(|(idx, v)| { + let var = lexer.span_str(v).to_string(); + ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::TupleField { + lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), + field: idx, + }, + } + })) + .collect()) + } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); Ok(vec![ir::ScheduleStmt::Assign { @@ -489,6 +512,29 @@ fn compile_expr( rhs: Box::new(rhs), })) } + parser::Expr::Tuple { span: _, exps } => { + let exprs = exps + .into_iter() + .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros)) + .fold(Ok(vec![]), |mut res, exp| { + let mut res = res?; + res.push(exp?); + Ok(res) + })?; + Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs })) + } + parser::Expr::TupleField { + span: _, + lhs, + field, + } => { + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + let field = lexer.span_str(field).parse().expect("Parsing"); + Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { + lhs: Box::new(lhs), + field, + })) + } } } diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 6aa85fe53689cf015497e56850ef0c197ccbdae0..ab1495b816c99452560d03c0addf77a5aec18974 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -21,6 +21,7 @@ pub enum Pass { ForkFusion, ForkGuardElim, ForkInterchange, + ForkReshape, ForkSplit, ForkUnroll, Forkify, @@ -59,6 +60,7 @@ impl Pass { Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, + Pass::ForkReshape => true, Pass::InterproceduralSROA => num == 0 || num == 1, Pass::Print => num == 1, Pass::Rename => num == 1, @@ -76,6 +78,7 @@ impl Pass { Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", + Pass::ForkReshape => "any", Pass::InterproceduralSROA => "0 or 1", Pass::Print => "1", Pass::Rename => "1", @@ -130,6 +133,13 @@ pub enum ScheduleExp { lhs: Box<ScheduleExp>, rhs: Box<ScheduleExp>, }, + Tuple { + exprs: Vec<ScheduleExp>, + }, + TupleField { + lhs: Box<ScheduleExp>, + field: usize, + }, // This is used to "box" a selection by evaluating it at one point and then // allowing it to be used as a selector later on Selection { diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 3b030e1d42bdb970cdfa67d21c4198dc89edea9e..451f035b8122ac1a694567327a44776c9326fff2 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -19,6 +19,8 @@ Schedule -> OperationList Stmt -> Stmt : 'let' 'ID' '=' Expr ';' { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } } + | 'let' '(' Ids ')' '=' Expr ';' + { Stmt::LetsStmt { span: $span, vars: rev($3), expr: $6 } } | 'ID' '=' Expr ';' { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } } | Expr ';' @@ -56,10 +58,14 @@ Expr -> Expr { Expr::String { span: $span } } | Expr '.' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | Expr '.' 'INT' + { Expr::TupleField { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } | Expr '@' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } - | '(' Expr ')' - { $2 } + | '(' Exprs ')' + { Expr::Tuple { span: $span, exps: $2 } } + | '[' Exprs ']' + { Expr::Tuple { span: $span, exps: $2 } } | '{' Schedule '}' { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' @@ -73,14 +79,18 @@ Expr -> Expr ; Args -> Vec<Expr> - : { vec![] } - | '[' Exprs ']' { rev($2) } + : { vec![] } + | '[' RExprs ']' { rev($2) } ; Exprs -> Vec<Expr> - : { vec![] } - | Expr { vec![$1] } - | Expr ',' Exprs { snoc($1, $3) } + : RExprs { rev($1) } + ; + +RExprs -> Vec<Expr> + : { vec![] } + | Expr { vec![$1] } + | Expr ',' RExprs { snoc($1, $3) } ; Fields -> Vec<(Span, Expr)> @@ -149,6 +159,7 @@ pub enum OperationList { pub enum Stmt { LetStmt { span: Span, var: Span, expr: Expr }, + LetsStmt { span: Span, vars: Vec<Span>, expr: Expr }, AssignStmt { span: Span, var: Span, rhs: Expr }, ExprStmt { span: Span, exp: Expr }, Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> }, @@ -180,6 +191,8 @@ pub enum Expr { BlockExpr { span: Span, body: Box<OperationList> }, Record { span: Span, fields: Vec<(Span, Expr)> }, SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> }, + Tuple { span: Span, exps: Vec<Expr> }, + TupleField { span: Span, lhs: Box<Expr>, field: Span }, } pub enum Selector { diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 70d8e4278169ebdbe9985e00ede161acbe05c24d..456df2eda49b93a6c80327a090b6f6606ae711bb 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -294,6 +294,9 @@ pub enum Value { Record { fields: HashMap<String, Value>, }, + Tuple { + values: Vec<Value>, + }, Everything {}, Selection { selection: Vec<Value>, @@ -371,6 +374,11 @@ impl Value { "Expected code selection, found record".to_string(), )); } + Value::Tuple { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found tuple".to_string(), + )); + } Value::Integer { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found integer".to_string(), @@ -1294,6 +1302,7 @@ fn interp_expr( | Value::Integer { .. } | Value::Boolean { .. } | Value::String { .. } + | Value::Tuple { .. } | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { @@ -1467,6 +1476,28 @@ fn interp_expr( Ok((Value::Selection { selection: values }, changed)) } }, + ScheduleExp::Tuple { exprs } => { + let mut vals = vec![]; + let mut changed = false; + for exp in exprs { + let (val, change) = interp_expr(pm, exp, stringtab, env, functions)?; + vals.push(val); + changed = changed || change; + } + Ok((Value::Tuple { values: vals }, changed)) + } + ScheduleExp::TupleField { lhs, field } => { + let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?; + match val { + Value::Tuple { values } if *field < values.len() => { + Ok((vec_take(values, *field), changed)) + } + _ => Err(SchedulerError::SemanticError(format!( + "No field at index {}", + field + ))), + } + } } } @@ -1524,6 +1555,14 @@ fn update_value( Some(Value::Record { fields: new_fields }) } } + // For tuples, if we deleted values like we do for records this would mess up the indices + // which would behave very strangely. Instead if any field cannot be updated then we + // eliminate the entire value + Value::Tuple { values } => values + .into_iter() + .map(|v| update_value(v, func_idx, juno_func_idx)) + .collect::<Option<Vec<_>>>() + .map(|values| Value::Tuple { values }), Value::JunoFunction { func } => { juno_func_idx[func.idx] .clone() @@ -2897,6 +2936,247 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkReshape => { + let mut shape = vec![]; + let mut loops = BTreeSet::new(); + let mut fork_count = 0; + + for arg in args { + let Value::Tuple { values } = arg else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + + let mut indices = vec![]; + for val in values { + let Value::Integer { val: idx } = val else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "expected each argument to be a list of integers".to_string(), + }); + }; + indices.push(idx); + loops.insert(idx); + fork_count += 1; + } + shape.push(indices); + } + + if loops != (0..fork_count).collect() { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: + "expected forks to be numbered sequentially from 0 and used exactly once" + .to_string(), + }); + } + + let Some((nodes, func_id)) = selection_as_set(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "must be applied to nodes in a single function".to_string(), + }); + }; + let func = func_id.idx(); + + pm.make_def_uses(); + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_reduce_cycles(); + + let def_uses = pm.def_uses.take().unwrap(); + let mut fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + + let def_use = &def_uses[func]; + let fork_join_map = &mut fork_join_maps[func]; + let loops = &loops[func]; + let reduce_cycles = &reduce_cycles[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + // There should be exactly one fork nest in the selection and it should contain + // exactly fork_count forks (counting each dimension of each fork) + // We determine the loops (ordered top-down) that are contained in the selection + // (in particular the header is in the selection) and its a fork-join (the header + // is a fork) + let mut loops = loops + .bottom_up_loops() + .into_iter() + .rev() + .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); + let Some((top_fork_head, top_fork_body)) = loops.next() else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!( + "expected {} forks found 0 in {}", + fork_count, + editor.func().name + ), + }); + }; + // All the remaining forks need to be contained in the top fork body + let mut forks = vec![top_fork_head]; + let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len(); + for (head, _) in loops { + if !top_fork_body[head.idx()] { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "selection includes multiple non-nested forks".to_string(), + }); + } else { + forks.push(head); + num_dims += editor.node(head).try_fork().unwrap().1.len(); + } + } + + if num_dims != fork_count { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: format!( + "expected {} forks, found {} in {}", + fork_count, num_dims, pm.functions[func].name + ), + }); + } + + // Now, we coalesce all of these forks into one so that we can interchange them + let mut forks = forks.into_iter(); + let top_fork = forks.next().unwrap(); + let mut cur_fork = top_fork; + for next_fork in forks { + let Some((new_fork, new_join)) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to coalesce forks".to_string(), + }); + }; + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + let join = *fork_join_map.get(&cur_fork).unwrap(); + + // Now we have just one fork and we can perform the interchanges we need + // To do this, we track two maps: from original index to current index and from + // current index to original index + let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>(); + let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>(); + + // Now, starting from the first (outermost) index we move the desired fork bound + // into place + for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() { + let cur_idx = orig_to_cur[*original_idx]; + let swapping = cur_to_orig[idx]; + + // If the desired factor is already in the correct place, do nothing + if cur_idx == idx { + continue; + } + assert!(idx < cur_idx); + let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) + else { + return Err(SchedulerError::PassError { + pass: "fork-reshape".to_string(), + error: "failed to interchange forks".to_string(), + }); + }; + cur_fork = fork_res; + + // Update our maps + orig_to_cur[*original_idx] = idx; + orig_to_cur[swapping] = cur_idx; + cur_to_orig[idx] = *original_idx; + cur_to_orig[cur_idx] = swapping; + } + + // Finally we split the fork into the desired pieces. We do this by first splitting + // the fork into individual forks and then coalesce the chunks together + // Not sure how split_fork could fail, so if it does panic is fine + let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap(); + + for (fork, join) in forks.iter().zip(joins.iter()) { + fork_join_map.insert(*fork, *join); + } + + // Finally coalesce the chunks together + let mut fork_idx = 0; + let mut final_forks = vec![]; + for chunk in shape.iter() { + let chunk_len = chunk.len(); + + let mut cur_fork = forks[fork_idx]; + for i in 1..chunk_len { + let next_fork = forks[fork_idx + i]; + // Again, not sure at this point how coalesce could fail, so panic if it + // does + let (new_fork, new_join) = + fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) + .unwrap(); + cur_fork = new_fork; + fork_join_map.insert(new_fork, new_join); + } + + fork_idx += chunk_len; + final_forks.push(cur_fork); + } + + // Label each fork and return the labels + // We've trashed our analyses at this point, so rerun them so that we can determine the + // nodes in each of the result fork-joins + pm.clear_analyses(); + pm.make_def_uses(); + pm.make_nodes_in_fork_joins(); + + let def_uses = pm.def_uses.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + let def_use = &def_uses[func]; + let nodes_in_fork_joins = &nodes_in_fork_joins[func]; + + let mut editor = FunctionEditor::new( + &mut pm.functions[func], + func_id, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ); + + let labels = create_labels_for_node_sets( + &mut editor, + final_forks + .into_iter() + .map(|fork| nodes_in_fork_joins[&fork].iter().copied()), + ) + .into_iter() + .map(|(_, label)| Value::Label { + labels: vec![LabelInfo { + func: func_id, + label, + }], + }) + .collect(); + + result = Value::Tuple { values: labels }; + changed = true; + + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { @@ -3047,3 +3327,7 @@ where }); labels } + +fn vec_take<T>(mut v: Vec<T>, index: usize) -> T { + v.swap_remove(index) +}