use hercules_ir::ir::*;

use crate::*;

/*
 * Top-level function for running interprocedural analysis.
 *
 * IP SROA expects that all nodes in all functions provided to it can be edited,
 * since it needs to be able to modify both the functions whose types are being
 * changed and call sites of those functions. What functions to run IP SROA on
 * is therefore specified by a separate argument.
 *
 * This optimization also takes an allow_sroa_arrays arguments (like non-IP
 * SROA) which controls whether it will break up products of arrays.
 */
pub fn interprocedural_sroa(
    editors: &mut Vec<FunctionEditor>,
    types: &Vec<Vec<TypeID>>,
    func_selection: &Vec<bool>,
    allow_sroa_arrays: bool,
) {
    let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| {
        editor.get_type(typ).is_product()
            && (allow_sroa_arrays || !type_contains_array(editor, typ))
    };

    let callsites = get_callsites(editors);

    for ((func_id, apply), callsites) in (0..func_selection.len())
        .map(FunctionID::new)
        .zip(func_selection.iter())
        .zip(callsites.into_iter())
    {
        if !apply {
            continue;
        }

        let editor: &mut FunctionEditor = &mut editors[func_id.idx()];
        let param_types = &editor.func().param_types.to_vec();
        let return_types = &editor.func().return_types.to_vec();

        // We determine the new param/return types of the function and track a
        // map that tells us how the old param/return values are constructed
        // from the new ones.
        let mut new_param_types = vec![];
        let mut old_param_type_map = vec![];
        let mut new_return_types = vec![];
        let mut old_return_type_map = vec![];
        let mut changed = false;

        for par_typ in param_types.iter() {
            if !can_sroa_type(editor, *par_typ) {
                old_param_type_map.push(IndexTree::Leaf(new_param_types.len()));
                new_param_types.push(*par_typ);
            } else {
                let (types, index) = sroa_type(editor, *par_typ, new_param_types.len());
                old_param_type_map.push(index);
                new_param_types.extend(types);
                changed = true;
            }
        }

        for ret_typ in return_types.iter() {
            if !can_sroa_type(editor, *ret_typ) {
                old_return_type_map.push(IndexTree::Leaf(new_return_types.len()));
                new_return_types.push(*ret_typ);
            } else {
                let (types, index) = sroa_type(editor, *ret_typ, new_return_types.len());
                old_return_type_map.push(index);
                new_return_types.extend(types);
                changed = true;
            }
        }

        // If the param/return types aren't changed by IP SROA, skip to the next
        // function.
        if !changed {
            continue;
        }

        // Modify each parameter in the current function and the param types.
        let mut param_nodes: Vec<_> = vec![vec![]; param_types.len()];
        for id in editor.node_ids() {
            if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
                param_nodes[idx].push(id);
            }
        }
        println!("{}", editor.func().name);
        let success = editor.edit(|mut edit| {
            for (idx, ids) in param_nodes.into_iter().enumerate() {
                let new_indices = &old_param_type_map[idx];
                let built = if let IndexTree::Leaf(new_idx) = new_indices {
                    edit.add_node(Node::Parameter { index: *new_idx })
                } else {
                    let prod_ty = param_types[idx];
                    let cons = edit.add_zero_constant(prod_ty);
                    let mut cons = edit.add_node(Node::Constant { id: cons });
                    new_indices.for_each(|idx: &Vec<Index>, param_idx: &usize| {
                        let param = edit.add_node(Node::Parameter { index: *param_idx });
                        cons = edit.add_node(Node::Write {
                            collect: cons,
                            data: param,
                            indices: idx.clone().into_boxed_slice(),
                        });
                    });
                    cons
                };
                for id in ids {
                    edit = edit.replace_all_uses(id, built)?;
                    edit = edit.delete_node(id)?;
                }
            }

            edit.set_param_types(new_param_types);
            Ok(edit)
        });
        assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");

        // Modify each return in the current function and the return types.
        let return_nodes: Vec<_> = editor
            .node_ids()
            .filter(|id| editor.func().nodes[id.idx()].is_return())
            .collect();
        let success = editor.edit(|mut edit| {
            for node in return_nodes {
                let Node::Return { control, data } = edit.get_node(node) else {
                    panic!()
                };
                let control = *control;
                let data = data.to_vec();

                let mut new_data = vec![];
                for (idx, (data_id, update_info)) in
                    data.into_iter().zip(old_return_type_map.iter()).enumerate()
                {
                    if let IndexTree::Leaf(new_idx) = update_info {
                        // Unchanged return value
                        assert!(new_data.len() == *new_idx);
                        new_data.push(data_id);
                    } else {
                        // SROA'd return value
                        let reads = generate_reads_edit(&mut edit, return_types[idx], data_id);
                        reads.zip(update_info).for_each(|_, (read_id, ret_idx)| {
                            assert!(new_data.len() == **ret_idx);
                            new_data.push(*read_id);
                        });
                    }
                }

                let new_ret = edit.add_node(Node::Return {
                    control,
                    data: new_data.into(),
                });
                edit.sub_edit(node, new_ret);
                edit = edit.delete_node(node)?;
            }

            edit.set_return_types(new_return_types);
            Ok(edit)
        });
        assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");

        // Finally, update calls of this function.
        for (caller, callsite) in callsites {
            let editor = &mut editors[caller.idx()];
            assert!(editor.func_id() == caller);

            let projs = editor.get_users(callsite).collect::<Vec<_>>();
            for proj_id in projs {
                let Node::DataProjection { data: _, selection } = editor.node(proj_id) else {
                    panic!("Call has a non data-projection user");
                };
                let new_return_info = &old_return_type_map[*selection];
                let typ = types[caller.idx()][proj_id.idx()];
                replace_returned_value(editor, proj_id, typ, new_return_info, callsite);
            }

            let (control, callee, dc_args, args) =
                editor.func().nodes[callsite.idx()].try_call().unwrap();
            let dc_args = dc_args.clone();
            let args = args.clone();
            let success = editor.edit(|mut edit| {
                let mut new_args = vec![];
                for (idx, (data_id, update_info)) in
                    args.iter().zip(old_param_type_map.iter()).enumerate()
                {
                    if let IndexTree::Leaf(new_idx) = update_info {
                        // Unchanged parameter value
                        assert!(new_args.len() == *new_idx);
                        new_args.push(*data_id);
                    } else {
                        // SROA'd parameter value
                        let reads = generate_reads_edit(&mut edit, param_types[idx], *data_id);
                        reads.zip(update_info).for_each(|_, (read_id, ret_idx)| {
                            assert!(new_args.len() == **ret_idx);
                            new_args.push(*read_id);
                        });
                    }
                }
                let new_call = edit.add_node(Node::Call {
                    control,
                    function: callee,
                    dynamic_constants: dc_args,
                    args: new_args.into_boxed_slice(),
                });
                edit = edit.replace_all_uses(callsite, new_call)?;
                edit = edit.delete_node(callsite)?;
                Ok(edit)
            });
            assert!(success);
        }
    }
}

