use std::collections::{HashMap, HashSet}; use nestify::nest; use hercules_ir::*; use crate::*; /* * Substitute all uses of dynamic constants in a type that are keys in the substs map with the * dynamic constant value for that key. Return the substituted version of the type, once memoized. */ pub fn substitute_dynamic_constants_in_type( substs: &HashMap<DynamicConstantID, DynamicConstantID>, ty: TypeID, edit: &mut FunctionEdit, ) -> TypeID { // Look inside the type for references to dynamic constants. let ty_clone = edit.get_type(ty).clone(); match ty_clone { Type::Product(ref fields) => { let new_fields = fields .into_iter() .map(|field_id| substitute_dynamic_constants_in_type(substs, *field_id, edit)) .collect(); if new_fields != *fields { edit.add_type(Type::Product(new_fields)) } else { ty } } Type::Summation(ref variants) => { let new_variants = variants .into_iter() .map(|variant_id| substitute_dynamic_constants_in_type(substs, *variant_id, edit)) .collect(); if new_variants != *variants { edit.add_type(Type::Summation(new_variants)) } else { ty } } Type::Array(elem_ty, ref dims) => { let new_elem_ty = substitute_dynamic_constants_in_type(substs, elem_ty, edit); let new_dims = dims .into_iter() .map(|dim_id| substitute_dynamic_constants(substs, *dim_id, edit)) .collect(); if new_elem_ty != elem_ty || new_dims != *dims { edit.add_type(Type::Array(new_elem_ty, new_dims)) } else { ty } } _ => ty, } } /* * Substitute all uses of dynamic constants in a dynamic constant dc that are keys in the * substs map and replace them with their appropriate replacement values. Return the substituted * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create * new dynamic constants, which can only be done inside an edit. */ pub fn substitute_dynamic_constants( substs: &HashMap<DynamicConstantID, DynamicConstantID>, dc: DynamicConstantID, edit: &mut FunctionEdit, ) -> DynamicConstantID { // If this dynamic constant should be substituted, just return the substitution if let Some(subst) = substs.get(&dc) { return *subst; } // Look inside the dynamic constant to perform substitution in its children let dc_clone = edit.get_dynamic_constant(dc).clone(); match dc_clone { DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc, DynamicConstant::Add(xs) => { let new_xs = xs .iter() .map(|x| substitute_dynamic_constants(substs, *x, edit)) .collect::<Vec<_>>(); if new_xs != xs { edit.add_dynamic_constant(DynamicConstant::Add(new_xs)) } else { dc } } DynamicConstant::Sub(left, right) => { let new_left = substitute_dynamic_constants(substs, left, edit); let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right)) } else { dc } } DynamicConstant::Mul(xs) => { let new_xs = xs .iter() .map(|x| substitute_dynamic_constants(substs, *x, edit)) .collect::<Vec<_>>(); if new_xs != xs { edit.add_dynamic_constant(DynamicConstant::Mul(new_xs)) } else { dc } } DynamicConstant::Div(left, right) => { let new_left = substitute_dynamic_constants(substs, left, edit); let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right)) } else { dc } } DynamicConstant::Rem(left, right) => { let new_left = substitute_dynamic_constants(substs, left, edit); let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right)) } else { dc } } DynamicConstant::Min(xs) => { let new_xs = xs .iter() .map(|x| substitute_dynamic_constants(substs, *x, edit)) .collect::<Vec<_>>(); if new_xs != xs { edit.add_dynamic_constant(DynamicConstant::Min(new_xs)) } else { dc } } DynamicConstant::Max(xs) => { let new_xs = xs .iter() .map(|x| substitute_dynamic_constants(substs, *x, edit)) .collect::<Vec<_>>(); if new_xs != xs { edit.add_dynamic_constant(DynamicConstant::Max(new_xs)) } else { dc } } } } /* * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return * the substituted version of the constant, once memozied. */ pub fn substitute_dynamic_constants_in_constant( substs: &HashMap<DynamicConstantID, DynamicConstantID>, cons: ConstantID, edit: &mut FunctionEdit, ) -> ConstantID { // Look inside the type for references to dynamic constants. let cons_clone = edit.get_constant(cons).clone(); match cons_clone { Constant::Product(ty, fields) => { let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); let new_fields = fields .iter() .map(|field_id| substitute_dynamic_constants_in_constant(substs, *field_id, edit)) .collect(); if new_ty != ty || new_fields != fields { edit.add_constant(Constant::Product(new_ty, new_fields)) } else { cons } } Constant::Summation(ty, idx, variant) => { let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); let new_variant = substitute_dynamic_constants_in_constant(substs, variant, edit); if new_ty != ty || new_variant != variant { edit.add_constant(Constant::Summation(new_ty, idx, new_variant)) } else { cons } } Constant::Array(ty) => { let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); if new_ty != ty { edit.add_constant(Constant::Array(new_ty)) } else { cons } } _ => cons, } } /* * Substitute all uses of the dynamic constants specified by the subst map in a node. */ pub fn substitute_dynamic_constants_in_node( substs: &HashMap<DynamicConstantID, DynamicConstantID>, node: &mut Node, edit: &mut FunctionEdit, ) { match node { Node::Fork { control: _, factors, } => { for factor in factors.into_iter() { *factor = substitute_dynamic_constants(substs, *factor, edit); } } Node::Constant { id } => { *id = substitute_dynamic_constants_in_constant(substs, *id, edit); } Node::DynamicConstant { id } => { *id = substitute_dynamic_constants(substs, *id, edit); } Node::Call { control: _, function: _, dynamic_constants, args: _, } => { for dc_arg in dynamic_constants.into_iter() { *dc_arg = substitute_dynamic_constants(substs, *dc_arg, edit); } } _ => {} } } /* * Top level function to make a function have only a single return. */ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { let returns: Vec<NodeID> = (0..editor.func().nodes.len()) .filter(|idx| editor.func().nodes[*idx].is_return()) .map(NodeID::new) .collect(); assert!(!returns.is_empty()); if returns.len() == 1 { return Some(returns[0]); } let preds_before_returns: Box<[NodeID]> = returns .iter() .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0]) .collect(); let num_return_data = editor.func().return_types.len(); let data_to_return: Vec<Box<[NodeID]>> = (0..num_return_data) .map(|idx| { returns .iter() .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[idx + 1]) .collect() }) .collect(); // All of the old returns get replaced in a single edit. let mut new_return = None; editor.edit(|mut edit| { let region = edit.add_node(Node::Region { preds: preds_before_returns, }); let return_vals = data_to_return .into_iter() .map(|data| { edit.add_node(Node::Phi { control: region, data, }) }) .collect(); for ret in returns { edit = edit.delete_node(ret)?; } new_return = Some(edit.add_node(Node::Return { control: region, data: return_vals, })); Ok(edit) }); new_return } pub fn contains_between_control_flow(func: &Function) -> bool { let num_control = func.nodes.iter().filter(|node| node.is_control()).count(); assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node."); num_control > 2 } /* * Top level function to ensure a Hercules function contains at least one * control node that isn't the start or return nodes. */ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { if !contains_between_control_flow(editor.func()) { let ret = editor .node_ids() .skip(1) .filter(|id| editor.func().nodes[id.idx()].is_control()) .next() .unwrap(); let Node::Return { control, ref data } = editor.func().nodes[ret.idx()] else { panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.") }; assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node."); let data = data.clone(); let mut region_id = None; editor.edit(|mut edit| { edit = edit.delete_node(ret)?; region_id = Some(edit.add_node(Node::Region { preds: Box::new([NodeID::new(0)]), })); edit.add_node(Node::Return { control: region_id.unwrap(), data, }); Ok(edit) }); region_id } else { Some( editor .get_users(NodeID::new(0)) .filter(|id| editor.func().nodes[id.idx()].is_control()) .next() .unwrap(), ) } } pub type DenseNodeMap<T> = Vec<T>; pub type SparseNodeMap<T> = HashMap<NodeID, T>; nest! { // #[derive(Clone, Debug)] pub struct NodeIterator<'a> { pub direction: #[derive(Clone, Debug, PartialEq)] pub enum Direction { Uses, Users, }, visited: DenseNodeMap<bool>, stack: Vec<NodeID>, func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor. // `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search. stop_on: HashSet<NodeID>, // Don't add neighbors of these. } } pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { let len = editor.func().nodes.len(); NodeIterator { direction: Direction::Uses, visited: vec![false; len], stack: vec![node], func: editor, stop_on: HashSet::new(), } } pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> { let len = editor.func().nodes.len(); NodeIterator { direction: Direction::Users, visited: vec![false; len], stack: vec![node], func: editor, stop_on: HashSet::new(), } } pub fn walk_all_uses_stop_on<'a>( node: NodeID, editor: &'a FunctionEditor<'a>, stop_on: HashSet<NodeID>, ) -> NodeIterator<'a> { let len = editor.func().nodes.len(); let uses = editor.get_uses(node).collect(); NodeIterator { direction: Direction::Uses, visited: vec![false; len], stack: uses, func: editor, stop_on, } } pub fn walk_all_users_stop_on<'a>( node: NodeID, editor: &'a FunctionEditor<'a>, stop_on: HashSet<NodeID>, ) -> NodeIterator<'a> { let len = editor.func().nodes.len(); let users = editor.get_users(node).collect(); NodeIterator { direction: Direction::Users, visited: vec![false; len], stack: users, func: editor, stop_on, } } impl<'a> Iterator for NodeIterator<'a> { type Item = NodeID; fn next(&mut self) -> Option<Self::Item> { while let Some(current) = self.stack.pop() { if !self.visited[current.idx()] { self.visited[current.idx()] = true; if !self.stop_on.contains(¤t) { if self.direction == Direction::Uses { for neighbor in self.func.get_uses(current) { self.stack.push(neighbor) } } else { for neighbor in self.func.get_users(current) { self.stack.push(neighbor) } } } return Some(current); } } None } } /* * Materializes an einsum expression into an IR node tree. Replaces thread IDs * with provides node IDs. Doesn't materialize reductions or comprehensions. */ pub fn materialize_simple_einsum_expr( edit: &mut FunctionEdit, id: MathID, env: &MathEnv, dim_substs: &[NodeID], ) -> NodeID { match env[id.idx()] { MathExpr::Zero(ty) => { let cons_id = edit.add_zero_constant(ty); edit.add_node(Node::Constant { id: cons_id }) } MathExpr::One(ty) => { let cons_id = edit.add_one_constant(ty); edit.add_node(Node::Constant { id: cons_id }) } MathExpr::OpaqueNode(id) => id, MathExpr::ThreadID(dim) => dim_substs[dim.0], MathExpr::Read(collect, ref indices) => { let collect = materialize_simple_einsum_expr(edit, collect, env, dim_substs); let indices = Box::new([Index::Position( indices .into_iter() .map(|idx| materialize_simple_einsum_expr(edit, *idx, env, dim_substs)) .collect(), )]); edit.add_node(Node::Read { collect, indices }) } MathExpr::Unary(op, input) => { let input = materialize_simple_einsum_expr(edit, input, env, dim_substs); edit.add_node(Node::Unary { op, input }) } MathExpr::Binary(op, left, right) => { let left = materialize_simple_einsum_expr(edit, left, env, dim_substs); let right = materialize_simple_einsum_expr(edit, right, env, dim_substs); edit.add_node(Node::Binary { op, left, right }) } MathExpr::Ternary(op, first, second, third) => { let first = materialize_simple_einsum_expr(edit, first, env, dim_substs); let second = materialize_simple_einsum_expr(edit, second, env, dim_substs); let third = materialize_simple_einsum_expr(edit, third, env, dim_substs); edit.add_node(Node::Ternary { op, first, second, third, }) } MathExpr::IntrinsicFunc(intrinsic, ref args) => { let args = args .into_iter() .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs)) .collect(); edit.add_node(Node::IntrinsicCall { intrinsic, args }) } _ => panic!(), } } /* * Get the node IDs referred to in position indices in a indices set. */ pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ { indices .iter() .filter_map(|index| { if let Index::Position(indices) = index { Some(indices) } else { None } }) .flat_map(|pos| pos.iter()) .map(|id| *id) } /* * Checks if a set of indices is fully parallel over a set of forks - that is, * every thread ID from every fork appears at least once in positions in the * indices set. */ pub fn indices_parallel_over_forks<I>( editor: &FunctionEditor, indices: &[Index], mut forks: I, ) -> bool where I: Iterator<Item = NodeID>, { // Get the forks corresponding to position uses of bare thread ids. let nodes = &editor.func().nodes; let fork_thread_id_pairs = node_indices(indices).filter_map(|id| { if let Node::ThreadID { control, dimension } = nodes[id.idx()] { Some((control, dimension)) } else if let Node::Binary { op: BinaryOperator::Add, left: tid, right: cons, } = nodes[id.idx()] && let Node::ThreadID { control, dimension } = nodes[tid.idx()] && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) { Some((control, dimension)) } else if let Node::Binary { op: BinaryOperator::Add, left: cons, right: tid, } = nodes[id.idx()] && let Node::ThreadID { control, dimension } = nodes[tid.idx()] && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant()) { Some((control, dimension)) } else { None } }); let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new(); for (fork, dim) in fork_thread_id_pairs { rep_forks.entry(fork).or_default().push(dim); } // If each fork the query is over is represented and each of its dimensions // is represented, then the indices are parallel over the forks. forks.all(|fork| { let Some(mut rep_dims) = rep_forks.remove(&fork) else { return false; }; rep_dims.sort(); rep_dims.dedup(); nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len() }) } pub fn is_zero(editor: &FunctionEditor, id: NodeID) -> bool { let nodes = &editor.func().nodes; nodes[id.idx()] .try_constant() .map(|id| editor.get_constant(id).is_zero()) .unwrap_or(false) || nodes[id.idx()] .try_dynamic_constant() .map(|id| editor.get_dynamic_constant(id).is_zero()) .unwrap_or(false) || nodes[id.idx()].is_undef() } pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool { let nodes = &editor.func().nodes; nodes[id.idx()] .try_constant() .map(|id| editor.get_constant(id).is_one()) .unwrap_or(false) || nodes[id.idx()] .try_dynamic_constant() .map(|id| editor.get_dynamic_constant(id).is_one()) .unwrap_or(false) || nodes[id.idx()].is_undef() } pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool { let nodes = &editor.func().nodes; nodes[id.idx()] .try_constant() .map(|id| editor.get_constant(id).is_largest()) .unwrap_or(false) || nodes[id.idx()] .try_dynamic_constant() .map(|id| editor.get_dynamic_constant(id).is_largest()) .unwrap_or(false) || nodes[id.idx()].is_undef() } pub fn is_smallest(editor: &FunctionEditor, id: NodeID) -> bool { let nodes = &editor.func().nodes; nodes[id.idx()] .try_constant() .map(|id| editor.get_constant(id).is_smallest()) .unwrap_or(false) || nodes[id.idx()] .try_dynamic_constant() .map(|id| editor.get_dynamic_constant(id).is_smallest()) .unwrap_or(false) || nodes[id.idx()].is_undef() }