use std::collections::{HashMap, HashSet};
use std::iter::zip;

use bitvec::{order::Lsb0, vec::BitVec};
use hercules_ir::{ir::*, LoopTree};

use crate::*;

type NodeVec = BitVec<u8, Lsb0>;
pub fn calculate_fork_nodes(
    editor: &FunctionEditor,
    inner_control: &NodeVec,
    fork: NodeID,
) -> HashSet<NodeID> {
    // Stop on PHIs / reduces outside of loop.
    let stop_on: HashSet<NodeID> = editor
        .node_ids()
        .filter(|node| {
            let data = &editor.func().nodes[node.idx()];

            // External Phi
            if let Node::Phi { control, data: _ } = data {
                if match inner_control.get(control.idx()) {
                    Some(v) => !*v, //
                    None => true,   // Doesn't exist, must be external
                } {
                    return true;
                }
            }
            // External Reduce
            if let Node::Reduce {
                control,
                init: _,
                reduct: _,
            } = data
            {
                if match inner_control.get(control.idx()) {
                    Some(v) => !*v, //
                    None => true,   // Doesn't exist, must be external
                } {
                    return true;
                }
            }

            // External Control
            if data.is_control() {
                return match inner_control.get(node.idx()) {
                    Some(v) => !*v, //
                    None => true,   // Doesn't exist, must be external
                };
            }
            // else
            return false;
        })
        .collect();

    let reduces: Vec<_> = editor
        .node_ids()
        .filter(|node| {
            let Node::Reduce { control, .. } = editor.func().nodes[node.idx()] else {
                return false;
            };
            match inner_control.get(control.idx()) {
                Some(v) => *v,
                None => false,
            }
        })
        .chain(
            editor
                .get_users(fork)
                .filter(|node| editor.node(node).is_thread_id()),
        )
        .collect();

    let all_users: HashSet<NodeID> = reduces
        .clone()
        .iter()
        .flat_map(|phi| walk_all_users_stop_on(*phi, editor, stop_on.clone()))
        .chain(reduces.clone())
        .collect();

    let all_uses: HashSet<_> = reduces
        .clone()
        .iter()
        .flat_map(|phi| walk_all_uses_stop_on(*phi, editor, stop_on.clone()))
        .chain(reduces)
        .filter(|node| {
            // Get rid of nodes in stop_on
            !stop_on.contains(node)
        })
        .collect();

    all_users.intersection(&all_uses).cloned().collect()
}
/*
 * Convert forks back into loops right before codegen when a backend is not
 * lowering a fork-join to vector / parallel code. Lowering fork-joins into
 * sequential loops in LLVM is actually not entirely trivial, so it's easier to
 * just do this transformation within Hercules IR.
 */

// FIXME: Only works on fully split fork nests.
pub fn unforkify_all(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    loop_tree: &LoopTree,
) {
    for l in loop_tree.bottom_up_loops().into_iter().rev() {
        if !editor.node(l.0).is_fork() {
            continue;
        }

        let fork = l.0;
        let join = fork_join_map[&fork];

        unforkify(editor, fork, join, loop_tree);
    }
}

pub fn unforkify_one(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    loop_tree: &LoopTree,
) {
    for l in loop_tree.bottom_up_loops().into_iter().rev() {
        if !editor.node(l.0).is_fork() {
            continue;
        }

        let fork = l.0;
        let join = fork_join_map[&fork];

        if unforkify(editor, fork, join, loop_tree) {
            break;
        }
    }
}

pub fn unforkify(
    editor: &mut FunctionEditor,
    fork: NodeID,
    join: NodeID,
    loop_tree: &LoopTree,
) -> bool {
    let mut zero_cons_id = ConstantID::new(0);
    let mut one_cons_id = ConstantID::new(0);
    assert!(editor.edit(|mut edit| {
        zero_cons_id = edit.add_constant(Constant::UnsignedInteger64(0));
        one_cons_id = edit.add_constant(Constant::UnsignedInteger64(1));
        Ok(edit)
    }));

    // Convert the fork to a region, thread IDs to a single phi, reduces to
    // phis, and the join to a branch at the top of the loop. The previous
    // control insides of the fork-join should become the successor of the true
    // projection node, and what was the use of the join should become a use of
    // the new region.
    let fork_nodes = calculate_fork_nodes(editor, loop_tree.nodes_in_loop_bitvec(fork), fork);

    let nodes = &editor.func().nodes;
    let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
    if factors.len() > 1 {
        // For now, don't convert multi-dimensional fork-joins. Rely on pass
        // that splits fork-joins.
        return false;
    }
    let join_control = nodes[join.idx()].try_join().unwrap();
    let tids: Vec<_> = editor
        .get_users(fork)
        .filter(|id| nodes[id.idx()].is_thread_id())
        .collect();
    let reduces: Vec<_> = editor
        .get_users(join)
        .filter(|id| nodes[id.idx()].is_reduce())
        .collect();

    let num_nodes = editor.node_ids().len();
    let region_id = NodeID::new(num_nodes);
    let if_id = NodeID::new(num_nodes + 1);
    let proj_back_id = NodeID::new(num_nodes + 2);
    let proj_exit_id = NodeID::new(num_nodes + 3);
    let zero_id = NodeID::new(num_nodes + 4);
    let one_id = NodeID::new(num_nodes + 5);
    let indvar_id = NodeID::new(num_nodes + 6);
    let add_id = NodeID::new(num_nodes + 7);
    let dc_id = NodeID::new(num_nodes + 8);
    let neq_id = NodeID::new(num_nodes + 9);

    let guard_if_id = NodeID::new(num_nodes + 10);
    let guard_join_id = NodeID::new(num_nodes + 11);
    let guard_taken_proj_id = NodeID::new(num_nodes + 12);
    let guard_skipped_proj_id = NodeID::new(num_nodes + 13);
    let guard_cond_id = NodeID::new(num_nodes + 14);

    let phi_ids = (num_nodes + 15..num_nodes + 15 + reduces.len()).map(NodeID::new);
    let s = num_nodes + 15 + reduces.len();
    let join_phi_ids = (s..s + reduces.len()).map(NodeID::new);

    let guard_cond = Node::Binary {
        left: zero_id,
        right: dc_id,
        op: BinaryOperator::LT,
    };
    let guard_if = Node::If {
        control: fork_control,
        cond: guard_cond_id,
    };
    let guard_taken_proj = Node::Projection {
        control: guard_if_id,
        selection: 1,
    };
    let guard_skipped_proj = Node::Projection {
        control: guard_if_id,
        selection: 0,
    };
    let guard_join = Node::Region {
        preds: Box::new([guard_skipped_proj_id, proj_exit_id]),
    };

    let region = Node::Region {
        preds: Box::new([guard_taken_proj_id, proj_back_id]),
    };
    let if_node = Node::If {
        control: join_control,
        cond: neq_id,
    };
    let proj_back = Node::Projection {
        control: if_id,
        selection: 1,
    };
    let proj_exit = Node::Projection {
        control: if_id,
        selection: 0,
    };
    let zero = Node::Constant { id: zero_cons_id };
    let one = Node::Constant { id: one_cons_id };
    let indvar = Node::Phi {
        control: region_id,
        data: Box::new([zero_id, add_id]),
    };
    let add = Node::Binary {
        op: BinaryOperator::Add,
        left: indvar_id,
        right: one_id,
    };
    let dc = Node::DynamicConstant { id: factors[0] };
    let neq = Node::Binary {
        op: BinaryOperator::NE,
        left: add_id,
        right: dc_id,
    };
    let (phis, join_phis): (Vec<_>, Vec<_>) = reduces
        .iter()
        .map(|reduce_id| {
            let (_, init, reduct) = nodes[reduce_id.idx()].try_reduce().unwrap();
            (
                Node::Phi {
                    control: region_id,
                    data: Box::new([init, reduct]),
                },
                Node::Phi {
                    control: guard_join_id,
                    data: Box::new([init, reduct]),
                },
            )
        })
        .unzip();

    editor.edit(|mut edit| {
        assert_eq!(edit.add_node(region), region_id);
        assert_eq!(edit.add_node(if_node), if_id);
        assert_eq!(edit.add_node(proj_back), proj_back_id);
        assert_eq!(edit.add_node(proj_exit), proj_exit_id);
        assert_eq!(edit.add_node(zero), zero_id);
        assert_eq!(edit.add_node(one), one_id);
        assert_eq!(edit.add_node(indvar), indvar_id);
        assert_eq!(edit.add_node(add), add_id);
        assert_eq!(edit.add_node(dc), dc_id);
        assert_eq!(edit.add_node(neq), neq_id);
        assert_eq!(edit.add_node(guard_if), guard_if_id);
        assert_eq!(edit.add_node(guard_join), guard_join_id);
        assert_eq!(edit.add_node(guard_taken_proj), guard_taken_proj_id);
        assert_eq!(edit.add_node(guard_skipped_proj), guard_skipped_proj_id);
        assert_eq!(edit.add_node(guard_cond), guard_cond_id);

        for (phi_id, phi) in zip(phi_ids.clone(), &phis) {
            assert_eq!(edit.add_node(phi.clone()), phi_id);
        }
        for (phi_id, phi) in zip(join_phi_ids.clone(), &join_phis) {
            assert_eq!(edit.add_node(phi.clone()), phi_id);
        }

        edit = edit.replace_all_uses(fork, region_id)?;
        edit = edit.replace_all_uses_where(join, guard_join_id, |usee| *usee != if_id)?;
        edit.sub_edit(fork, region_id);
        edit.sub_edit(join, if_id);
        for tid in tids.iter() {
            edit.sub_edit(*tid, indvar_id);
            edit = edit.replace_all_uses(*tid, indvar_id)?;
        }
        for (((reduce, phi_id), phi), join_phi_id) in
            zip(reduces.iter(), phi_ids).zip(phis).zip(join_phi_ids)
        {
            edit.sub_edit(*reduce, phi_id);
            let Node::Phi { control: _, data } = phi else {
                panic!()
            };
            edit = edit
                .replace_all_uses_where(*reduce, join_phi_id, |usee| !fork_nodes.contains(usee))?; //, |usee| *usee != *reduct)?;
            edit = edit.replace_all_uses_where(*reduce, phi_id, |usee| {
                fork_nodes.contains(usee) || *usee == data[1]
            })?;
            edit = edit.delete_node(*reduce)?;
        }

        edit = edit.delete_node(fork)?;
        edit = edit.delete_node(join)?;
        for tid in tids {
            edit = edit.delete_node(tid)?;
        }

        Ok(edit)
    })
}