Skip to content
Snippets Groups Projects

Fork fission bufferize

Merged Xavier Routh requested to merge fork-fission-bufferize into main
All threads resolved!
4 files
+ 90
42
Compare changes
  • Side-by-side
  • Inline
Files
4
@@ -695,6 +695,24 @@ pub(crate) fn split_fork(
}
}
pub fn chunk_all_forks_unguarded(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
dim_idx: usize,
tile_size: usize,
) -> () {
// Add dc
let mut dc_id = DynamicConstantID::new(0);
editor.edit(|mut edit| {
dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size));
Ok(edit)
});
for (fork, _ ) in fork_join_map {
chunk_fork_unguarded(editor, *fork, dim_idx, dc_id);
}
}
// Splits a dimension of a single fork join into multiple.
// Iterates an outer loop original_dim / tile_size times
// adds a tile_size loop as the inner loop
@@ -711,39 +729,36 @@ pub fn chunk_fork_unguarded(
let mut new_factors: Vec<_> = old_factors.to_vec();
let fork_users: Vec<_> = editor.get_users(fork).collect();
let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
editor.edit(|mut edit| {
let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx + 1, tile_size);
new_factors[dim_idx] = edit.add_dynamic_constant(outer);
let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
for tid in fork_users {
let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
editor.edit(|mut edit| {
for (tid, node) in fork_users {
let Node::ThreadID { control: _, dimension: tid_dim } = node else {continue};
if tid_dim > dim_idx {
let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
let new_tid = edit.add_node(new_tid);
edit.replace_all_uses(tid, new_tid)
edit = edit.replace_all_uses(tid, new_tid)?;
} else if tid_dim == dim_idx {
let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1 };
let tile_tid = edit.add_node(tile_tid);
let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
edit.replace_all_uses_where(tid, add, |usee| *usee != mul )
} else {
Ok(edit)
edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul )?;
}
});
}
editor.edit(|mut edit| {
let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx + 1, tile_size);
new_factors[dim_idx] = edit.add_dynamic_constant(outer);
let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
let new_fork = edit.add_node(new_fork);
edit.replace_all_uses(fork, new_fork)
}
edit = edit.delete_node(fork)?;
Ok(edit)
});
}
@@ -791,9 +806,8 @@ pub fn fork_dim_merge(
let mut new_factors: Vec<_> = old_factors.to_vec();
let fork_users: Vec<_> = editor.get_users(fork).collect();
let fork_users: Vec<_> = editor.get_users(fork).map(|f| (f, editor.node(f).clone())).collect();
let mut new_nodes = vec![];
@@ -801,6 +815,7 @@ pub fn fork_dim_merge(
let inner_dc_id = new_factors[inner_idx];
let mut new_fork_id = NodeID::new(0);
editor.edit(|mut edit| {
new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
new_factors.remove(inner_idx);
@@ -809,22 +824,20 @@ pub fn fork_dim_merge(
let new_fork = edit.add_node(new_fork);
new_fork_id = new_fork;
edit.sub_edit(fork, new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
});
for tid in fork_users {
let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
edit = edit.delete_node(fork)?;
println!("tid: {:?}", tid);
editor.edit(|mut edit| {
for (tid, node) in fork_users {
// FIXME: DO we want sub edits in this?
let Node::ThreadID { control: _, dimension: tid_dim } = node else { continue };
if tid_dim > inner_idx {
let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
let new_tid = edit.add_node(new_tid);
edit.replace_all_uses(tid, new_tid)
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
} else if tid_dim == outer_idx {
let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
let outer_tid = edit.add_node(outer_tid);
@@ -834,8 +847,8 @@ pub fn fork_dim_merge(
// inner_idx % dim(outer_idx)
let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
edit.replace_all_uses(tid, rem)
edit.sub_edit(tid, rem);
edit = edit.replace_all_uses(tid, rem)?;
} else if tid_dim == inner_idx {
let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
let outer_tid = edit.add_node(outer_tid);
@@ -843,13 +856,12 @@ pub fn fork_dim_merge(
let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
// inner_idx / dim(outer_idx)
let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
edit.replace_all_uses(tid, div)
} else {
Ok(edit)
edit.sub_edit(tid, div);
edit = edit.replace_all_uses(tid, div)?;
}
});
};
}
Ok(edit)
});
return new_fork_id;
Loading