-
Aaron Councilman authoredAaron Councilman authored
unforkify.rs 10.36 KiB
use std::collections::{HashMap, HashSet};
use std::iter::zip;
use bitvec::{order::Lsb0, vec::BitVec};
use hercules_ir::{ir::*, LoopTree};
use crate::*;
type NodeVec = BitVec<u8, Lsb0>;
pub fn calculate_fork_nodes(
editor: &FunctionEditor,
inner_control: &NodeVec,
fork: NodeID,
) -> HashSet<NodeID> {
// Stop on PHIs / reduces outside of loop.
let stop_on: HashSet<NodeID> = editor
.node_ids()
.filter(|node| {
let data = &editor.func().nodes[node.idx()];
// External Phi
if let Node::Phi { control, data: _ } = data {
if match inner_control.get(control.idx()) {
Some(v) => !*v, //
None => true, // Doesn't exist, must be external
} {
return true;
}
}
// External Reduce
if let Node::Reduce {
control,
init: _,
reduct: _,
} = data
{
if match inner_control.get(control.idx()) {
Some(v) => !*v, //
None => true, // Doesn't exist, must be external
} {
return true;
}
}
// External Control
if data.is_control() {
return match inner_control.get(node.idx()) {
Some(v) => !*v, //
None => true, // Doesn't exist, must be external
};
}
// else
return false;
})
.collect();
let reduces: Vec<_> = editor
.node_ids()
.filter(|node| {
let Node::Reduce { control, .. } = editor.func().nodes[node.idx()] else {
return false;
};
match inner_control.get(control.idx()) {
Some(v) => *v,
None => false,
}
})
.chain(
editor
.get_users(fork)
.filter(|node| editor.node(node).is_thread_id()),
)
.collect();
let all_users: HashSet<NodeID> = reduces
.clone()
.iter()
.flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone()))
.chain(reduces.clone())
.collect();
let all_uses: HashSet<_> = reduces
.clone()
.iter()
.flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone()))
.chain(reduces)
.filter(|node| {
// Get rid of nodes in stop_on
!stop_on.contains(node)
})
.collect();
all_users.intersection(&all_uses).cloned().collect()
}
/*
* Convert forks back into loops right before codegen when a backend is not
* lowering a fork-join to vector / parallel code. Lowering fork-joins into
* sequential loops in LLVM is actually not entirely trivial, so it's easier to
* just do this transformation within Hercules IR.
*/
// FIXME: Only works on fully split fork nests.
pub fn unforkify_all(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
loop_tree: &LoopTree,
) {
for l in loop_tree.bottom_up_loops().into_iter().rev() {
if !editor.node(l.0).is_fork() {
continue;
}
let fork = l.0;
let join = fork_join_map[&fork];
unforkify(editor, fork, join, loop_tree);
}
}
pub fn unforkify_one(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
loop_tree: &LoopTree,
) {
for l in loop_tree.bottom_up_loops().into_iter().rev() {
if !editor.node(l.0).is_fork() {
continue;
}
let fork = l.0;
let join = fork_join_map[&fork];
if unforkify(editor, fork, join, loop_tree) {
break;
}
}
}
pub fn unforkify(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
loop_tree: &LoopTree,
) -> bool {
let mut zero_cons_id = ConstantID::new(0);
let mut one_cons_id = ConstantID::new(0);
assert!(editor.edit(|mut edit| {
zero_cons_id = edit.add_constant(Constant::UnsignedInteger64(0));
one_cons_id = edit.add_constant(Constant::UnsignedInteger64(1));
Ok(edit)
}));
// Convert the fork to a region, thread IDs to a single phi, reduces to
// phis, and the join to a branch at the top of the loop. The previous
// control insides of the fork-join should become the successor of the true
// projection node, and what was the use of the join should become a use of
// the new region.
let fork_nodes = calculate_fork_nodes(editor, loop_tree.nodes_in_loop_bitvec(fork), fork);
let nodes = &editor.func().nodes;
let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
if factors.len() > 1 {
// For now, don't convert multi-dimensional fork-joins. Rely on pass
// that splits fork-joins.
return false;
}
let join_control = nodes[join.idx()].try_join().unwrap();
let tids: Vec<_> = editor
.get_users(fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.collect();
let reduces: Vec<_> = editor
.get_users(join)
.filter(|id| nodes[id.idx()].is_reduce())
.collect();
let num_nodes = editor.node_ids().len();
let region_id = NodeID::new(num_nodes);
let if_id = NodeID::new(num_nodes + 1);
let proj_back_id = NodeID::new(num_nodes + 2);
let proj_exit_id = NodeID::new(num_nodes + 3);
let zero_id = NodeID::new(num_nodes + 4);
let one_id = NodeID::new(num_nodes + 5);
let indvar_id = NodeID::new(num_nodes + 6);
let add_id = NodeID::new(num_nodes + 7);
let dc_id = NodeID::new(num_nodes + 8);
let neq_id = NodeID::new(num_nodes + 9);
let guard_if_id = NodeID::new(num_nodes + 10);
let guard_join_id = NodeID::new(num_nodes + 11);
let guard_taken_proj_id = NodeID::new(num_nodes + 12);
let guard_skipped_proj_id = NodeID::new(num_nodes + 13);
let guard_cond_id = NodeID::new(num_nodes + 14);
let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new);
let s = num_nodes + 15 + reduces.len();
let join_phi_ids = (s..s + reduces.len()).map(NodeID::new);
let guard_cond = Node::Binary {
left: zero_id,
right: dc_id,
op: BinaryOperator::LT,
};
let guard_if = Node::If {
control: fork_control,
cond: guard_cond_id,
};
let guard_taken_proj = Node::ControlProjection {
control: guard_if_id,
selection: 1,
};
let guard_skipped_proj = Node::ControlProjection {
control: guard_if_id,
selection: 0,
};
let guard_join = Node::Region {
preds: Box::new([guard_skipped_proj_id, proj_exit_id]),
};
let region = Node::Region {
preds: Box::new([guard_taken_proj_id, proj_back_id]),
};
let if_node = Node::If {
control: join_control,
cond: neq_id,
};
let proj_back = Node::ControlProjection {
control: if_id,
selection: 1,
};
let proj_exit = Node::ControlProjection {
control: if_id,
selection: 0,
};
let zero = Node::Constant { id: zero_cons_id };
let one = Node::Constant { id: one_cons_id };
let indvar = Node::Phi {
control: region_id,
data: Box::new([zero_id, add_id]),
};
let add = Node::Binary {
op: BinaryOperator::Add,
left: indvar_id,
right: one_id,
};
let dc = Node::DynamicConstant { id: factors[0] };
let neq = Node::Binary {
op: BinaryOperator::NE,
left: add_id,
right: dc_id,
};
let (phis, join_phis): (Vec<_>, Vec<_>) = reduces
.iter()
.map(|reduce_id| {
let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap();
(
Node::Phi {
control: region_id,
data: Box::new([init, reduct]),
},
Node::Phi {
control: guard_join_id,
data: Box::new([init, reduct]),
},
)
})
.unzip();
editor.edit(|mut edit| {
assert_eq!(edit.add_node(region), region_id);
assert_eq!(edit.add_node(if_node), if_id);
assert_eq!(edit.add_node(proj_back), proj_back_id);
assert_eq!(edit.add_node(proj_exit), proj_exit_id);
assert_eq!(edit.add_node(zero), zero_id);
assert_eq!(edit.add_node(one), one_id);
assert_eq!(edit.add_node(indvar), indvar_id);
assert_eq!(edit.add_node(add), add_id);
assert_eq!(edit.add_node(dc), dc_id);
assert_eq!(edit.add_node(neq), neq_id);
assert_eq!(edit.add_node(guard_if), guard_if_id);
assert_eq!(edit.add_node(guard_join), guard_join_id);
assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id);
assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id);
assert_eq!(edit.add_node(guard_cond), guard_cond_id);
for (phi_id, phi) in zip(phi_ids.clone(), &phis) {
assert_eq!(edit.add_node(phi.clone()), phi_id);
}
for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) {
assert_eq!(edit.add_node(phi.clone()), phi_id);
}
edit = edit.replace_all_uses(fork, region_id)?;
edit = edit.replace_all_uses_where(join, guard_join_id, |usee| *usee != if_id)?;
edit.sub_edit(fork, region_id);
edit.sub_edit(join, if_id);
for tid in tids.iter() {
edit.sub_edit(*tid, indvar_id);
edit = edit.replace_all_uses(*tid, indvar_id)?;
}
for (((reduce, phi_id), phi), join_phi_id) in
zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids)
{
edit.sub_edit(*reduce, phi_id);
let Node::Phi { control: _, data } = phi else {
panic!()
};
edit = edit
.replace_all_uses_where(*reduce, join_phi_id, |usee| !fork_nodes.contains(usee))?; //, |usee| *usee != *reduct)?;
edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| {
fork_nodes.contains(usee) || *usee == data[1]
})?;
edit = edit.delete_node(*reduce)?;
}
edit = edit.delete_node(fork)?;
edit = edit.delete_node(join)?;
for tid in tids {
edit = edit.delete_node(tid)?;
}
Ok(edit)
})
}