fn sroa_type(
    editor: &FunctionEditor,
    typ: TypeID,
    type_index: usize,
) -> (Vec<TypeID>, IndexTree<usize>) {
    match &*editor.get_type(typ) {
        Type::Product(ts) => {
            let mut res_types = vec![];
            let mut index = type_index;
            let mut children = vec![];
            for t in ts {
                let (types, child) = sroa_type(editor, *t, index);
                index += types.len();
                res_types.extend(types);
                children.push(child);
            }
            (res_types, IndexTree::Node(children))
        }
        _ => (vec![typ], IndexTree::Leaf(type_index)),
    }
}

// Returns a list for each function of the call sites of that function
fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)>> {
    let mut callsites = vec![vec![]; editors.len()];

    for editor in editors {
        let caller = editor.func_id();
        for (callsite, (_, callee, _, _)) in editor
            .func()
            .nodes
            .iter()
            .enumerate()
            .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c)))
        {
            assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");
            callsites[callee.idx()].push((caller, NodeID::new(callsite)));
        }
    }

    callsites
}

// Replaces a projection node (from before the function signature change) based on the of_new_call
// description (which tells how to construct the value from the new returned values).
fn replace_returned_value(
    editor: &mut FunctionEditor,
    proj_id: NodeID,
    proj_typ: TypeID,
    of_new_call: &IndexTree<usize>,
    call_node: NodeID,
) {
    let constant = generate_constant(editor, proj_typ);

    let success = editor.edit(|mut edit| {
        let mut new_val = edit.add_node(Node::Constant { id: constant });
        of_new_call.for_each(|idx, selection| {
            let new_proj = edit.add_node(Node::DataProjection {
                data: call_node,
                selection: *selection,
            });
            new_val = edit.add_node(Node::Write {
                collect: new_val,
                data: new_proj,
                indices: idx.clone().into(),
            });
        });

        edit = edit.replace_all_uses(proj_id, new_val)?;
        edit.delete_node(proj_id)
    });
    assert!(success);
}