Skip to content
Snippets Groups Projects
Commit f092ae38 authored by Russel Arbore's avatar Russel Arbore
Browse files

more backprop opt

parent 54438766
No related branches found
No related tags found
2 merge requests!215Large benches,!214More optimizations
Pipeline #202010 passed
...@@ -319,12 +319,12 @@ pub fn fork_fission<'a>( ...@@ -319,12 +319,12 @@ pub fn fork_fission<'a>(
.collect(); .collect();
let mut created_forks = Vec::new(); let mut created_forks = Vec::new();
// This does the reduction fission // This does the reduction fission
for fork in forks { for fork in forks {
let join = fork_join_map[&fork.0]; let join = fork_join_map[&fork.0];
// FIXME: Don't make multiple forks for reduces that are in cycles with each other. // FIXME: Don't make multiple forks for reduces that are in cycles with each other.
let reduce_partition = default_reduce_partition(editor, fork.0, join); let reduce_partition = default_reduce_partition(editor, fork.0, join);
if !editor.func().labels[fork.0.idx()].contains(&fork_label) { if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
...@@ -332,14 +332,19 @@ pub fn fork_fission<'a>( ...@@ -332,14 +332,19 @@ pub fn fork_fission<'a>(
} }
if editor.is_mutable(fork.0) { if editor.is_mutable(fork.0) {
created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0); created_forks = fork_reduce_fission_helper(
editor,
fork_join_map,
reduce_partition,
nodes_in_fork_joins,
fork.0,
);
if created_forks.is_empty() { if created_forks.is_empty() {
continue; continue;
} else { } else {
return created_forks; return created_forks;
} }
} }
} }
created_forks created_forks
...@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>( ...@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>(
let mut new_forks = Vec::new(); let mut new_forks = Vec::new();
let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap(); let mut new_control_pred: NodeID = editor
.get_uses(fork)
.filter(|n| editor.node(n).is_control())
.next()
.unwrap();
let mut new_fork = NodeID::new(0); let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0); let mut new_join = NodeID::new(0);
let subgraph = &nodes_in_fork_joins[&fork]; let subgraph = &nodes_in_fork_joins[&fork];
// Gets everything between fork & join that this reduce needs. (ALL CONTROL) // Gets everything between fork & join that this reduce needs. (ALL CONTROL)
editor.edit(|mut edit| { editor.edit(|mut edit| {
for reduce in reduce_partition { for reduce in reduce_partition {
...@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>( ...@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>(
new_fork = mapping[&fork]; new_fork = mapping[&fork];
new_forks.push(new_fork); new_forks.push(new_fork);
new_join = mapping[&join]; new_join = mapping[&join];
// Atttach new_fork after control_pred // Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
...@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>( ...@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>(
// Replace uses of reduce // Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?; edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
new_control_pred = new_join; new_control_pred = new_join;
}; }
// Replace original join w/ new final join // Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?; edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
...@@ -1502,6 +1511,10 @@ fn fork_fusion( ...@@ -1502,6 +1511,10 @@ fn fork_fusion(
* element. This aides in parallelizing outer loops. Looks only at reduces with * element. This aides in parallelizing outer loops. Looks only at reduces with
* the monoid reduce schedule, since that indicates a particular structure which * the monoid reduce schedule, since that indicates a particular structure which
* is annoying to check for again. * is annoying to check for again.
*
* Looks for would-be monoid reduces, if not for a gate on the reduction.
* Partially predicate the gated reduction to allow for a proper monoid
* reduction.
*/ */
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
for id in editor.node_ids() { for id in editor.node_ids() {
...@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue; continue;
}; };
let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
match nodes[reduct.idx()] { match nodes[reduct.idx()] {
Node::Binary { Node::Binary {
...@@ -1532,7 +1545,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1532,7 +1545,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init, left: init,
right: id, right: id,
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1555,7 +1568,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1555,7 +1568,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init, left: init,
right: id, right: id,
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1574,7 +1587,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1574,7 +1587,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Max, intrinsic: Intrinsic::Max,
args: Box::new([init, id]), args: Box::new([init, id]),
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1593,7 +1606,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1593,7 +1606,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Min, intrinsic: Intrinsic::Min,
args: Box::new([init, id]), args: Box::new([init, id]),
}); });
for u in out_uses { for u in out_users {
edit.sub_edit(u, final_op); edit.sub_edit(u, final_op);
} }
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
...@@ -1602,6 +1615,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1602,6 +1615,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
_ => {} _ => {}
} }
} }
for id in editor.node_ids() {
if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
continue;
}
let nodes = &editor.func().nodes;
let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
if let Node::Phi {
control: phi_control,
ref data,
} = nodes[reduct.idx()]
&& data.len() == 2
&& data.contains(&id)
&& let other = *data
.into_iter()
.filter(|other| **other != id)
.next()
.unwrap()
&& let Node::Binary {
op: BinaryOperator::Add,
left,
right,
} = nodes[other.idx()]
&& ((left == id) ^ (right == id))
{
let gated_input = if left == id { right } else { left };
let data = data.clone();
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[id.idx()]);
let zero = edit.add_node(Node::Constant { id: zero });
let phi = edit.add_node(Node::Phi {
control: phi_control,
data: data
.iter()
.map(|phi_use| if *phi_use == id { zero } else { gated_input })
.collect(),
});
let new_reduce_id = NodeID::new(edit.num_node_ids());
let new_reduct_id = NodeID::new(edit.num_node_ids() + 1);
let new_reduce = Node::Reduce {
control,
init,
reduct: new_reduct_id,
};
let new_add = Node::Binary {
op: BinaryOperator::Add,
left: new_reduce_id,
right: phi,
};
let new_reduce = edit.add_node(new_reduce);
edit.add_node(new_add);
edit = edit.replace_all_uses(id, new_reduce)?;
edit = edit.delete_node(id)?;
Ok(edit)
});
}
}
} }
/* /*
......
...@@ -126,11 +126,24 @@ fn remove_useless_fork_joins( ...@@ -126,11 +126,24 @@ fn remove_useless_fork_joins(
// Third, get rid of fork-joins. // Third, get rid of fork-joins.
for (fork, join) in fork_join_map { for (fork, join) in fork_join_map {
if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 { if editor.get_users(*join).len() == 1 {
let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0]; let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0]; let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
let tids: Vec<_> = editor
.get_users(*fork)
.filter(|id| editor.func().nodes[id.idx()].is_thread_id())
.collect();
editor.edit(|mut edit| { editor.edit(|mut edit| {
if !tids.is_empty() {
let u64_ty = edit.add_type(Type::UnsignedInteger64);
let zero = edit.add_zero_constant(u64_ty);
let zero = edit.add_node(Node::Constant { id: zero });
for tid in tids {
edit = edit.replace_all_uses(tid, zero)?;
edit = edit.delete_node(tid)?;
}
}
edit = edit.replace_all_uses(*join, join_use)?; edit = edit.replace_all_uses(*join, join_use)?;
edit = edit.replace_all_uses(*fork, fork_use)?; edit = edit.replace_all_uses(*fork, fork_use)?;
edit = edit.delete_node(*fork)?; edit = edit.delete_node(*fork)?;
......
...@@ -33,7 +33,11 @@ fixpoint { ...@@ -33,7 +33,11 @@ fixpoint {
reduce-slf(*); reduce-slf(*);
simpl!(*); simpl!(*);
fork-tile[16, 0, false, true](layer_forward@inner_loop); fork-extend[32](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
fork-tile[32, 0, false, true](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
let out = fork-split(layer_forward@inner_loop); let out = fork-split(layer_forward@inner_loop);
clean-monoid-reduces(layer_forward); clean-monoid-reduces(layer_forward);
simpl!(layer_forward); simpl!(layer_forward);
...@@ -47,5 +51,4 @@ fork-tile[32, 0, false, true](adjust_weights); ...@@ -47,5 +51,4 @@ fork-tile[32, 0, false, true](adjust_weights);
fork-split(adjust_weights); fork-split(adjust_weights);
simpl!(adjust_weights); simpl!(adjust_weights);
xdot[true](*);
gcm(*); gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment