Skip to content
Snippets Groups Projects
Commit 2d577f79 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Fork-reshape

parent 751ded74
No related branches found
No related tags found
1 merge request!205Fork reshape
Pipeline #201926 passed
This commit is part of merge request !205. Comments created here will be created in the context of that merge request.
......@@ -6,6 +6,8 @@ fn main() {
JunoCompiler::new()
.file_in_src("matmul.jn")
.unwrap()
.schedule_in_src("cpu.sch")
.unwrap()
.build()
.unwrap();
}
......
macro optimize!(X) {
gvn(X);
phi-elim(X);
dce(X);
ip-sroa(X);
sroa(X);
dce(X);
gvn(X);
phi-elim(X);
dce(X);
}
macro codegen-prep!(X) {
optimize!(X);
gcm(X);
float-collections(X);
dce(X);
gcm(X);
}
macro forkify!(X) {
fixpoint {
forkify(X);
fork-guard-elim(X);
}
}
macro fork-tile![n](X) {
fork-tile[n, 0, false, true](X);
}
macro parallelize!(X) {
parallel-fork(X);
parallel-reduce(X);
}
macro unforkify!(X) {
fork-split(X);
unforkify(X);
}
optimize!(*);
forkify!(*);
associative(matmul@outer);
// Parallelize by computing output array as 16 chunks
let par = matmul@outer \ matmul@inner;
fork-tile![4](par);
let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1;
parallelize!(outer \ inner);
let body = outline(inner);
cpu(body);
// Tile for cache, assuming 64B cache lines
fork-split(body);
fork-tile![16](body);
let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1;
reduce-slf(inner);
unforkify!(body);
codegen-prep!(*);
......@@ -134,6 +134,7 @@ impl FromStr for Appliable {
"fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
"fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
"fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
"fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
"lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
"loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)),
"outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
......
......@@ -20,6 +20,7 @@ pub enum Pass {
ForkFusion,
ForkGuardElim,
ForkInterchange,
ForkReshape,
ForkSplit,
ForkUnroll,
Forkify,
......@@ -57,6 +58,7 @@ impl Pass {
Pass::ForkExtend => num == 1,
Pass::ForkFissionBufferize => num == 2 || num == 1,
Pass::ForkInterchange => num == 2,
Pass::ForkReshape => true,
Pass::Print => num == 1,
Pass::Rename => num == 1,
Pass::SROA => num == 0 || num == 1,
......@@ -73,6 +75,7 @@ impl Pass {
Pass::ForkExtend => "1",
Pass::ForkFissionBufferize => "1 or 2",
Pass::ForkInterchange => "2",
Pass::ForkReshape => "any",
Pass::Print => "1",
Pass::Rename => "1",
Pass::SROA => "0 or 1",
......
......@@ -61,9 +61,9 @@ Expr -> Expr
| Expr '@' 'ID'
{ Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } }
| '(' Exprs ')'
{ exprs_to_expr($span, $2) }
{ Expr::Tuple { span: $span, exps: $2 } }
| '[' Exprs ']'
{ exprs_to_expr($span, $2) }
{ Expr::Tuple { span: $span, exps: $2 } }
| '{' Schedule '}'
{ Expr::BlockExpr { span: $span, body: Box::new($2) } }
| '<' Fields '>'
......@@ -203,11 +203,3 @@ pub struct MacroDecl {
pub selection_name: Span,
pub def: Box<OperationList>,
}
fn exprs_to_expr(span: Span, mut exps: Vec<Expr>) -> Expr {
if exps.len() == 1 {
exps.pop().unwrap()
} else {
Expr::Tuple { span, exps }
}
}
......@@ -2901,6 +2901,220 @@ fn run_pass(
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::ForkReshape => {
let mut shape = vec![];
let mut loops = BTreeSet::new();
let mut fork_count = 0;
for arg in args {
let Value::Tuple { values } = arg else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "expected each argument to be a list of integers".to_string(),
});
};
let mut indices = vec![];
for val in values {
let Value::Integer { val: idx } = val else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "expected each argument to be a list of integers".to_string(),
});
};
indices.push(idx);
loops.insert(idx);
fork_count += 1;
}
shape.push(indices);
}
if loops != (0..fork_count).collect() {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "expected forks to be numbered sequentially from 0 and used exactly once".to_string(),
});
}
let Some((nodes, func_id)) = selection_as_set(pm, selection) else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "must be applied to nodes in a single function".to_string(),
});
};
let func = func_id.idx();
pm.make_def_uses();
pm.make_fork_join_maps();
pm.make_loops();
pm.make_reduce_cycles();
let def_uses = pm.def_uses.take().unwrap();
let mut fork_join_maps = pm.fork_join_maps.take().unwrap();
let loops = pm.loops.take().unwrap();
let reduce_cycles = pm.reduce_cycles.take().unwrap();
let def_use = &def_uses[func];
let fork_join_map = &mut fork_join_maps[func];
let loops = &loops[func];
let reduce_cycles = &reduce_cycles[func];
let mut editor = FunctionEditor::new(
&mut pm.functions[func],
func_id,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
);
// There should be exactly one fork nest in the selection and it should contain
// exactly fork_count forks (counting each dimension of each fork)
// We determine the loops (ordered top-down) that are contained in the selection
// (in particular the header is in the selection) and its a fork-join (the header
// is a fork)
let mut loops = loops.bottom_up_loops().into_iter().rev()
.filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork());
let Some((top_fork_head, top_fork_body)) = loops.next() else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: format!("expected {} forks found 0 in {}", fork_count, editor.func().name),
});
};
// All the remaining forks need to be contained in the top fork body
let mut forks = vec![top_fork_head];
let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len();
for (head, _) in loops {
if !top_fork_body[head.idx()] {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "selection includes multiple non-nested forks".to_string(),
});
} else {
forks.push(head);
num_dims += editor.node(head).try_fork().unwrap().1.len();
}
}
if num_dims != fork_count {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: format!("expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name),
});
}
// Now, we coalesce all of these forks into one so that we can interchange them
let mut forks = forks.into_iter();
let top_fork = forks.next().unwrap();
let mut cur_fork = top_fork;
for next_fork in forks {
let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "failed to coalesce forks".to_string(),
});
};
cur_fork = new_fork;
fork_join_map.insert(new_fork, new_join);
}
let join = *fork_join_map.get(&cur_fork).unwrap();
// Now we have just one fork and we can perform the interchanges we need
// To do this, we track two maps: from original index to current index and from
// current index to original index
let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>();
let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>();
// Now, starting from the first (outermost) index we move the desired fork bound
// into place
for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() {
let cur_idx = orig_to_cur[*original_idx];
let swapping = cur_to_orig[idx];
// If the desired factor is already in the correct place, do nothing
if cur_idx == idx {
continue;
}
assert!(idx < cur_idx);
let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else {
return Err(SchedulerError::PassError {
pass: "fork-reshape".to_string(),
error: "failed to interchange forks".to_string(),
});
};
cur_fork = fork_res;
// Update our maps
orig_to_cur[*original_idx] = idx;
orig_to_cur[swapping] = cur_idx;
cur_to_orig[idx] = *original_idx;
cur_to_orig[cur_idx] = swapping;
}
// Finally we split the fork into the desired pieces. We do this by first splitting
// the fork into individual forks and then coalesce the chunks together
// Not sure how split_fork could fail, so if it does panic is fine
let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap();
for (fork, join) in forks.iter().zip(joins.iter()) {
fork_join_map.insert(*fork, *join);
}
// Finally coalesce the chunks together
let mut fork_idx = 0;
let mut final_forks = vec![];
for chunk in shape.iter() {
let chunk_len = chunk.len();
let mut cur_fork = forks[fork_idx];
for i in 1..chunk_len {
let next_fork = forks[fork_idx + i];
// Again, not sure at this point how coalesce could fail, so panic if it
// does
let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map).unwrap();
cur_fork = new_fork;
fork_join_map.insert(new_fork, new_join);
}
fork_idx += chunk_len;
final_forks.push(cur_fork);
}
// Label each fork and return the labels
// We've trashed our analyses at this point, so rerun them so that we can determine the
// nodes in each of the result fork-joins
pm.clear_analyses();
pm.make_def_uses();
pm.make_nodes_in_fork_joins();
let def_uses = pm.def_uses.take().unwrap();
let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
let def_use = &def_uses[func];
let nodes_in_fork_joins = &nodes_in_fork_joins[func];
let mut editor = FunctionEditor::new(
&mut pm.functions[func],
func_id,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
);
let labels = create_labels_for_node_sets(&mut editor, final_forks.into_iter().map(|fork| nodes_in_fork_joins[&fork].iter().copied()))
.into_iter()
.map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label }] })
.collect();
result = Value::Tuple { values: labels };
changed = true;
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::WritePredication => {
assert!(args.is_empty());
for func in build_selection(pm, selection, false) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment