Skip to content
Snippets Groups Projects
utils.rs 13.47 KiB
use std::iter::zip;

use hercules_ir::def_use::*;
use hercules_ir::ir::*;

use crate::*;

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * type. Return the substituted version of the type, once memozied.
 */
pub(crate) fn substitute_dynamic_constants_in_type(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    ty: TypeID,
    edit: &mut FunctionEdit,
) -> TypeID {
    // Look inside the type for references to dynamic constants.
    let ty_clone = edit.get_type(ty).clone();
    match ty_clone {
        Type::Product(ref fields) => {
            let new_fields = fields
                .into_iter()
                .map(|field_id| substitute_dynamic_constants_in_type(dc_a, dc_b, *field_id, edit))
                .collect();
            if new_fields != *fields {
                edit.add_type(Type::Product(new_fields))
            } else {
                ty
            }
        }
        Type::Summation(ref variants) => {
            let new_variants = variants
                .into_iter()
                .map(|variant_id| {
                    substitute_dynamic_constants_in_type(dc_a, dc_b, *variant_id, edit)
                })
                .collect();
            if new_variants != *variants {
                edit.add_type(Type::Summation(new_variants))
            } else {
                ty
            }
        }
        Type::Array(elem_ty, ref dims) => {
            let new_elem_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, elem_ty, edit);
            let new_dims = dims
                .into_iter()
                .map(|dim_id| substitute_dynamic_constants(dc_a, dc_b, *dim_id, edit))
                .collect();
            if new_elem_ty != elem_ty || new_dims != *dims {
                edit.add_type(Type::Array(new_elem_ty, new_dims))
            } else {
                ty
            }
        }
        _ => ty,
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * dynamic constant C. Return the substituted version of C, once memoized. Takes
 * a mutable edit instead of an editor since this may create new dynamic
 * constants, which can only be done inside an edit.
 */
pub(crate) fn substitute_dynamic_constants(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    dc_c: DynamicConstantID,
    edit: &mut FunctionEdit,
) -> DynamicConstantID {
    // If C is just A, then just replace all of C with B.
    if dc_a == dc_c {
        return dc_b;
    }

    // Since we substitute non-sense dynamic constant IDs earlier, we explicitly
    // check that the provided ID to replace inside of is valid. Otherwise,
    // ignore.
    if dc_c.idx() >= edit.num_dynamic_constants() {
        return dc_c;
    }

    // If C is not just A, look inside of it to possibly substitute a child DC.
    let dc_clone = edit.get_dynamic_constant(dc_c).clone();
    match dc_clone {
        DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc_c,
        // This is a certified Rust moment.
        DynamicConstant::Add(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Add(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Sub(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Mul(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Mul(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Div(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Rem(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Min(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Min(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Max(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Max(new_left, new_right))
            } else {
                dc_c
            }
        }
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * constant. Return the substituted version of the constant, once memozied.
 */
pub(crate) fn substitute_dynamic_constants_in_constant(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    cons: ConstantID,
    edit: &mut FunctionEdit,
) -> ConstantID {
    // Look inside the type for references to dynamic constants.
    let cons_clone = edit.get_constant(cons).clone();
    match cons_clone {
        Constant::Product(ty, fields) => {
            let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit);
            let new_fields = fields
                .iter()
                .map(|field_id| {
                    substitute_dynamic_constants_in_constant(dc_a, dc_b, *field_id, edit)
                })
                .collect();
            if new_ty != ty || new_fields != fields {
                edit.add_constant(Constant::Product(new_ty, new_fields))
            } else {
                cons
            }
        }
        Constant::Summation(ty, idx, variant) => {
            let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit);
            let new_variant = substitute_dynamic_constants_in_constant(dc_a, dc_b, variant, edit);
            if new_ty != ty || new_variant != variant {
                edit.add_constant(Constant::Summation(new_ty, idx, new_variant))
            } else {
                cons
            }
        }
        Constant::Array(ty) => {
            let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit);
            if new_ty != ty {
                edit.add_constant(Constant::Array(new_ty))
            } else {
                cons
            }
        }
        _ => cons,
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * node.
 */
pub(crate) fn substitute_dynamic_constants_in_node(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    node: &mut Node,
    edit: &mut FunctionEdit,
) {
    match node {
        Node::Fork {
            control: _,
            factors,
        } => {
            for factor in factors.into_iter() {
                *factor = substitute_dynamic_constants(dc_a, dc_b, *factor, edit);
            }
        }
        Node::Constant { id } => {
            *id = substitute_dynamic_constants_in_constant(dc_a, dc_b, *id, edit);
        }
        Node::DynamicConstant { id } => {
            *id = substitute_dynamic_constants(dc_a, dc_b, *id, edit);
        }
        Node::Call {
            control: _,
            function: _,
            dynamic_constants,
            args: _,
        } => {
            for dc_arg in dynamic_constants.into_iter() {
                *dc_arg = substitute_dynamic_constants(dc_a, dc_b, *dc_arg, edit);
            }
        }
        _ => {}
    }
}

/*
 * Top level function to make a function have only a single return.
 */
pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
    let returns: Vec<NodeID> = (0..editor.func().nodes.len())
        .filter(|idx| editor.func().nodes[*idx].is_return())
        .map(NodeID::new)
        .collect();
    assert!(!returns.is_empty());
    if returns.len() == 1 {
        return Some(returns[0]);
    }
    let preds_before_returns: Vec<NodeID> = returns
        .iter()
        .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0])
        .collect();
    let data_to_return: Vec<NodeID> = returns
        .iter()
        .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[1])
        .collect();

    // All of the old returns get replaced in a single edit.
    let mut new_return = None;
    editor.edit(|mut edit| {
        let region = edit.add_node(Node::Region {
            preds: preds_before_returns.into_boxed_slice(),
        });
        let phi = edit.add_node(Node::Phi {
            control: region,
            data: data_to_return.into_boxed_slice(),
        });
        for ret in returns {
            edit = edit.delete_node(ret)?;
        }
        new_return = Some(edit.add_node(Node::Return {
            control: region,
            data: phi,
        }));
        Ok(edit)
    });
    new_return
}

pub(crate) fn contains_between_control_flow(func: &Function) -> bool {
    let num_control = func.nodes.iter().filter(|node| node.is_control()).count();
    assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node.");
    num_control > 2
}

/*
 * Top level function to ensure a Hercules function contains at least one
 * control node that isn't the start or return nodes.
 */
pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
    if !contains_between_control_flow(editor.func()) {
        let ret = editor
            .node_ids()
            .skip(1)
            .filter(|id| editor.func().nodes[id.idx()].is_control())
            .next()
            .unwrap();
        let Node::Return { control, data } = editor.func().nodes[ret.idx()] else {
            panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.")
        };
        assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node.");
        let mut region_id = None;
        editor.edit(|mut edit| {
            edit = edit.delete_node(ret)?;
            region_id = Some(edit.add_node(Node::Region {
                preds: Box::new([NodeID::new(0)]),
            }));
            edit.add_node(Node::Return {
                control: region_id.unwrap(),
                data,
            });
            Ok(edit)
        });
        region_id
    } else {
        Some(
            editor
                .get_users(NodeID::new(0))
                .filter(|id| editor.func().nodes[id.idx()].is_control())
                .next()
                .unwrap(),
        )
    }
}

/*
 * Helper function to tell if two lists of indices have the same structure.
 */
pub(crate) fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool {
    if indices1.len() == indices2.len() {
        let mut equiv = true;
        for pair in zip(indices1, indices2) {
            equiv = equiv
                && match pair {
                    (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2,
                    (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2,
                    (Index::Position(ref pos1), Index::Position(ref pos2)) => {
                        assert_eq!(pos1.len(), pos2.len());
                        true
                    }
                    _ => false,
                };
        }
        equiv
    } else {
        false
    }
}

/*
 * Helper function to determine if two lists of indices may overlap.
 */
pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
    for pair in zip(indices1, indices2) {
        match pair {
            // Check that the field numbers are the same.
            (Index::Field(idx1), Index::Field(idx2)) => {
                if idx1 != idx2 {
                    return false;
                }
            }
            // Variant indices always may overlap, since it's the same
            // underlying memory. Position indices always may overlap, since the
            // indexing nodes may be the same at runtime.
            (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {}
            _ => panic!(),
        }
    }
    // `zip` will exit as soon as either iterator is done - two sets of indices
    // may overlap when one indexes a larger sub-value than the other.
    true
}