Skip to content
Snippets Groups Projects
Commit d95a7e49 authored by rarbore2's avatar rarbore2
Browse files

BFS optimization

parent 2ae082d1
No related branches found
No related tags found
1 merge request!225BFS optimization
...@@ -1515,6 +1515,10 @@ fn fork_fusion( ...@@ -1515,6 +1515,10 @@ fn fork_fusion(
* Looks for would-be monoid reduces, if not for a gate on the reduction. * Looks for would-be monoid reduces, if not for a gate on the reduction.
* Partially predicate the gated reduction to allow for a proper monoid * Partially predicate the gated reduction to allow for a proper monoid
* reduction. * reduction.
*
* Looks for monoid reduces that occur through a gated write to a single
* location, and lift them to a proper monoid reduction with a single gated
* write afterwards.
*/ */
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
for id in editor.node_ids() { for id in editor.node_ids() {
...@@ -1676,6 +1680,121 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { ...@@ -1676,6 +1680,121 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
}); });
} }
} }
for id in editor.node_ids() {
// Identify reduce/write/phi cycle through which a sparse AND reduction
// is occurring.
let nodes = &editor.func().nodes;
let Some((join, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
let join_pred = nodes[join.idx()].try_join().unwrap();
let join_succ = editor
.get_users(join)
.filter(|id| nodes[id.idx()].is_control())
.next()
.unwrap();
let Some((_, phi_data)) = nodes[reduct.idx()].try_phi() else {
continue;
};
if phi_data.len() != 2 {
continue;
}
let phi_other_use = if phi_data[0] == id {
phi_data[1]
} else if phi_data[1] == id {
phi_data[0]
} else {
continue;
};
let Some((collect, data, indices)) = nodes[phi_other_use.idx()].try_write() else {
continue;
};
if collect != id {
continue;
}
if indices.into_iter().any(|idx| idx.is_position()) {
continue;
}
if !is_false(editor, data) {
continue;
}
let Some(preds) = nodes[join_pred.idx()].try_region() else {
continue;
};
if preds.len() != 2 {
continue;
}
let Some((if1, _)) = nodes[preds[0].idx()].try_control_proj() else {
continue;
};
let Some((if2, sel)) = nodes[preds[1].idx()].try_control_proj() else {
continue;
};
if if1 != if2 {
continue;
}
let Some((_, mut cond)) = nodes[if1.idx()].try_if() else {
continue;
};
// Transform to a monoid reduction and a single gated write.
let negated = phi_other_use == phi_data[sel];
let indices = indices.to_vec().into_boxed_slice();
editor.edit(|mut edit| {
let t = edit.add_constant(Constant::Boolean(true));
let t = edit.add_node(Node::Constant { id: t });
let f = edit.add_constant(Constant::Boolean(false));
let f = edit.add_node(Node::Constant { id: f });
if negated {
cond = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: cond,
});
}
let reduce_id = NodeID::new(edit.num_node_ids());
let and_id = NodeID::new(edit.num_node_ids() + 1);
edit.add_node(Node::Reduce {
control: join,
init: t,
reduct: and_id,
});
edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: cond,
right: reduce_id,
});
let new_if = edit.add_node(Node::If {
control: join,
cond: reduce_id,
});
let cpj1 = edit.add_node(Node::ControlProjection {
control: new_if,
selection: 0,
});
let cpj2 = edit.add_node(Node::ControlProjection {
control: new_if,
selection: 1,
});
let region = edit.add_node(Node::Region {
preds: Box::new([cpj1, cpj2]),
});
let write = edit.add_node(Node::Write {
collect: init,
data: f,
indices,
});
let phi = edit.add_node(Node::Phi {
control: region,
data: Box::new([write, init]),
});
edit = edit.replace_all_uses_where(id, phi, |other_id| {
*other_id != phi_other_use && *other_id != reduct
})?;
edit.replace_all_uses_where(join, region, |id| *id == join_succ)
});
}
} }
/* /*
......
...@@ -54,9 +54,13 @@ if !feature("seq") { ...@@ -54,9 +54,13 @@ if !feature("seq") {
inline(bfs@cost_init, bfs@loop1, bfs@loop2); inline(bfs@cost_init, bfs@loop1, bfs@loop2);
init = init_body; init = init_body;
} }
fork-tile[8, 0, false, true](init, traverse, collect);
delete-uncalled(*); delete-uncalled(*);
const-inline(*); const-inline(*);
clean-monoid-reduces(collect);
simpl!(*);
fork-tile[8, 0, false, true](init, traverse, collect);
clean-monoid-reduces(collect);
simpl!(*); simpl!(*);
fork-split(init, traverse, collect); fork-split(init, traverse, collect);
......
...@@ -39,6 +39,7 @@ fixpoint { ...@@ -39,6 +39,7 @@ fixpoint {
simpl!(collect); simpl!(collect);
fork-tile[1024, 0, false, true](init, traverse, collect); fork-tile[1024, 0, false, true](init, traverse, collect);
fork-split(init, traverse, collect); let out = fork-split(init, traverse, collect);
simpl!(*);
gcm(*); gcm(*);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment