Skip to content
Snippets Groups Projects
utils.rs 20.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::collections::{HashMap, HashSet};
    
    use hercules_ir::*;
    
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
    use crate::*;
    
    /*
    
     * Substitute all uses of dynamic constants in a type that are keys in the substs map with the
     * dynamic constant value for that key. Return the substituted version of the type, once memoized.
    
    rarbore2's avatar
    rarbore2 committed
    pub fn substitute_dynamic_constants_in_type(
    
        substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
        ty: TypeID,
        edit: &mut FunctionEdit,
    ) -> TypeID {
        // Look inside the type for references to dynamic constants.
        let ty_clone = edit.get_type(ty).clone();
        match ty_clone {
            Type::Product(ref fields) => {
                let new_fields = fields
                    .into_iter()
    
                    .map(|field_id| substitute_dynamic_constants_in_type(substs, *field_id, edit))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                    .collect();
                if new_fields != *fields {
                    edit.add_type(Type::Product(new_fields))
                } else {
                    ty
                }
            }
            Type::Summation(ref variants) => {
                let new_variants = variants
                    .into_iter()
    
                    .map(|variant_id| substitute_dynamic_constants_in_type(substs, *variant_id, edit))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                    .collect();
                if new_variants != *variants {
                    edit.add_type(Type::Summation(new_variants))
                } else {
                    ty
                }
            }
            Type::Array(elem_ty, ref dims) => {
    
                let new_elem_ty = substitute_dynamic_constants_in_type(substs, elem_ty, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                let new_dims = dims
                    .into_iter()
    
                    .map(|dim_id| substitute_dynamic_constants(substs, *dim_id, edit))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                    .collect();
                if new_elem_ty != elem_ty || new_dims != *dims {
                    edit.add_type(Type::Array(new_elem_ty, new_dims))
                } else {
                    ty
                }
            }
            _ => ty,
        }
    }
    
    /*
    
     * Substitute all uses of dynamic constants in a dynamic constant dc that are keys in the
     * substs map and replace them with their appropriate replacement values. Return the substituted
     * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create
     * new dynamic constants, which can only be done inside an edit.
    
    rarbore2's avatar
    rarbore2 committed
    pub fn substitute_dynamic_constants(
    
        substs: &HashMap<DynamicConstantID, DynamicConstantID>,
        dc: DynamicConstantID,
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
        edit: &mut FunctionEdit,
    ) -> DynamicConstantID {
    
        // If this dynamic constant should be substituted, just return the substitution
        if let Some(subst) = substs.get(&dc) {
            return *subst;
    
        // Look inside the dynamic constant to perform substitution in its children
        let dc_clone = edit.get_dynamic_constant(dc).clone();
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
        match dc_clone {
    
            DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc,
            DynamicConstant::Add(xs) => {
                let new_xs = xs
                    .iter()
                    .map(|x| substitute_dynamic_constants(substs, *x, edit))
                    .collect::<Vec<_>>();
                if new_xs != xs {
                    edit.add_dynamic_constant(DynamicConstant::Add(new_xs))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                } else {
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                }
            }
            DynamicConstant::Sub(left, right) => {
    
                let new_left = substitute_dynamic_constants(substs, left, edit);
                let new_right = substitute_dynamic_constants(substs, right, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                if new_left != left || new_right != right {
                    edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right))
                } else {
    
            DynamicConstant::Mul(xs) => {
                let new_xs = xs
                    .iter()
                    .map(|x| substitute_dynamic_constants(substs, *x, edit))
                    .collect::<Vec<_>>();
                if new_xs != xs {
                    edit.add_dynamic_constant(DynamicConstant::Mul(new_xs))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                } else {
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                }
            }
            DynamicConstant::Div(left, right) => {
    
                let new_left = substitute_dynamic_constants(substs, left, edit);
                let new_right = substitute_dynamic_constants(substs, right, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                if new_left != left || new_right != right {
                    edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right))
                } else {
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                }
            }
            DynamicConstant::Rem(left, right) => {
    
                let new_left = substitute_dynamic_constants(substs, left, edit);
                let new_right = substitute_dynamic_constants(substs, right, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                if new_left != left || new_right != right {
                    edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right))
    
                } else {
    
            DynamicConstant::Min(xs) => {
                let new_xs = xs
                    .iter()
                    .map(|x| substitute_dynamic_constants(substs, *x, edit))
                    .collect::<Vec<_>>();
                if new_xs != xs {
                    edit.add_dynamic_constant(DynamicConstant::Min(new_xs))
    
                } else {
    
            DynamicConstant::Max(xs) => {
                let new_xs = xs
                    .iter()
                    .map(|x| substitute_dynamic_constants(substs, *x, edit))
                    .collect::<Vec<_>>();
                if new_xs != xs {
                    edit.add_dynamic_constant(DynamicConstant::Max(new_xs))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                } else {
    
     * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return
     * the substituted version of the constant, once memozied.
    
    rarbore2's avatar
    rarbore2 committed
    pub fn substitute_dynamic_constants_in_constant(
    
        substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
        cons: ConstantID,
        edit: &mut FunctionEdit,
    ) -> ConstantID {
        // Look inside the type for references to dynamic constants.
        let cons_clone = edit.get_constant(cons).clone();
        match cons_clone {
            Constant::Product(ty, fields) => {
    
                let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                let new_fields = fields
                    .iter()
    
                    .map(|field_id| substitute_dynamic_constants_in_constant(substs, *field_id, edit))
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                    .collect();
                if new_ty != ty || new_fields != fields {
                    edit.add_constant(Constant::Product(new_ty, new_fields))
                } else {
                    cons
                }
            }
            Constant::Summation(ty, idx, variant) => {
    
                let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
                let new_variant = substitute_dynamic_constants_in_constant(substs, variant, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                if new_ty != ty || new_variant != variant {
                    edit.add_constant(Constant::Summation(new_ty, idx, new_variant))
                } else {
                    cons
                }
            }
            Constant::Array(ty) => {
    
                let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                if new_ty != ty {
                    edit.add_constant(Constant::Array(new_ty))
                } else {
                    cons
                }
            }
            _ => cons,
        }
    }
    
    /*
    
     * Substitute all uses of the dynamic constants specified by the subst map in a node.
    
    rarbore2's avatar
    rarbore2 committed
    pub fn substitute_dynamic_constants_in_node(
    
        substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
        node: &mut Node,
        edit: &mut FunctionEdit,
    ) {
        match node {
            Node::Fork {
                control: _,
                factors,
            } => {
                for factor in factors.into_iter() {
    
                    *factor = substitute_dynamic_constants(substs, *factor, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
                }
            }
            Node::Constant { id } => {
    
                *id = substitute_dynamic_constants_in_constant(substs, *id, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
            }
            Node::DynamicConstant { id } => {
    
                *id = substitute_dynamic_constants(substs, *id, edit);
    
    Ryan Ziegler's avatar
    Ryan Ziegler committed
            }
            Node::Call {
                control: _,
                function: _,
                dynamic_constants,
                args: _,
            } => {
                for dc_arg in dynamic_constants.into_iter() {
    
                    *dc_arg = substitute_dynamic_constants(substs, *dc_arg, edit);
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Top level function to make a function have only a single return.
     */
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
    
    rarbore2's avatar
    rarbore2 committed
        let returns: Vec<NodeID> = (0..editor.func().nodes.len())
            .filter(|idx| editor.func().nodes[*idx].is_return())
            .map(NodeID::new)
            .collect();
        assert!(!returns.is_empty());
        if returns.len() == 1 {
            return Some(returns[0]);
        }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
        let preds_before_returns: Box<[NodeID]> = returns
    
    rarbore2's avatar
    rarbore2 committed
            .iter()
            .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0])
            .collect();
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    
        let num_return_data = editor.func().return_types.len();
        let data_to_return: Vec<Box<[NodeID]>> = (0..num_return_data)
            .map(|idx| {
                returns
                    .iter()
                    .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[idx + 1])
                    .collect()
            })
    
    rarbore2's avatar
    rarbore2 committed
            .collect();
    
        // All of the old returns get replaced in a single edit.
        let mut new_return = None;
        editor.edit(|mut edit| {
            let region = edit.add_node(Node::Region {
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                preds: preds_before_returns,
    
    rarbore2's avatar
    rarbore2 committed
            });
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            let return_vals = data_to_return
                .into_iter()
                .map(|data| {
                    edit.add_node(Node::Phi {
                        control: region,
                        data,
                    })
                })
                .collect();
    
    rarbore2's avatar
    rarbore2 committed
            for ret in returns {
                edit = edit.delete_node(ret)?;
            }
            new_return = Some(edit.add_node(Node::Return {
                control: region,
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                data: return_vals,
    
    rarbore2's avatar
    rarbore2 committed
            }));
            Ok(edit)
        });
        new_return
    }
    
    rarbore2's avatar
    rarbore2 committed
    pub fn contains_between_control_flow(func: &Function) -> bool {
    
    Russel Arbore's avatar
    Russel Arbore committed
        let num_control = func.nodes.iter().filter(|node| node.is_control()).count();
        assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node.");
        num_control > 2
    }
    
    /*
     * Top level function to ensure a Hercules function contains at least one
     * control node that isn't the start or return nodes.
     */
    
    Aaron Councilman's avatar
    Aaron Councilman committed
    pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
    
    Russel Arbore's avatar
    Russel Arbore committed
        if !contains_between_control_flow(editor.func()) {
            let ret = editor
                .node_ids()
                .skip(1)
                .filter(|id| editor.func().nodes[id.idx()].is_control())
                .next()
                .unwrap();
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            let Node::Return { control, ref data } = editor.func().nodes[ret.idx()] else {
    
    Russel Arbore's avatar
    Russel Arbore committed
                panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.")
            };
            assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node.");
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            let data = data.clone();
    
    Russel Arbore's avatar
    Russel Arbore committed
            let mut region_id = None;
            editor.edit(|mut edit| {
                edit = edit.delete_node(ret)?;
                region_id = Some(edit.add_node(Node::Region {
                    preds: Box::new([NodeID::new(0)]),
                }));
                edit.add_node(Node::Return {
                    control: region_id.unwrap(),
                    data,
                });
                Ok(edit)
            });
            region_id
        } else {
            Some(
                editor
                    .get_users(NodeID::new(0))
                    .filter(|id| editor.func().nodes[id.idx()].is_control())
                    .next()
                    .unwrap(),
            )
        }
    }
    
    pub type DenseNodeMap<T> = Vec<T>;
    pub type SparseNodeMap<T> = HashMap<NodeID, T>;
    
    nest! {
    //
    #[derive(Clone, Debug)]
    pub struct NodeIterator<'a> {
        pub direction:
            #[derive(Clone, Debug, PartialEq)]
            pub enum Direction {
                Uses,
                Users,
            },
        visited: DenseNodeMap<bool>,
        stack: Vec<NodeID>,
        func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor.
        // `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search.
        stop_on: HashSet<NodeID>, // Don't add neighbors of these.
    }
    }
    
    pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> {
        let len = editor.func().nodes.len();
        NodeIterator {
            direction: Direction::Uses,
            visited: vec![false; len],
            stack: vec![node],
            func: editor,
            stop_on: HashSet::new(),
        }
    }
    
    pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> {
        let len = editor.func().nodes.len();
        NodeIterator {
            direction: Direction::Users,
            visited: vec![false; len],
            stack: vec![node],
            func: editor,
            stop_on: HashSet::new(),
        }
    }
    
    pub fn walk_all_uses_stop_on<'a>(
        node: NodeID,
        editor: &'a FunctionEditor<'a>,
        stop_on: HashSet<NodeID>,
    ) -> NodeIterator<'a> {
        let len = editor.func().nodes.len();
        let uses = editor.get_uses(node).collect();
        NodeIterator {
            direction: Direction::Uses,
            visited: vec![false; len],
            stack: uses,
            func: editor,
            stop_on,
        }
    }
    
    pub fn walk_all_users_stop_on<'a>(
        node: NodeID,
        editor: &'a FunctionEditor<'a>,
        stop_on: HashSet<NodeID>,
    ) -> NodeIterator<'a> {
        let len = editor.func().nodes.len();
        let users = editor.get_users(node).collect();
        NodeIterator {
            direction: Direction::Users,
            visited: vec![false; len],
            stack: users,
            func: editor,
            stop_on,
        }
    }
    
    impl<'a> Iterator for NodeIterator<'a> {
        type Item = NodeID;
    
        fn next(&mut self) -> Option<Self::Item> {
            while let Some(current) = self.stack.pop() {
                if !self.visited[current.idx()] {
                    self.visited[current.idx()] = true;
    
                    if !self.stop_on.contains(&current) {
                        if self.direction == Direction::Uses {
                            for neighbor in self.func.get_uses(current) {
                                self.stack.push(neighbor)
                            }
                        } else {
                            for neighbor in self.func.get_users(current) {
                                self.stack.push(neighbor)
                            }
                        }
                    }
    
                    return Some(current);
                }
            }
            None
        }
    }
    
    
    /*
     * Materializes an einsum expression into an IR node tree. Replaces thread IDs
     * with provides node IDs. Doesn't materialize reductions or comprehensions.
     */
    pub fn materialize_simple_einsum_expr(
        edit: &mut FunctionEdit,
        id: MathID,
        env: &MathEnv,
        dim_substs: &[NodeID],
    ) -> NodeID {
        match env[id.idx()] {
            MathExpr::Zero(ty) => {
                let cons_id = edit.add_zero_constant(ty);
                edit.add_node(Node::Constant { id: cons_id })
            }
            MathExpr::One(ty) => {
                let cons_id = edit.add_one_constant(ty);
                edit.add_node(Node::Constant { id: cons_id })
            }
            MathExpr::OpaqueNode(id) => id,
            MathExpr::ThreadID(dim) => dim_substs[dim.0],
            MathExpr::Read(collect, ref indices) => {
                let collect = materialize_simple_einsum_expr(edit, collect, env, dim_substs);
                let indices = Box::new([Index::Position(
                    indices
                        .into_iter()
                        .map(|idx| materialize_simple_einsum_expr(edit, *idx, env, dim_substs))
                        .collect(),
                )]);
                edit.add_node(Node::Read { collect, indices })
            }
            MathExpr::Unary(op, input) => {
                let input = materialize_simple_einsum_expr(edit, input, env, dim_substs);
                edit.add_node(Node::Unary { op, input })
            }
            MathExpr::Binary(op, left, right) => {
                let left = materialize_simple_einsum_expr(edit, left, env, dim_substs);
                let right = materialize_simple_einsum_expr(edit, right, env, dim_substs);
                edit.add_node(Node::Binary { op, left, right })
            }
            MathExpr::Ternary(op, first, second, third) => {
                let first = materialize_simple_einsum_expr(edit, first, env, dim_substs);
                let second = materialize_simple_einsum_expr(edit, second, env, dim_substs);
                let third = materialize_simple_einsum_expr(edit, third, env, dim_substs);
                edit.add_node(Node::Ternary {
                    op,
                    first,
                    second,
                    third,
                })
            }
    
    rarbore2's avatar
    rarbore2 committed
            MathExpr::IntrinsicFunc(intrinsic, ref args) => {
                let args = args
                    .into_iter()
                    .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
                    .collect();
                edit.add_node(Node::IntrinsicCall { intrinsic, args })
            }
    
            _ => panic!(),
        }
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    /*
     * Get the node IDs referred to in position indices in a indices set.
     */
    pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ {
        indices
            .iter()
            .filter_map(|index| {
                if let Index::Position(indices) = index {
                    Some(indices)
                } else {
                    None
                }
            })
            .flat_map(|pos| pos.iter())
            .map(|id| *id)
    }
    
    /*
     * Checks if a set of indices is fully parallel over a set of forks - that is,
     * every thread ID from every fork appears at least once in positions in the
     * indices set.
     */
    pub fn indices_parallel_over_forks<I>(
        editor: &FunctionEditor,
        indices: &[Index],
        mut forks: I,
    ) -> bool
    where
        I: Iterator<Item = NodeID>,
    {
        // Get the forks corresponding to position uses of bare thread ids.
        let nodes = &editor.func().nodes;
        let fork_thread_id_pairs = node_indices(indices).filter_map(|id| {
            if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
                Some((control, dimension))
    
    rarbore2's avatar
    rarbore2 committed
            } else if let Node::Binary {
                op: BinaryOperator::Add,
                left: tid,
                right: cons,
            } = nodes[id.idx()]
                && let Node::ThreadID { control, dimension } = nodes[tid.idx()]
                && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
            {
                Some((control, dimension))
            } else if let Node::Binary {
                op: BinaryOperator::Add,
                left: cons,
                right: tid,
            } = nodes[id.idx()]
                && let Node::ThreadID { control, dimension } = nodes[tid.idx()]
                && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
            {
                Some((control, dimension))
    
    rarbore2's avatar
    rarbore2 committed
            } else {
                None
            }
        });
        let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new();
        for (fork, dim) in fork_thread_id_pairs {
            rep_forks.entry(fork).or_default().push(dim);
        }
    
        // If each fork the query is over is represented and each of its dimensions
        // is represented, then the indices are parallel over the forks.
        forks.all(|fork| {
            let Some(mut rep_dims) = rep_forks.remove(&fork) else {
                return false;
            };
    
            rep_dims.sort();
            rep_dims.dedup();
            nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len()
        })
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    pub fn is_zero(editor: &FunctionEditor, id: NodeID) -> bool {
        let nodes = &editor.func().nodes;
        nodes[id.idx()]
            .try_constant()
            .map(|id| editor.get_constant(id).is_zero())
            .unwrap_or(false)
            || nodes[id.idx()]
                .try_dynamic_constant()
                .map(|id| editor.get_dynamic_constant(id).is_zero())
                .unwrap_or(false)
            || nodes[id.idx()].is_undef()
    }
    
    pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
        let nodes = &editor.func().nodes;
        nodes[id.idx()]
            .try_constant()
            .map(|id| editor.get_constant(id).is_one())
            .unwrap_or(false)
            || nodes[id.idx()]
                .try_dynamic_constant()
                .map(|id| editor.get_dynamic_constant(id).is_one())
                .unwrap_or(false)
            || nodes[id.idx()].is_undef()
    }
    
    rarbore2's avatar
    rarbore2 committed
    
    pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool {
        let nodes = &editor.func().nodes;
        nodes[id.idx()]
            .try_constant()
            .map(|id| editor.get_constant(id).is_largest())
            .unwrap_or(false)
            || nodes[id.idx()]
                .try_dynamic_constant()
                .map(|id| editor.get_dynamic_constant(id).is_largest())
                .unwrap_or(false)
            || nodes[id.idx()].is_undef()
    }
    
    pub fn is_smallest(editor: &FunctionEditor, id: NodeID) -> bool {
        let nodes = &editor.func().nodes;
        nodes[id.idx()]
            .try_constant()
            .map(|id| editor.get_constant(id).is_smallest())
            .unwrap_or(false)
            || nodes[id.idx()]
                .try_dynamic_constant()
                .map(|id| editor.get_dynamic_constant(id).is_smallest())
                .unwrap_or(false)
            || nodes[id.idx()].is_undef()
    }