extern crate hercules_ir;

use std::collections::HashMap;
use std::iter::zip;

use self::hercules_ir::callgraph::*;
use self::hercules_ir::def_use::*;
use self::hercules_ir::ir::*;
use self::hercules_ir::schedule::*;

use crate::*;

/*
 * Top level function to run inlining. Currently, inlines every function call,
 * since mutual recursion is not valid in Hercules IR.
 */
pub fn inline(
    editors: &mut [FunctionEditor],
    callgraph: &CallGraph,
    mut plans: Option<&mut Vec<Plan>>,
) {
    // Step 1: run topological sort on the call graph to inline the "deepest"
    // function first. Mutual recursion is not currently supported, so assert
    // that a topological sort exists.
    let mut num_calls: Vec<usize> = (0..editors.len())
        .map(|idx| callgraph.num_callees(FunctionID::new(idx)))
        .collect();
    let mut no_calls_stack: Vec<FunctionID> = num_calls
        .iter()
        .enumerate()
        .filter(|(_, num)| **num == 0)
        .map(|(idx, _)| FunctionID::new(idx))
        .collect();
    let mut topo = vec![];
    while let Some(no_call_func) = no_calls_stack.pop() {
        topo.push(no_call_func);
        for caller in callgraph.get_callers(no_call_func) {
            num_calls[caller.idx()] -= 1;
            if num_calls[caller.idx()] == 0 {
                no_calls_stack.push(*caller);
            }
        }
    }
    assert_eq!(
        topo.len(),
        editors.len(),
        "PANIC: Found mutual recursion in Hercules IR."
    );

    // Step 2: make sure each function has a single return node. If an edit
    // failed to make a function have a single return node, then we can't inline
    // calls of it.
    let single_return_nodes: Vec<_> = editors
        .iter_mut()
        .map(|editor| collapse_returns(editor))
        .collect();

    // Step 3: verify that each possible dynamic constant parameter index has a
    // single unique dynamic constant ID. If this isn't true, dynamic constant
    // substitution won't work, and this should be true anyway!
    let mut found_idxs = HashMap::new();
    for id in editors[0].dynamic_constant_ids() {
        let dc = editors[0].get_dynamic_constant(id);
        if let DynamicConstant::Parameter(idx) = *dc {
            assert!(!found_idxs.contains_key(&idx));
            found_idxs.insert(idx, id);
        }
    }
    let mut dc_param_idx_to_dc_id = vec![];
    for idx in 0..found_idxs.len() {
        dc_param_idx_to_dc_id.push(found_idxs[&idx]);
    }

    // Step 4: run inlining on each function individually. Iterate the functions
    // in topological order.
    for to_inline_id in topo {
        // Since Rust cannot analyze the accesses into an array of mutable
        // references, we need to do some weirdness here to simultaneously get:
        // 1. A mutable reference to the function we're modifying.
        // 2. Shared references to all of the functions called by that function.
        // We need to get the same for plans, if we receive them.
        let callees = callgraph.get_callees(to_inline_id);
        let editor_refs = get_mut_and_immuts(editors, to_inline_id, callees);
        let plan_refs = plans
            .as_mut()
            .map(|plans| get_mut_and_immuts(*plans, to_inline_id, callees));
        inline_func(
            editor_refs.0,
            editor_refs.1,
            plan_refs,
            &single_return_nodes,
            &dc_param_idx_to_dc_id,
        );
    }
}

/*
 * Helper function to get from an array of mutable references:
 * 1. A single mutable reference.
 * 2. Several shared references.
 * Where none of the references alias. We need to use this both for function
 * editors and plans.
 */
fn get_mut_and_immuts<'a, T, I: ID>(
    mut_refs: &'a mut [T],
    mut_id: I,
    shared_id: &[I],
) -> (&'a mut T, HashMap<I, &'a T>) {
    let mut all_id = Vec::from(shared_id);
    all_id.sort_unstable();
    all_id.insert(all_id.binary_search(&mut_id).unwrap_err(), mut_id);
    let mut mut_ref = None;
    let mut shared_refs = HashMap::new();
    let mut cursor = 0;
    let mut slice = &mut *mut_refs;
    for id in all_id {
        let (left, right) = slice.split_at_mut(id.idx() - cursor);
        cursor += left.len() + 1;
        let (left, right) = right.split_at_mut(1);
        let item = &mut left[0];
        if id == mut_id {
            assert!(mut_ref.is_none());
            mut_ref = Some(item);
        } else {
            shared_refs.insert(id, &*item);
        }
        slice = right;
    }

    (mut_ref.unwrap(), shared_refs)
}

/*
 * Run inlining on a single function. Pass a mutable reference to the function
 * to modify and shared references for all called functions.
 */
fn inline_func(
    editor: &mut FunctionEditor,
    called: HashMap<FunctionID, &FunctionEditor>,
    plans: Option<(&mut Plan, HashMap<FunctionID, &Plan>)>,
    single_return_nodes: &Vec<Option<NodeID>>,
    dc_param_idx_to_dc_id: &Vec<DynamicConstantID>,
) {
    let first_num_nodes = editor.func().nodes.len();
    for id in (0..first_num_nodes).map(NodeID::new) {
        // Break down the call node.
        let Node::Call {
            control,
            function,
            ref dynamic_constants,
            ref args,
        } = editor.func().nodes[id.idx()]
        else {
            continue;
        };

        // Assemble all the info we'll need to do the edit.
        let dcs_a = &dc_param_idx_to_dc_id[..dynamic_constants.len()];
        let dcs_b = dynamic_constants.clone();
        let args = args.clone();
        let old_num_nodes = editor.func().nodes.len();
        let old_id_to_new_id = |old_id: NodeID| NodeID::new(old_id.idx() + old_num_nodes);
        let call_pred = get_uses(&editor.func().nodes[control.idx()]);
        assert_eq!(call_pred.as_ref().len(), 1);
        let call_pred = call_pred.as_ref()[0];
        let called_func = called[&function].func();
        // We can't inline calls to functions with multiple returns.
        let Some(called_return) = single_return_nodes[function.idx()] else {
            continue;
        };
        let called_return_uses = get_uses(&called_func.nodes[called_return.idx()]);
        let called_return_pred = called_return_uses.as_ref()[0];
        let called_return_data = called_return_uses.as_ref()[1];

        // Perform the actual edit.
        let success = editor.edit(|mut edit| {
            // Insert the nodes from the called function. There are a few
            // special cases:
            // - Start: don't add start nodes - later, we'll replace_all_uses on
            //   the start node with the one predecessor of the call's region
            //   node.
            // - Parameter: don't add parameter nodes - later, we'll
            //   replace_all_uses on the parameter nodes with the arguments to
            //   the call node.
            // - Return: don't add return nodes - later, we'll replace_all_uses
            //   on the call's region node with the predecessor to the return
            //   node.
            for (idx, node) in called_func.nodes.iter().enumerate() {
                if node.is_start() || node.is_parameter() || node.is_return() {
                    // We still need to add some node to make sure the IDs line
                    // up. Just add a gravestone.
                    edit.add_node(Node::Start);
                    continue;
                }
                // Get the node from the callee function and replace all the
                // uses with the to-be IDs in the caller function.
                let mut node = node.clone();
                if node.is_fork()
                    || node.is_constant()
                    || node.is_dynamic_constant()
                    || node.is_call()
                {
                    // We have to perform the subsitution in two steps. First,
                    // we map every dynamic constant A to a non-sense dynamic
                    // constant ID. Second, we map each non-sense dynamic
                    // constant ID to the appropriate dynamic constant B. Why
                    // not just do this in one step from A to B? We update
                    // dynamic constants one at a time, so imagine the following
                    // A -> B mappings:
                    // ID 0 -> ID 1
                    // ID 1 -> ID 0
                    // First, we apply the first mapping. This changes all
                    // references to dynamic constant 0 to dynamic constant 1.
                    // Then, we apply the second mapping. This updates all
                    // already present references to dynamic constant 1, as well
                    // as the new references we just made in the first step. We
                    // actually want to institute all the updates
                    // *simultaneously*, hence the two step maneuver.
                    let num_dcs = edit.num_dynamic_constants();
                    for (dc_a, dc_n) in zip(dcs_a, num_dcs..) {
                        substitute_dynamic_constants_in_node(
                            *dc_a,
                            DynamicConstantID::new(dc_n),
                            &mut node,
                            &mut edit,
                        );
                    }
                    for (dc_n, dc_b) in zip(num_dcs.., dcs_b.iter()) {
                        substitute_dynamic_constants_in_node(
                            DynamicConstantID::new(dc_n),
                            *dc_b,
                            &mut node,
                            &mut edit,
                        );
                    }
                }
                let mut uses = get_uses_mut(&mut node);
                for u in uses.as_mut() {
                    **u = old_id_to_new_id(**u);
                }
                // Add the node and check that the IDs line up.
                let add_id = edit.add_node(node);
                assert_eq!(add_id, old_id_to_new_id(NodeID::new(idx)));
            }

            // Stitch the control use of the inlined start node with the
            // predecessor control node of the call's region.
            let start_node = &called_func.nodes[0];
            assert!(start_node.is_start());
            let start_id = old_id_to_new_id(NodeID::new(0));
            edit = edit.replace_all_uses(start_id, call_pred)?;

            // Stich the control use of the original call node's region with
            // the predecessor control of the inlined function's return.
            edit = edit.replace_all_uses(
                control,
                if called_return_pred == NodeID::new(0) {
                    call_pred
                } else {
                    old_id_to_new_id(called_return_pred)
                },
            )?;

            // Stitch uses of parameter nodes in the inlined function to the IDs
            // of arguments provided to the call node.
            for (node_idx, node) in called_func.nodes.iter().enumerate() {
                if let Node::Parameter { index } = node {
                    let param_id = old_id_to_new_id(NodeID::new(node_idx));
                    edit = edit.replace_all_uses(param_id, args[*index])?;
                }
            }

            // Finally, delete the call node.
            edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?;
            edit = edit.delete_node(control)?;
            edit = edit.delete_node(id)?;

            Ok(edit)
        });
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * dynamic constant C. Return the substituted version of C, once memoized. Takes
 * a mutable edit instead of an editor since this may create new dynamic
 * constants, which can only be done inside an edit.
 */
fn substitute_dynamic_constants(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    dc_c: DynamicConstantID,
    edit: &mut FunctionEdit,
) -> DynamicConstantID {
    // If C is just A, then just replace all of C with B.
    if dc_a == dc_c {
        return dc_b;
    }

    // Since we substitute non-sense dynamic constant IDs earlier, we explicitly
    // check that the provided ID to replace inside of is valid. Otherwise,
    // ignore.
    if dc_c.idx() >= edit.num_dynamic_constants() {
        return dc_c;
    }

    // If C is not just A, look inside of it to possibly substitute a child DC.
    let dc_clone = edit.get_dynamic_constant(dc_c).clone();
    match dc_clone {
        DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc_c,
        // This is a certified Rust moment.
        DynamicConstant::Add(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Add(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Sub(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Mul(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Mul(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Div(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right))
            } else {
                dc_c
            }
        }
        DynamicConstant::Rem(left, right) => {
            let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit);
            let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit);
            if new_left != left || new_right != right {
                edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right))
            } else {
                dc_c
            }
        }
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * type. Return the substituted version of the type, once memozied.
 */
fn substitute_dynamic_constants_in_type(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    ty: TypeID,
    edit: &mut FunctionEdit,
) -> TypeID {
    // Look inside the type for references to dynamic constants.
    let ty_clone = edit.get_type(ty).clone();
    match ty_clone {
        Type::Product(ref fields) => {
            let new_fields = fields
                .into_iter()
                .map(|field_id| substitute_dynamic_constants_in_type(dc_a, dc_b, *field_id, edit))
                .collect();
            if new_fields != *fields {
                edit.add_type(Type::Product(new_fields))
            } else {
                ty
            }
        }
        Type::Summation(ref variants) => {
            let new_variants = variants
                .into_iter()
                .map(|variant_id| {
                    substitute_dynamic_constants_in_type(dc_a, dc_b, *variant_id, edit)
                })
                .collect();
            if new_variants != *variants {
                edit.add_type(Type::Summation(new_variants))
            } else {
                ty
            }
        }
        Type::Array(elem_ty, ref dims) => {
            let new_elem_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, elem_ty, edit);
            let new_dims = dims
                .into_iter()
                .map(|dim_id| substitute_dynamic_constants(dc_a, dc_b, *dim_id, edit))
                .collect();
            if new_elem_ty != elem_ty || new_dims != *dims {
                edit.add_type(Type::Array(new_elem_ty, new_dims))
            } else {
                ty
            }
        }
        _ => ty,
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * constant. Return the substituted version of the constant, once memozied.
 */
fn substitute_dynamic_constants_in_constant(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    cons: ConstantID,
    edit: &mut FunctionEdit,
) -> ConstantID {
    // Look inside the type for references to dynamic constants.
    let cons_clone = edit.get_constant(cons).clone();
    match cons_clone {
        Constant::Product(ty, fields) => {
            let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit);
            let new_fields = fields
                .iter()
                .map(|field_id| {
                    substitute_dynamic_constants_in_constant(dc_a, dc_b, *field_id, edit)
                })
                .collect();
            if new_ty != ty || new_fields != fields {
                edit.add_constant(Constant::Product(new_ty, new_fields))
            } else {
                cons
            }
        }
        Constant::Summation(ty, idx, variant) => {
            let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit);
            let new_variant = substitute_dynamic_constants_in_constant(dc_a, dc_b, variant, edit);
            if new_ty != ty || new_variant != variant {
                edit.add_constant(Constant::Summation(new_ty, idx, new_variant))
            } else {
                cons
            }
        }
        Constant::Array(ty) => {
            let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit);
            if new_ty != ty {
                edit.add_constant(Constant::Array(new_ty))
            } else {
                cons
            }
        }
        _ => cons,
    }
}

/*
 * Substitute all uses of a dynamic constant A with dynamic constant B in a
 * node.
 */
fn substitute_dynamic_constants_in_node(
    dc_a: DynamicConstantID,
    dc_b: DynamicConstantID,
    node: &mut Node,
    edit: &mut FunctionEdit,
) {
    match node {
        Node::Fork {
            control: _,
            factors,
        } => {
            for factor in factors.into_iter() {
                *factor = substitute_dynamic_constants(dc_a, dc_b, *factor, edit);
            }
        }
        Node::Constant { id } => {
            *id = substitute_dynamic_constants_in_constant(dc_a, dc_b, *id, edit);
        }
        Node::DynamicConstant { id } => {
            *id = substitute_dynamic_constants(dc_a, dc_b, *id, edit);
        }
        Node::Call {
            control: _,
            function: _,
            dynamic_constants,
            args: _,
        } => {
            for dc_arg in dynamic_constants.into_iter() {
                *dc_arg = substitute_dynamic_constants(dc_a, dc_b, *dc_arg, edit);
            }
        }
        _ => {}
    }
}

/*
 * Top level function to make a function have only a single return.
 */
pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
    let returns: Vec<NodeID> = (0..editor.func().nodes.len())
        .filter(|idx| editor.func().nodes[*idx].is_return())
        .map(NodeID::new)
        .collect();
    assert!(!returns.is_empty());
    if returns.len() == 1 {
        return Some(returns[0]);
    }
    let preds_before_returns: Vec<NodeID> = returns
        .iter()
        .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0])
        .collect();
    let data_to_return: Vec<NodeID> = returns
        .iter()
        .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[1])
        .collect();

    // All of the old returns get replaced in a single edit.
    let mut new_return = None;
    editor.edit(|mut edit| {
        let region = edit.add_node(Node::Region {
            preds: preds_before_returns.into_boxed_slice(),
        });
        let phi = edit.add_node(Node::Phi {
            control: region,
            data: data_to_return.into_boxed_slice(),
        });
        for ret in returns {
            edit = edit.delete_node(ret)?;
        }
        new_return = Some(edit.add_node(Node::Return {
            control: region,
            data: phi,
        }));
        Ok(edit)
    });
    new_return
}