Skip to content
Snippets Groups Projects
Commit fdb26f84 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'alternate-fork-tile' into 'main'

add alternate order to fork tilling

See merge request !169
parents 54f9d981 cb5a7b4e
No related branches found
No related tags found
1 merge request!169add alternate order to fork tilling
Pipeline #201606 passed
......@@ -856,6 +856,7 @@ pub fn chunk_all_forks_unguarded(
fork_join_map: &HashMap<NodeID, NodeID>,
dim_idx: usize,
tile_size: usize,
order: bool,
) -> () {
// Add dc
let mut dc_id = DynamicConstantID::new(0);
......@@ -864,19 +865,31 @@ pub fn chunk_all_forks_unguarded(
Ok(edit)
});
let order = match order {
true => &TileOrder::TileInner,
false => &TileOrder::TileOuter,
};
for (fork, _) in fork_join_map {
chunk_fork_unguarded(editor, *fork, dim_idx, dc_id);
chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order);
}
}
// Splits a dimension of a single fork join into multiple.
// Iterates an outer loop original_dim / tile_size times
// adds a tile_size loop as the inner loop
// Assumes that tile size divides original dim evenly.
enum TileOrder {
TileInner,
TileOuter,
}
pub fn chunk_fork_unguarded(
editor: &mut FunctionEditor,
fork: NodeID,
dim_idx: usize,
tile_size: DynamicConstantID,
order: &TileOrder,
) -> () {
// tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
let Node::Fork {
......@@ -893,63 +906,128 @@ pub fn chunk_fork_unguarded(
.map(|f| (f, editor.node(f).clone()))
.collect();
editor.edit(|mut edit| {
let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx + 1, tile_size);
new_factors[dim_idx] = edit.add_dynamic_constant(outer);
let new_fork = Node::Fork {
control: old_control,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.sub_edit(fork, new_fork);
for (tid, node) in fork_users {
let Node::ThreadID {
control: _,
dimension: tid_dim,
} = node
else {
continue;
};
if tid_dim > dim_idx {
let new_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
match order {
TileOrder::TileInner => {
editor.edit(|mut edit| {
let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx + 1, tile_size);
new_factors[dim_idx] = edit.add_dynamic_constant(outer);
let new_fork = Node::Fork {
control: old_control,
factors: new_factors.into(),
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
edit = edit.delete_node(tid)?;
} else if tid_dim == dim_idx {
let tile_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.sub_edit(fork, new_fork);
for (tid, node) in fork_users {
let Node::ThreadID {
control: _,
dimension: tid_dim,
} = node
else {
continue;
};
if tid_dim > dim_idx {
let new_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
edit = edit.delete_node(tid)?;
} else if tid_dim == dim_idx {
let tile_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let tile_tid = edit.add_node(tile_tid);
let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
let mul = edit.add_node(Node::Binary {
left: tid,
right: tile_size,
op: BinaryOperator::Mul,
});
let add = edit.add_node(Node::Binary {
left: mul,
right: tile_tid,
op: BinaryOperator::Add,
});
edit.sub_edit(tid, add);
edit.sub_edit(tid, tile_tid);
edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
}
}
edit = edit.delete_node(fork)?;
Ok(edit)
});
},
TileOrder::TileOuter => {
editor.edit(|mut edit| {
let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx, tile_size);
let inner_dc_id = edit.add_dynamic_constant(inner);
let new_fork = Node::Fork {
control: old_control,
factors: new_factors.into(),
};
let tile_tid = edit.add_node(tile_tid);
let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
let mul = edit.add_node(Node::Binary {
left: tid,
right: tile_size,
op: BinaryOperator::Mul,
});
let add = edit.add_node(Node::Binary {
left: mul,
right: tile_tid,
op: BinaryOperator::Add,
});
edit.sub_edit(tid, add);
edit.sub_edit(tid, tile_tid);
edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
}
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.sub_edit(fork, new_fork);
for (tid, node) in fork_users {
let Node::ThreadID {
control: _,
dimension: tid_dim,
} = node
else {
continue;
};
if tid_dim > dim_idx {
let new_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
edit = edit.delete_node(tid)?;
} else if tid_dim == dim_idx {
let tile_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim,
};
let tile_tid = edit.add_node(tile_tid);
let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } );
let mul = edit.add_node(Node::Binary {
left: tid,
right: inner_dc,
op: BinaryOperator::Mul,
});
let add = edit.add_node(Node::Binary {
left: mul,
right: tile_tid,
op: BinaryOperator::Add,
});
edit.sub_edit(tid, add);
edit.sub_edit(tid, tile_tid);
edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
}
}
edit = edit.delete_node(fork)?;
Ok(edit)
});
}
edit = edit.delete_node(fork)?;
Ok(edit)
});
}
}
pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
......
......@@ -54,7 +54,7 @@ gvn(*);
phi-elim(*);
dce(*);
fork-tile[32, 0, false](test6@loop);
fork-tile[32, 0, false, true](test6@loop);
let out = fork-split(test6@loop);
let out = outline(out.test6.fj1);
cpu(out);
......
......@@ -52,7 +52,7 @@ slf(auto.test2);
infer-schedules(auto.test2);
fork-interchange[0, 1](auto.test2);
fork-tile[32, 0, false](test6@loop);
fork-tile[32, 0, false, true](test6@loop);
let out = fork-split(test6@loop);
let out = auto-outline(test6);
gpu(out.test6);
......
......@@ -43,7 +43,7 @@ impl Pass {
pub fn num_args(&self) -> usize {
match self {
Pass::Xdot => 1,
Pass::ForkChunk => 3,
Pass::ForkChunk => 4,
Pass::ForkFissionBufferize => 2,
Pass::ForkInterchange => 2,
_ => 0,
......
......@@ -2118,7 +2118,7 @@ fn run_pass(
pm.clear_analyses();
}
Pass::ForkChunk => {
assert_eq!(args.len(), 3);
assert_eq!(args.len(), 4);
let Some(Value::Integer { val: tile_size }) = args.get(0) else {
return Err(SchedulerError::PassError {
pass: "forkChunk".to_string(),
......@@ -2137,6 +2137,12 @@ fn run_pass(
error: "expected boolean argument".to_string(),
});
};
let Some(Value::Boolean { val: tile_order }) = args.get(3) else {
return Err(SchedulerError::PassError {
pass: "forkChunk".to_string(),
error: "expected boolean argument".to_string(),
});
};
assert!(!*guarded_flag);
pm.make_fork_join_maps();
......@@ -2148,7 +2154,7 @@ fn run_pass(
let Some(mut func) = func else {
continue;
};
chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size);
chunk_all_forks_unguarded(&mut func, fork_join_map, *dim_idx, *tile_size, *tile_order);
changed |= func.modified();
}
pm.delete_gravestones();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment