use std::collections::{HashMap, HashSet};

use nestify::nest;

use hercules_ir::*;

use crate::*;

 * Substitute all uses of dynamic constants in a type that are keys in the substs map with the
 * dynamic constant value for that key. Return the substituted version of the type, once memoized.
pub fn substitute_dynamic_constants_in_type(
    substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    ty: TypeID,
    edit: &mut FunctionEdit,
) -> TypeID {
    // Look inside the type for references to dynamic constants.
    let ty_clone = edit.get_type(ty).clone();
    match ty_clone {
        Type::Product(ref fields) => {
            let new_fields = fields
                .map(|field_id| substitute_dynamic_constants_in_type(substs, *field_id, edit))
            if new_fields != *fields {
            } else {
        Type::Summation(ref variants) => {
            let new_variants = variants
                .map(|variant_id| substitute_dynamic_constants_in_type(substs, *variant_id, edit))
            if new_variants != *variants {
            } else {
        Type::Array(elem_ty, ref dims) => {
            let new_elem_ty = substitute_dynamic_constants_in_type(substs, elem_ty, edit);
            let new_dims = dims
                .map(|dim_id| substitute_dynamic_constants(substs, *dim_id, edit))
            if new_elem_ty != elem_ty || new_dims != *dims {
                edit.add_type(Type::Array(new_elem_ty, new_dims))
            } else {
        _ => ty,

 * Substitute all uses of dynamic constants in a dynamic constant dc that are keys in the
 * substs map and replace them with their appropriate replacement values. Return the substituted
 * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create
 * new dynamic constants, which can only be done inside an edit.
pub fn substitute_dynamic_constants(
    substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    dc: DynamicConstantID,
    edit: &mut FunctionEdit,
) -> DynamicConstantID {
    // If this dynamic constant should be substituted, just return the substitution
    if let Some(subst) = substs.get(&dc) {
        return *subst;

    // Look inside the dynamic constant to perform substitution in its children
    let dc_clone = edit.get_dynamic_constant(dc).clone();
    match dc_clone {
        DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc,
        DynamicConstant::Add(xs) => {
            let new_xs = xs
                .map(|x| substitute_dynamic_constants(substs, *x, edit))
            if new_xs != xs {
            } else {
        DynamicConstant::Sub(left, right) => {
            let new_left = substitute_dynamic_constants(substs, left, edit);
            let new_right = substitute_dynamic_constants(substs, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right))
            } else {
        DynamicConstant::Mul(xs) => {
            let new_xs = xs
                .map(|x| substitute_dynamic_constants(substs, *x, edit))
            if new_xs != xs {
            } else {
        DynamicConstant::Div(left, right) => {
            let new_left = substitute_dynamic_constants(substs, left, edit);
            let new_right = substitute_dynamic_constants(substs, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right))
            } else {
        DynamicConstant::Rem(left, right) => {
            let new_left = substitute_dynamic_constants(substs, left, edit);
            let new_right = substitute_dynamic_constants(substs, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right))
            } else {
        DynamicConstant::Min(xs) => {
            let new_xs = xs
                .map(|x| substitute_dynamic_constants(substs, *x, edit))
            if new_xs != xs {
            } else {
        DynamicConstant::Max(xs) => {
            let new_xs = xs
                .map(|x| substitute_dynamic_constants(substs, *x, edit))
            if new_xs != xs {
            } else {

 * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return
 * the substituted version of the constant, once memozied.
pub fn substitute_dynamic_constants_in_constant(
    substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    cons: ConstantID,
    edit: &mut FunctionEdit,
) -> ConstantID {
    // Look inside the type for references to dynamic constants.
    let cons_clone = edit.get_constant(cons).clone();
    match cons_clone {
        Constant::Product(ty, fields) => {
            let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
            let new_fields = fields
                .map(|field_id| substitute_dynamic_constants_in_constant(substs, *field_id, edit))
            if new_ty != ty || new_fields != fields {
                edit.add_constant(Constant::Product(new_ty, new_fields))
            } else {
        Constant::Summation(ty, idx, variant) => {
            let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
            let new_variant = substitute_dynamic_constants_in_constant(substs, variant, edit);
            if new_ty != ty || new_variant != variant {
                edit.add_constant(Constant::Summation(new_ty, idx, new_variant))
            } else {
        Constant::Array(ty) => {
            let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit);
            if new_ty != ty {
            } else {
        _ => cons,

 * Substitute all uses of the dynamic constants specified by the subst map in a node.
pub fn substitute_dynamic_constants_in_node(
    substs: &HashMap<DynamicConstantID, DynamicConstantID>,
    node: &mut Node,
    edit: &mut FunctionEdit,
) {
    match node {
        Node::Fork {
            control: _,
        } => {
            for factor in factors.into_iter() {
                *factor = substitute_dynamic_constants(substs, *factor, edit);
        Node::Constant { id } => {
            *id = substitute_dynamic_constants_in_constant(substs, *id, edit);
        Node::DynamicConstant { id } => {
            *id = substitute_dynamic_constants(substs, *id, edit);
        Node::Call {
            control: _,
            function: _,
            args: _,
        } => {
            for dc_arg in dynamic_constants.into_iter() {
                *dc_arg = substitute_dynamic_constants(substs, *dc_arg, edit);
        _ => {}

 * Top level function to make a function have only a single return.
pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
    let returns: Vec<NodeID> = (0..editor.func().nodes.len())
        .filter(|idx| editor.func().nodes[*idx].is_return())
    if returns.len() == 1 {
        return Some(returns[0]);
    let preds_before_returns: Box<[NodeID]> = returns
        .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0])

    let num_return_data = editor.func().return_types.len();
    let data_to_return: Vec<Box<[NodeID]>> = (0..num_return_data)
        .map(|idx| {
                .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[idx + 1])

    // All of the old returns get replaced in a single edit.
    let mut new_return = None;
    editor.edit(|mut edit| {
        let region = edit.add_node(Node::Region {
            preds: preds_before_returns,
        let return_vals = data_to_return
            .map(|data| {
                edit.add_node(Node::Phi {
                    control: region,
        for ret in returns {
            edit = edit.delete_node(ret)?;
        new_return = Some(edit.add_node(Node::Return {
            control: region,
            data: return_vals,

pub fn contains_between_control_flow(func: &Function) -> bool {
    let num_control = func.nodes.iter().filter(|node| node.is_control()).count();
    assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node.");
    num_control > 2

 * Top level function to ensure a Hercules function contains at least one
 * control node that isn't the start or return nodes.
pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
    if !contains_between_control_flow(editor.func()) {
        let ret = editor
            .filter(|id| editor.func().nodes[id.idx()].is_control())
        let Node::Return { control, ref data } = editor.func().nodes[ret.idx()] else {
            panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.")
        assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node.");
        let data = data.clone();
        let mut region_id = None;
        editor.edit(|mut edit| {
            edit = edit.delete_node(ret)?;
            region_id = Some(edit.add_node(Node::Region {
                preds: Box::new([NodeID::new(0)]),
            edit.add_node(Node::Return {
                control: region_id.unwrap(),
    } else {
                .filter(|id| editor.func().nodes[id.idx()].is_control())

pub type DenseNodeMap<T> = Vec<T>;
pub type SparseNodeMap<T> = HashMap<NodeID, T>;

nest! {
#[derive(Clone, Debug)]
pub struct NodeIterator<'a> {
    pub direction:
        #[derive(Clone, Debug, PartialEq)]
        pub enum Direction {
    visited: DenseNodeMap<bool>,
    stack: Vec<NodeID>,
    func: &'a FunctionEditor<'a>, // Maybe this is an enum, def use can be gotten from the function or from the editor.
    // `stop condition`, then return all nodes that caused stoppage i.e the frontier of the search.
    stop_on: HashSet<NodeID>, // Don't add neighbors of these.

pub fn walk_all_uses<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> {
    let len = editor.func().nodes.len();
    NodeIterator {
        direction: Direction::Uses,
        visited: vec![false; len],
        stack: vec![node],
        func: editor,
        stop_on: HashSet::new(),

pub fn walk_all_users<'a>(node: NodeID, editor: &'a FunctionEditor<'a>) -> NodeIterator<'a> {
    let len = editor.func().nodes.len();
    NodeIterator {
        direction: Direction::Users,
        visited: vec![false; len],
        stack: vec![node],
        func: editor,
        stop_on: HashSet::new(),

pub fn walk_all_uses_stop_on<'a>(
    node: NodeID,
    editor: &'a FunctionEditor<'a>,
    stop_on: HashSet<NodeID>,
) -> NodeIterator<'a> {
    let len = editor.func().nodes.len();
    let uses = editor.get_uses(node).collect();
    NodeIterator {
        direction: Direction::Uses,
        visited: vec![false; len],
        stack: uses,
        func: editor,

pub fn walk_all_users_stop_on<'a>(
    node: NodeID,
    editor: &'a FunctionEditor<'a>,
    stop_on: HashSet<NodeID>,
) -> NodeIterator<'a> {
    let len = editor.func().nodes.len();
    let users = editor.get_users(node).collect();
    NodeIterator {
        direction: Direction::Users,
        visited: vec![false; len],
        stack: users,
        func: editor,

impl<'a> Iterator for NodeIterator<'a> {
    type Item = NodeID;

    fn next(&mut self) -> Option<Self::Item> {
        while let Some(current) = self.stack.pop() {
            if !self.visited[current.idx()] {
                self.visited[current.idx()] = true;

                if !self.stop_on.contains(&current) {
                    if self.direction == Direction::Uses {
                        for neighbor in self.func.get_uses(current) {
                    } else {
                        for neighbor in self.func.get_users(current) {

                return Some(current);

 * Materializes an einsum expression into an IR node tree. Replaces thread IDs
 * with provides node IDs. Doesn't materialize reductions or comprehensions.
pub fn materialize_simple_einsum_expr(
    edit: &mut FunctionEdit,
    id: MathID,
    env: &MathEnv,
    dim_substs: &[NodeID],
) -> NodeID {
    match env[id.idx()] {
        MathExpr::Zero(ty) => {
            let cons_id = edit.add_zero_constant(ty);
            edit.add_node(Node::Constant { id: cons_id })
        MathExpr::One(ty) => {
            let cons_id = edit.add_one_constant(ty);
            edit.add_node(Node::Constant { id: cons_id })
        MathExpr::OpaqueNode(id) => id,
        MathExpr::ThreadID(dim) => dim_substs[dim.0],
        MathExpr::Read(collect, ref indices) => {
            let collect = materialize_simple_einsum_expr(edit, collect, env, dim_substs);
            let indices = Box::new([Index::Position(
                    .map(|idx| materialize_simple_einsum_expr(edit, *idx, env, dim_substs))
            edit.add_node(Node::Read { collect, indices })
        MathExpr::Unary(op, input) => {
            let input = materialize_simple_einsum_expr(edit, input, env, dim_substs);
            edit.add_node(Node::Unary { op, input })
        MathExpr::Binary(op, left, right) => {
            let left = materialize_simple_einsum_expr(edit, left, env, dim_substs);
            let right = materialize_simple_einsum_expr(edit, right, env, dim_substs);
            edit.add_node(Node::Binary { op, left, right })
        MathExpr::Ternary(op, first, second, third) => {
            let first = materialize_simple_einsum_expr(edit, first, env, dim_substs);
            let second = materialize_simple_einsum_expr(edit, second, env, dim_substs);
            let third = materialize_simple_einsum_expr(edit, third, env, dim_substs);
            edit.add_node(Node::Ternary {
        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
            let args = args
                .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
            edit.add_node(Node::IntrinsicCall { intrinsic, args })
        _ => panic!(),

 * Get the node IDs referred to in position indices in a indices set.
pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ {
        .filter_map(|index| {
            if let Index::Position(indices) = index {
            } else {
        .flat_map(|pos| pos.iter())
        .map(|id| *id)

 * Checks if a set of indices is fully parallel over a set of forks - that is,
 * every thread ID from every fork appears at least once in positions in the
 * indices set.
pub fn indices_parallel_over_forks<I>(
    editor: &FunctionEditor,
    indices: &[Index],
    mut forks: I,
) -> bool
    I: Iterator<Item = NodeID>,
    // Get the forks corresponding to position uses of bare thread ids.
    let nodes = &editor.func().nodes;
    let fork_thread_id_pairs = node_indices(indices).filter_map(|id| {
        if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
            Some((control, dimension))
        } else if let Node::Binary {
            op: BinaryOperator::Add,
            left: tid,
            right: cons,
        } = nodes[id.idx()]
            && let Node::ThreadID { control, dimension } = nodes[tid.idx()]
            && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
            Some((control, dimension))
        } else if let Node::Binary {
            op: BinaryOperator::Add,
            left: cons,
            right: tid,
        } = nodes[id.idx()]
            && let Node::ThreadID { control, dimension } = nodes[tid.idx()]
            && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
            Some((control, dimension))
        } else {
    let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new();
    for (fork, dim) in fork_thread_id_pairs {

    // If each fork the query is over is represented and each of its dimensions
    // is represented, then the indices are parallel over the forks.
    forks.all(|fork| {
        let Some(mut rep_dims) = rep_forks.remove(&fork) else {
            return false;

        nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len()

pub fn is_zero(editor: &FunctionEditor, id: NodeID) -> bool {
    let nodes = &editor.func().nodes;
        .map(|id| editor.get_constant(id).is_zero())
        || nodes[id.idx()]
            .map(|id| editor.get_dynamic_constant(id).is_zero())
        || nodes[id.idx()].is_undef()

pub fn is_one(editor: &FunctionEditor, id: NodeID) -> bool {
    let nodes = &editor.func().nodes;
        .map(|id| editor.get_constant(id).is_one())
        || nodes[id.idx()]
            .map(|id| editor.get_dynamic_constant(id).is_one())
        || nodes[id.idx()].is_undef()

pub fn is_largest(editor: &FunctionEditor, id: NodeID) -> bool {
    let nodes = &editor.func().nodes;
        .map(|id| editor.get_constant(id).is_largest())
        || nodes[id.idx()]
            .map(|id| editor.get_dynamic_constant(id).is_largest())
        || nodes[id.idx()].is_undef()

pub fn is_smallest(editor: &FunctionEditor, id: NodeID) -> bool {
    let nodes = &editor.func().nodes;
        .map(|id| editor.get_constant(id).is_smallest())
        || nodes[id.idx()]
            .map(|id| editor.get_dynamic_constant(id).is_smallest())
        || nodes[id.idx()].is_undef()