unforkify.rs 5.45 KiB
use std::collections::HashMap;
use std::iter::zip;
use hercules_ir::ir::*;
use crate::*;
/*
* Convert forks back into loops right before codegen when a backend is not
* lowering a fork-join to vector / parallel code. Lowering fork-joins into
* sequential loops in LLVM is actually not entirely trivial, so it's easier to
* just do this transformation within Hercules IR.
*/
pub fn unforkify(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
let mut zero_cons_id = ConstantID::new(0);
let mut one_cons_id = ConstantID::new(0);
assert!(editor.edit(|mut edit| {
zero_cons_id = edit.add_constant(Constant::UnsignedInteger64(0));
one_cons_id = edit.add_constant(Constant::UnsignedInteger64(1));
Ok(edit)
}));
// Convert the fork to a region, thread IDs to a single phi, reduces to
// phis, and the join to a branch at the top of the loop. The previous
// control insides of the fork-join should become the successor of the true
// projection node, and what was the use of the join should become a use of
// the new region.
for (fork, join) in fork_join_map {
let nodes = &editor.func().nodes;
let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
if factors.len() > 1 {
// For now, don't convert multi-dimensional fork-joins. Rely on pass
// that splits fork-joins.
continue;
}
let join_control = nodes[join.idx()].try_join().unwrap();
let tids: Vec<_> = editor
.get_users(*fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.collect();
let reduces: Vec<_> = editor
.get_users(*join)
.filter(|id| nodes[id.idx()].is_reduce())
.collect();
let num_nodes = editor.node_ids().len();
let region_id = NodeID::new(num_nodes);
let if_id = NodeID::new(num_nodes + 1);
let proj_back_id = NodeID::new(num_nodes + 2);
let proj_exit_id = NodeID::new(num_nodes + 3);
let zero_id = NodeID::new(num_nodes + 4);
let one_id = NodeID::new(num_nodes + 5);
let indvar_id = NodeID::new(num_nodes + 6);
let add_id = NodeID::new(num_nodes + 7);
let dc_id = NodeID::new(num_nodes + 8);
let neq_id = NodeID::new(num_nodes + 9);
let phi_ids = (num_nodes + 10..num_nodes + 10 + reduces.len()).map(NodeID::new);
let region = Node::Region {
preds: Box::new([
fork_control,
if join_control == *fork {
proj_back_id
} else {
join_control
},
]),
};
let if_node = Node::If {
control: region_id,
cond: neq_id,
};
let proj_back = Node::Projection {
control: if_id,
selection: 1,
};
let proj_exit = Node::Projection {
control: if_id,
selection: 0,
};
let zero = Node::Constant { id: zero_cons_id };
let one = Node::Constant { id: one_cons_id };
let indvar = Node::Phi {
control: region_id,
data: Box::new([zero_id, add_id]),
};
let add = Node::Binary {
op: BinaryOperator::Add,
left: indvar_id,
right: one_id,
};
let dc = Node::DynamicConstant { id: factors[0] };
let neq = Node::Binary {
op: BinaryOperator::NE,
left: indvar_id,
right: dc_id,
};
let phis: Vec<_> = reduces
.iter()
.map(|reduce_id| {
let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap();
Node::Phi {
control: region_id,
data: Box::new([init, reduct]),
}
})
.collect();
editor.edit(|mut edit| {
assert_eq!(edit.add_node(region), region_id);
assert_eq!(edit.add_node(if_node), if_id);
assert_eq!(edit.add_node(proj_back), proj_back_id);
assert_eq!(edit.add_node(proj_exit), proj_exit_id);
assert_eq!(edit.add_node(zero), zero_id);
assert_eq!(edit.add_node(one), one_id);
assert_eq!(edit.add_node(indvar), indvar_id);
assert_eq!(edit.add_node(add), add_id);
assert_eq!(edit.add_node(dc), dc_id);
assert_eq!(edit.add_node(neq), neq_id);
for (phi_id, phi) in zip(phi_ids.clone(), phis) {
assert_eq!(edit.add_node(phi), phi_id);
}
edit = edit.replace_all_uses(*fork, proj_back_id)?;
edit = edit.replace_all_uses(*join, proj_exit_id)?;
edit.sub_edit(*fork, region_id);
edit.sub_edit(*join, if_id);
for tid in tids.iter() {
edit.sub_edit(*tid, indvar_id);
edit = edit.replace_all_uses(*tid, indvar_id)?;
}
for (reduce, phi_id) in zip(reduces.iter(), phi_ids) {
edit.sub_edit(*reduce, phi_id);
edit = edit.replace_all_uses(*reduce, phi_id)?;
}
edit = edit.delete_node(*fork)?;
edit = edit.delete_node(*join)?;
for tid in tids {
edit = edit.delete_node(tid)?;
}
for reduce in reduces {
edit = edit.delete_node(reduce)?;
}
Ok(edit)
});
}
}