Skip to content
Snippets Groups Projects
Commit f092ae38 authored by Russel Arbore's avatar Russel Arbore
Browse files

more backprop opt

parent 54438766
No related branches found
No related tags found
2 merge requests!215Large benches,!214More optimizations
Pipeline #202010 passed
This commit is part of merge request !214. Comments created here will be created in the context of that merge request.
......@@ -319,12 +319,12 @@ pub fn fork_fission<'a>(
.collect();
let mut created_forks = Vec::new();
// This does the reduction fission
// This does the reduction fission
for fork in forks {
let join = fork_join_map[&fork.0];
// FIXME: Don't make multiple forks for reduces that are in cycles with each other.
// FIXME: Don't make multiple forks for reduces that are in cycles with each other.
let reduce_partition = default_reduce_partition(editor, fork.0, join);
if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
......@@ -332,14 +332,19 @@ pub fn fork_fission<'a>(
}
if editor.is_mutable(fork.0) {
created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
created_forks = fork_reduce_fission_helper(
editor,
fork_join_map,
reduce_partition,
nodes_in_fork_joins,
fork.0,
);
if created_forks.is_empty() {
continue;
} else {
return created_forks;
}
}
}
created_forks
......@@ -503,13 +508,17 @@ pub fn fork_reduce_fission_helper<'a>(
let mut new_forks = Vec::new();
let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap();
let mut new_control_pred: NodeID = editor
.get_uses(fork)
.filter(|n| editor.node(n).is_control())
.next()
.unwrap();
let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0);
let subgraph = &nodes_in_fork_joins[&fork];
let subgraph = &nodes_in_fork_joins[&fork];
// Gets everything between fork & join that this reduce needs. (ALL CONTROL)
editor.edit(|mut edit| {
for reduce in reduce_partition {
......@@ -522,7 +531,7 @@ pub fn fork_reduce_fission_helper<'a>(
new_fork = mapping[&fork];
new_forks.push(new_fork);
new_join = mapping[&join];
// Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
......@@ -532,7 +541,7 @@ pub fn fork_reduce_fission_helper<'a>(
// Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
new_control_pred = new_join;
};
}
// Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
......@@ -1502,6 +1511,10 @@ fn fork_fusion(
* element. This aides in parallelizing outer loops. Looks only at reduces with
* the monoid reduce schedule, since that indicates a particular structure which
* is annoying to check for again.
*
* Looks for would-be monoid reduces, if not for a gate on the reduction.
* Partially predicate the gated reduction to allow for a proper monoid
* reduction.
*/
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
for id in editor.node_ids() {
......@@ -1512,7 +1525,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
match nodes[reduct.idx()] {
Node::Binary {
......@@ -1532,7 +1545,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init,
right: id,
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1555,7 +1568,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
left: init,
right: id,
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1574,7 +1587,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Max,
args: Box::new([init, id]),
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1593,7 +1606,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
intrinsic: Intrinsic::Min,
args: Box::new([init, id]),
});
for u in out_uses {
for u in out_users {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
......@@ -1602,6 +1615,65 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
_ => {}
}
}
for id in editor.node_ids() {
if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
continue;
}
let nodes = &editor.func().nodes;
let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
if let Node::Phi {
control: phi_control,
ref data,
} = nodes[reduct.idx()]
&& data.len() == 2
&& data.contains(&id)
&& let other = *data
.into_iter()
.filter(|other| **other != id)
.next()
.unwrap()
&& let Node::Binary {
op: BinaryOperator::Add,
left,
right,
} = nodes[other.idx()]
&& ((left == id) ^ (right == id))
{
let gated_input = if left == id { right } else { left };
let data = data.clone();
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[id.idx()]);
let zero = edit.add_node(Node::Constant { id: zero });
let phi = edit.add_node(Node::Phi {
control: phi_control,
data: data
.iter()
.map(|phi_use| if *phi_use == id { zero } else { gated_input })
.collect(),
});
let new_reduce_id = NodeID::new(edit.num_node_ids());
let new_reduct_id = NodeID::new(edit.num_node_ids() + 1);
let new_reduce = Node::Reduce {
control,
init,
reduct: new_reduct_id,
};
let new_add = Node::Binary {
op: BinaryOperator::Add,
left: new_reduce_id,
right: phi,
};
let new_reduce = edit.add_node(new_reduce);
edit.add_node(new_add);
edit = edit.replace_all_uses(id, new_reduce)?;
edit = edit.delete_node(id)?;
Ok(edit)
});
}
}
}
/*
......
......@@ -126,11 +126,24 @@ fn remove_useless_fork_joins(
// Third, get rid of fork-joins.
for (fork, join) in fork_join_map {
if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 {
if editor.get_users(*join).len() == 1 {
let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
let tids: Vec<_> = editor
.get_users(*fork)
.filter(|id| editor.func().nodes[id.idx()].is_thread_id())
.collect();
editor.edit(|mut edit| {
if !tids.is_empty() {
let u64_ty = edit.add_type(Type::UnsignedInteger64);
let zero = edit.add_zero_constant(u64_ty);
let zero = edit.add_node(Node::Constant { id: zero });
for tid in tids {
edit = edit.replace_all_uses(tid, zero)?;
edit = edit.delete_node(tid)?;
}
}
edit = edit.replace_all_uses(*join, join_use)?;
edit = edit.replace_all_uses(*fork, fork_use)?;
edit = edit.delete_node(*fork)?;
......
......@@ -33,7 +33,11 @@ fixpoint {
reduce-slf(*);
simpl!(*);
fork-tile[16, 0, false, true](layer_forward@inner_loop);
fork-extend[32](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
fork-tile[32, 0, false, true](layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
let out = fork-split(layer_forward@inner_loop);
clean-monoid-reduces(layer_forward);
simpl!(layer_forward);
......@@ -47,5 +51,4 @@ fork-tile[32, 0, false, true](adjust_weights);
fork-split(adjust_weights);
simpl!(adjust_weights);
xdot[true](*);
gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment