-
Aaron Councilman authoredAaron Councilman authored
utils.rs 20.11 KiB
use std::collections::{HashMap, HashSet};
use nestify::nest;
use hercules_ir::*;
use crate::*;
/*
* Substitute all uses of dynamic constants in a type that are keys in the substs map with the
* dynamic constant value for that key. Return the substituted version of the type, once memoized.
*/
pub fn substitute_dynamic_constants_in_type(
substs: &HashMap<DynamicConstantID, DynamicConstantID>,
ty: TypeID,
edit: &mut FunctionEdit,
) -> TypeID {
// Look inside the type for references to dynamic constants.
let ty_clone = edit.get_type(ty).clone();
match ty_clone {
Type::Product(ref fields) => {
let new_fields = fields
.into_iter()
.map(|field_id| substitute_dynamic_constants_in_type(substs, *field_id, edit))
.collect();
if new_fields != *fields {
edit.add_type(Type::Product(new_fields))
} else {
ty
}
}
Type::Summation(ref variants) => {
let new_variants = variants
.into_iter()
.map(|variant_id| substitute_dynamic_constants_in_type(substs, *variant_id, edit))
.collect();
if new_variants != *variants {
edit.add_type(Type::Summation(new_variants))
} else {
ty
}
}
Type::Array(elem_ty, ref dims) => {
let new_elem_ty = substitute_dynamic_constants_in_type(substs, elem_ty, edit);
let new_dims = dims
.into_iter()
.map(|dim_id| substitute_dynamic_constants(substs, *dim_id, edit))
.collect();
if new_elem_ty != elem_ty || new_dims != *dims {
edit.add_type(Type::Array(new_elem_ty, new_dims))
} else {
ty
}
}
_ => ty,
}
}
/*
* Substitute all uses of dynamic constants in a dynamic constant dc that are keys in the
* substs map and replace them with their appropriate replacement values. Return the substituted
* version of dc, once memoized. Takes a mutable edit instead of an editor since this may create
* new dynamic constants, which can only be done inside an edit.
*/
pub fn substitute_dynamic_constants(
substs: &HashMap<DynamicConstantID, DynamicConstantID>,
dc: DynamicConstantID,
edit: &mut FunctionEdit,
) -> DynamicConstantID {
// If this dynamic constant should be substituted, just return the substitution
if let Some(subst) = substs.get(&dc) {
return *subst;
}
// Look inside the dynamic constant to perform substitution in its children
let dc_clone = edit.get_dynamic_constant(dc).clone();
match dc_clone {
DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc,
DynamicConstant::Add(xs) => {
let new_xs = xs
.iter()
.map(|x| substitute_dynamic_constants(substs, *x, edit))
.collect::<Vec<_>>();
if new_xs != xs {
edit.add_dynamic_constant(DynamicConstant::Add(new_xs))
} else {
dc
}
}
DynamicConstant::Sub(left, right) => {
let new_left = substitute_dynamic_constants(substs, left, edit);
let new_right = substitute_dynamic_constants(substs, right, edit);
if new_left != left || new_right != right {
edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right))
} else {
dc
}
}
DynamicConstant::Mul(xs) => {
let new_xs = xs
.iter()
.map(|x| substitute_dynamic_constants(substs, *x, edit))
.collect::<Vec<_>>();
if new_xs != xs {
edit.add_dynamic_constant(DynamicConstant::Mul(new_xs))
} else {
dc
}
}
DynamicConstant::Div(left, right) => {
let new_left = substitute_dynamic_constants(substs, left, edit);
let new_right = substitute_dynamic_constants(substs, right, edit);
if new_left != left || new_right != right {
edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right))
} else {
dc
}
}
DynamicConstant::Rem(left, right) => {
let new_left = substitute_dynamic_constants(substs, left, edit);
let new_right = substitute_dynamic_constants(substs, right, edit);
if new_left != left || new_right != right {
edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right))
} else {
dc
}
}
DynamicConstant::Min(xs) => {
let new_xs = xs
.iter()
.map(|x| substitute_dynamic_constants(substs, *x, edit))
.collect::<Vec<_>>();
if new_xs != xs {
edit.add_dynamic_constant(DynamicConstant::Min(new_xs))
} else {
dc
}
}
DynamicConstant::Max(xs) => {
let new_xs = xs
.iter()
.map(|x| substitute_dynamic_constants(substs, *x, edit))
.collect::<Vec<_>>();
if new_xs != xs {
edit.add_dynamic_constant(DynamicConstant::Max(new_xs))
} else {
dc
}
}
}
}
/*
* Substitute all uses of the dynamic constants specified by the subst map in a constant. Return
* the substituted version of the constant, once memozied.
*/
pub fn substitute_dynamic_constants_in_constant(
substs: &HashMap<DynamicConstantID, DynamicConstantID>,
cons: ConstantID,
edit: &mut FunctionEdit,
) -> ConstantID {
// Look inside the type for references to dynamic constants.
let cons_clone = edit.get_constant(cons).clone();
match cons_clone {
Constant::Product(ty, fields) => {
let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
let new_fields = fields
.iter()
.map(|field_id| substitute_dynamic_constants_in_constant(substs, *field_id, edit))
.collect();
if new_ty != ty || new_fields != fields {
edit.add_constant(Constant::Product(new_ty, new_fields))
} else {
cons
}
}
Constant::Summation(ty, idx, variant) => {
let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
let new_variant = substitute_dynamic_constants_in_constant(substs, variant, edit);
if new_ty != ty || new_variant != variant {
edit.add_constant(Constant::Summation(new_ty, idx, new_variant))
} else {
cons
}
}
Constant::Array(ty) => {
let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
if new_ty != ty {
edit.add_constant(Constant::Array(new_ty))
} else {
cons
}
}
_ => cons,
}
}
/*
* Substitute all uses of the dynamic constants specified by the subst map in a node.
*/
pub fn substitute_dynamic_constants_in_node(
substs: &HashMap<DynamicConstantID, DynamicConstantID>,
node: &mut Node,
edit: &mut FunctionEdit,
) {
match node {
Node::Fork {
control: _,
factors,
} => {
for factor in factors.into_iter() {
*factor = substitute_dynamic_constants(substs, *factor, edit);
}
}
Node::Constant { id } => {
*id = substitute_dynamic_constants_in_constant(substs, *id, edit);
}
Node::DynamicConstant { id } => {
*id = substitute_dynamic_constants(substs, *id, edit);
}
Node::Call {
control: _,
function: _,
dynamic_constants,
args: _,
} => {
for dc_arg in dynamic_constants.into_iter() {
*dc_arg = substitute_dynamic_constants(substs, *dc_arg, edit);
}
}
_ => {}
}
}
/*
* Top level function to make a function have only a single return.
*/
pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
let returns: Vec<NodeID> = (0..editor.func().nodes.len())
.filter(|idx| editor.func().nodes[*idx].is_return())
.map(NodeID::new)
.collect();
assert!(!returns.is_empty());
if returns.len() == 1 {
return Some(returns[0]);
}
let preds_before_returns: Box<[NodeID]> = returns
.iter()
.map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0])
.collect();
let num_return_data = editor.func().return_types.len();
let data_to_return: Vec<Box<[NodeID]>> = (0..num_return_data)
.map(|idx| {
returns
.iter()
.map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[idx + 1])
.collect()
})
.collect();
// All of the old returns get replaced in a single edit.
let mut new_return = None;
editor.edit(|mut edit| {
let region = edit.add_node(Node::Region {
preds: preds_before_returns,
});
let return_vals = data_to_return
.into_iter()
.map(|data| {
edit.add_node(Node::Phi {
control: region,
data,
})
})
.collect();
for ret in returns {
edit = edit.delete_node(ret)?;
}
new_return = Some(edit.add_node(Node::Return {
control: region,
data: return_vals,
}));
Ok(edit)
});
new_return
}
pub fn contains_between_control_flow(func: &Function) -> bool {
let num_control = func.nodes.iter().filter(|node| node.is_control()).count();
assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node.");
num_control > 2
}
/*
* Top level function to ensure a Hercules function contains at least one
* control node that isn't the start or return nodes.
*/
pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
if !contains_between_control_flow(editor.func()) {
let ret = editor
.node_ids()
.skip(1)
.filter(|id| editor.func().nodes[id.idx()].is_control())
.next()
.unwrap();
let Node::Return { control, ref data } = editor.func().nodes[ret.idx()] else {
panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.")
};
assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node.");
let data = data.clone();
let mut region_id = None;
editor.edit(|mut edit| {
edit = edit.delete_node(ret)?;
region_id = Some(edit.add_node(Node::Region {
preds: Box::new([NodeID::new(0)]),
}));
edit.add_node(Node::Return {
control: region_id.unwrap(),
data,
});
Ok(edit)
});
region_id
} else {
Some(
editor
.get_users(NodeID::new(0))
.filter(|id| editor.func().nodes[id.idx()].is_control())
.next()
.unwrap(),
)
}
}
pub type DenseNodeMap<T> = Vec<T>;
pub type SparseNodeMap<T> = HashMap<NodeID, T>;
nest! {
//
#[derive(Clone, Debug)]
pub struct NodeIterator<'a> {
pub direction:
#[derive(Clone, Debug, PartialEq)]
pub enum Direction {
Uses,
Users,
},
visited: DenseNodeMap<bool>,
stack: Vec<NodeID>,
func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor.
// `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search.
stop_on: HashSet<NodeID>, // Don't add neighbors of these.
}
}
pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> {
let len = editor.func().nodes.len();
NodeIterator {
direction: Direction::Uses,
visited: vec![false; len],
stack: vec![node],
func: editor,
stop_on: HashSet::new(),
}
}
pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> {
let len = editor.func().nodes.len();
NodeIterator {
direction: Direction::Users,
visited: vec![false; len],
stack: vec![node],
func: editor,
stop_on: HashSet::new(),
}
}
pub fn walk_all_uses_stop_on<'a>(
node: NodeID,
editor: &'a FunctionEditor<'a>,
stop_on: HashSet<NodeID>,
) -> NodeIterator<'a> {
let len = editor.func().nodes.len();
let uses = editor.get_uses(node).collect();
NodeIterator {
direction: Direction::Uses,
visited: vec![false; len],
stack: uses,
func: editor,
stop_on,
}
}
pub fn walk_all_users_stop_on<'a>(
node: NodeID,
editor: &'a FunctionEditor<'a>,
stop_on: HashSet<NodeID>,
) -> NodeIterator<'a> {
let len = editor.func().nodes.len();
let users = editor.get_users(node).collect();
NodeIterator {
direction: Direction::Users,
visited: vec![false; len],
stack: users,
func: editor,
stop_on,
}
}
impl<'a> Iterator for NodeIterator<'a> {
type Item = NodeID;
fn next(&mut self) -> Option<Self::Item> {
while let Some(current) = self.stack.pop() {
if !self.visited[current.idx()] {
self.visited[current.idx()] = true;
if !self.stop_on.contains(¤t) {
if self.direction == Direction::Uses {
for neighbor in self.func.get_uses(current) {
self.stack.push(neighbor)
}
} else {
for neighbor in self.func.get_users(current) {
self.stack.push(neighbor)
}
}
}
return Some(current);
}
}
None
}
}
/*
* Materializes an einsum expression into an IR node tree. Replaces thread IDs
* with provides node IDs. Doesn't materialize reductions or comprehensions.
*/
pub fn materialize_simple_einsum_expr(
edit: &mut FunctionEdit,
id: MathID,
env: &MathEnv,
dim_substs: &[NodeID],
) -> NodeID {
match env[id.idx()] {
MathExpr::Zero(ty) => {
let cons_id = edit.add_zero_constant(ty);
edit.add_node(Node::Constant { id: cons_id })
}
MathExpr::One(ty) => {
let cons_id = edit.add_one_constant(ty);
edit.add_node(Node::Constant { id: cons_id })
}
MathExpr::OpaqueNode(id) => id,
MathExpr::ThreadID(dim) => dim_substs[dim.0],
MathExpr::Read(collect, ref indices) => {
let collect = materialize_simple_einsum_expr(edit, collect, env, dim_substs);
let indices = Box::new([Index::Position(
indices
.into_iter()
.map(|idx| materialize_simple_einsum_expr(edit, *idx, env, dim_substs))
.collect(),
)]);
edit.add_node(Node::Read { collect, indices })
}
MathExpr::Unary(op, input) => {
let input = materialize_simple_einsum_expr(edit, input, env, dim_substs);
edit.add_node(Node::Unary { op, input })
}
MathExpr::Binary(op, left, right) => {
let left = materialize_simple_einsum_expr(edit, left, env, dim_substs);
let right = materialize_simple_einsum_expr(edit, right, env, dim_substs);
edit.add_node(Node::Binary { op, left, right })
}
MathExpr::Ternary(op, first, second, third) => {
let first = materialize_simple_einsum_expr(edit, first, env, dim_substs);
let second = materialize_simple_einsum_expr(edit, second, env, dim_substs);
let third = materialize_simple_einsum_expr(edit, third, env, dim_substs);
edit.add_node(Node::Ternary {
op,
first,
second,
third,
})
}
MathExpr::IntrinsicFunc(intrinsic, ref args) => {
let args = args
.into_iter()
.map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
.collect();
edit.add_node(Node::IntrinsicCall { intrinsic, args })
}
_ => panic!(),
}
}
/*
* Get the node IDs referred to in position indices in a indices set.
*/
pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ {
indices
.iter()
.filter_map(|index| {
if let Index::Position(indices) = index {
Some(indices)
} else {
None
}
})
.flat_map(|pos| pos.iter())
.map(|id| *id)
}
/*
* Checks if a set of indices is fully parallel over a set of forks - that is,
* every thread ID from every fork appears at least once in positions in the
* indices set.
*/
pub fn indices_parallel_over_forks<I>(
editor: &FunctionEditor,
indices: &[Index],
mut forks: I,
) -> bool
where
I: Iterator<Item = NodeID>,
{
// Get the forks corresponding to position uses of bare thread ids.
let nodes = &editor.func().nodes;
let fork_thread_id_pairs = node_indices(indices).filter_map(|id| {
if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
Some((control, dimension))
} else {
None
}
});
let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new();
for (fork, dim) in fork_thread_id_pairs {
rep_forks.entry(fork).or_default().push(dim);
}
// If each fork the query is over is represented and each of its dimensions
// is represented, then the indices are parallel over the forks.
forks.all(|fork| {
let Some(mut rep_dims) = rep_forks.remove(&fork) else {
return false;
};
rep_dims.sort();
rep_dims.dedup();
nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len()
})
}
pub fn is_zero(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_zero())
.unwrap_or(false)
|| nodes[id.idx()]
.try_dynamic_constant()
.map(|id| editor.get_dynamic_constant(id).is_zero())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_one())
.unwrap_or(false)
|| nodes[id.idx()]
.try_dynamic_constant()
.map(|id| editor.get_dynamic_constant(id).is_one())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_largest())
.unwrap_or(false)
|| nodes[id.idx()]
.try_dynamic_constant()
.map(|id| editor.get_dynamic_constant(id).is_largest())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}
pub fn is_smallest(editor: &FunctionEditor, id: NodeID) -> bool {
let nodes = &editor.func().nodes;
nodes[id.idx()]
.try_constant()
.map(|id| editor.get_constant(id).is_smallest())
.unwrap_or(false)
|| nodes[id.idx()]
.try_dynamic_constant()
.map(|id| editor.get_dynamic_constant(id).is_smallest())
.unwrap_or(false)
|| nodes[id.idx()].is_undef()
}