use std::collections::HashMap;
use std::collections::HashSet;
use std::iter::zip;
use std::iter::FromIterator;

use itertools::Itertools;
use nestify::nest;

use hercules_ir::*;

use crate::*;

pub fn loop_bound_canon_toplevel(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    control_subgraph: &Subgraph,
    loops: &LoopTree,
) -> bool {
    let natural_loops = loops
        .bottom_up_loops()
        .into_iter()
        .filter(|(k, _)| editor.func().nodes[k.idx()].is_region());

    let natural_loops: Vec<_> = natural_loops.collect();

    for l in natural_loops {
        if editor.is_mutable(l.0)
            && canonicalize_single_loop_bounds(
                editor,
                control_subgraph,
                &Loop {
                    header: l.0,
                    control: l.1.clone(),
                },
            )
        {
            return true;
        }
    }
    return false;
}

pub fn canonicalize_single_loop_bounds(
    editor: &mut FunctionEditor,
    control_subgraph: &Subgraph,
    l: &Loop,
) -> bool {
    let function = editor.func();

    let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else {
        return false;
    };

    let LoopExit::Conditional {
        if_node: loop_if,
        condition_node,
    } = loop_condition.clone()
    else {
        return false;
    };

    let loop_variance = compute_loop_variance(editor, l);
    let ivs = compute_induction_vars(editor.func(), l, &loop_variance);
    let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition);

    if has_canonical_iv(editor, l, &ivs).is_some() {
        return false;
    }

    let loop_bound_iv_phis = get_loop_condition_ivs(editor, l, &ivs, &loop_condition);

    let (loop_bound_ivs, _): (Vec<InductionVariable>, Vec<InductionVariable>) = ivs
        .into_iter()
        .partition(|f| loop_bound_iv_phis.contains(&f.phi()));


    // Assume there is only one loop bound iv.
    if loop_bound_ivs.len() != 1 {
        return false;
    }

    let Some(iv) = loop_bound_ivs.first() else {
        return false;
    };

    let InductionVariable::Basic {
        node: iv_phi,
        initializer,
        final_value,
        update_expression,
        update_value,
    } = iv
    else {
        return false;
    };


    let Some(loop_pred) = editor
        .get_uses(l.header)
        .filter(|node| !l.control[node.idx()])
        .next()
    else {
        return false;
    };

    // If there is a guard, we need to edit it.

    // (init_id, bound_id, binop node, if node).

    // FIXME: This is not always correct, depends on lots of things about the loop IV. 
    let loop_bound_dc = match *editor.node(condition_node) {
        Node::Binary { left, right, op } => match op {
            BinaryOperator::LT => right,
            BinaryOperator::LTE => right,
            BinaryOperator::GT => {return false}
            BinaryOperator::GTE => {return false}
            BinaryOperator::EQ => {return false}
            BinaryOperator::NE => {return false}
            _ => {return false}
        },
        _ => {return false}
    };

    
    // FIXME: This is quite fragile.
    let mut guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| {
        let Node::ControlProjection {
            control,
            selection: _,
        } = editor.node(loop_pred)
        else {
            return None;
        };

        let Node::If { cond, ..} = editor.node(control) else {
            return None;
        };

        let Node::Binary { left, right, op } = editor.node(cond) else {
            return None;
        };

        let Node::Binary {
            left: _,
            right: r,
            op: loop_op,
        } = editor.node(condition_node)
        else {
            return None;
        };

        if op != loop_op {
            return None;
        }

        if left != initializer {
            return None;
        }

        if right != r {
            return None;
        }

        return Some((*left, *right, *cond, *control));
    })();

    // // If guard is none, if some, make sure it is a good guard! move on
    // if let Some((init_id, bound_id, binop_node, if_node))= potential_guard_info {

    // };

    // let fork_guard_condition =

    // Lift dc math should make all constant into DCs, so these should all be DCs.
    let Node::DynamicConstant { id: init_dc_id } = *editor.node(initializer) else {
        return false;
    };
    let Node::DynamicConstant { id: update_dc_id } = *editor.node(update_value) else {
        return false;
    };

    // We are assuming this is a simple loop bound (i.e only one induction variable involved), so that .
    let Node::DynamicConstant {
        id: loop_bound_dc_id,
    } = *editor.node(loop_bound_dc)
    else {
        return false;
    };

    // We need to do 4 (5) things, which are mostly separate.

    // 0) Make the update into addition.
    // 1) Adjust update to be 1 (and bounds).
    // 2) Make the update a positive value. / Transform the condition into a `<`
    // - Are these separate?
    // 4) Change init to start from 0.

    // 5) Find some way to get fork-guard-elim to work with the new fork.
    // ideally, this goes in fork-guard-elim, but for now we hack it to change the guard condition bounds
    // here when we edit the loop bounds.

    // Right now we are just going to do (4), because I am lazy!

    // Collect info about the loop condition transformation.
    let mut dc_bound_node = match *editor.node(condition_node) {
        Node::Binary { left, right, op } => match op {
            BinaryOperator::LT => {
                if left == *update_expression && editor.node(right).is_dynamic_constant() {
                    right
                } else {
                    return false;
                }
            }
            BinaryOperator::LTE => {
                if left == *update_expression && editor.node(right).is_dynamic_constant() {
                    right
                } else {
                    return false;
                }
            }
            BinaryOperator::GT => todo!(),
            BinaryOperator::GTE => todo!(),
            BinaryOperator::EQ => todo!(),
            BinaryOperator::NE => todo!(),
            BinaryOperator::Or => todo!(),
            BinaryOperator::And => todo!(),
            BinaryOperator::Xor => todo!(),
            _ => panic!(),
        },
        _ => return false,
    };

    let condition_node_data = editor.node(condition_node).clone();

    let Node::DynamicConstant {
        id: mut bound_node_dc_id,
    } = *editor.node(dc_bound_node)
    else {
        return false;
    };

    // If increment is negative (how in the world do we know that...)
    // Increment can be DefinetlyPostiive, Unknown, DefinetlyNegative.
    let misc_guard_thing: Option<Node> =  if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
        Some(editor.node(binop_node).clone())
    } else {
        None
    };

    let mut condition_node = condition_node;

    let result = editor.edit(|mut edit| {
        // 2) Transform the condition into a < (from <=)
        if let Node::Binary { left, right, op } = condition_node_data {
            if BinaryOperator::LTE == op && left == *update_expression {
                // Change the condition into <
                let new_bop = edit.add_node(Node::Binary { left, right, op: BinaryOperator::LT });
                
                // Change the bound dc to be bound_dc + 1
                let one = DynamicConstant::Constant(1);
                let one = edit.add_dynamic_constant(one);

                let tmp = DynamicConstant::add(bound_node_dc_id, one);
                let new_condition_dc = edit.add_dynamic_constant(tmp);

                let new_dc_bound_node = edit.add_node(Node::DynamicConstant { id: new_condition_dc });

                // // 5) Change loop guard:
                guard_info = if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
                    // Change binop node
                    let Some(Node::Binary { left, right, op }) =  misc_guard_thing else {unreachable!()};
                    let blah = edit.add_node(Node::DynamicConstant { id:  new_condition_dc});

                    // FIXME: Don't assume that right is the loop bound in the guard. 
                    let new_binop_node = edit.add_node(Node::Binary { left, right: blah, op: BinaryOperator::LT });

                    edit = edit.replace_all_uses_where(binop_node, new_binop_node, |usee| *usee == if_node)?;
                    Some((init_id, bound_id, new_binop_node, if_node))
                } else {guard_info};

                edit = edit.replace_all_uses_where(dc_bound_node, new_dc_bound_node, |usee| *usee == new_bop)?;
                edit = edit.replace_all_uses(condition_node, new_bop)?;

                // Change loop condition
                dc_bound_node = new_dc_bound_node;
                bound_node_dc_id = new_condition_dc;
                condition_node = new_bop;
            }
        };
        Ok(edit)
    });

    
    let update_expr_users: Vec<_> = editor
        .get_users(*update_expression)
        .filter(|node| *node != iv.phi() && *node != condition_node)
        .collect();
    let iv_phi_users: Vec<_> = editor
        .get_users(iv.phi())
        .filter(|node| *node != iv.phi() && *node != *update_expression)
        .collect();

    let result = editor.edit(|mut edit| {
        // 4) Second, change loop IV to go from 0..N.
        // we subtract off init from init and dc_bound_node,
        // and then we add it back to uses of the IV.
        let new_init_dc = DynamicConstant::Constant(0);
        let new_init = Node::DynamicConstant {
            id: edit.add_dynamic_constant(new_init_dc),
        };
        let new_init = edit.add_node(new_init);
        edit = edit.replace_all_uses_where(*initializer, new_init, |usee| *usee == iv.phi())?;

        let new_condition_dc = DynamicConstant::sub(bound_node_dc_id, init_dc_id);
        let new_condition_dc_id = Node::DynamicConstant {
            id: edit.add_dynamic_constant(new_condition_dc),
        };
        let new_condition_dc = edit.add_node(new_condition_dc_id);
        edit = edit
            .replace_all_uses_where(dc_bound_node, new_condition_dc, |usee| *usee == condition_node)?;

        // 5) Change loop guard:
        if let Some((init_id, bound_id, binop_node, if_node)) = guard_info {
            edit = edit.replace_all_uses_where(init_id, new_init, |usee| *usee == binop_node)?;
            edit =
                edit.replace_all_uses_where(bound_id, new_condition_dc, |usee| *usee == binop_node)?;
        }
        

        // 4) Add the offset back to users of the IV update expression
        let new_user = Node::Binary {
            left: *update_expression,
            right: *initializer,
            op: BinaryOperator::Add,
        };
        let new_user = edit.add_node(new_user);
        edit = edit.replace_all_uses_where(*update_expression, new_user, |usee| {
            *usee != iv.phi()
                && *usee != *update_expression
                && *usee != new_user
                && *usee != condition_node
        })?;

        // Add the offset back to users of the IV directly
        let new_user = Node::Binary {
            left: *iv_phi,
            right: *initializer,
            op: BinaryOperator::Add,
        };
        let new_user = edit.add_node(new_user);
        edit = edit.replace_all_uses_where(*iv_phi, new_user, |usee| {
            *usee != iv.phi() && *usee != *update_expression && *usee != new_user
        })?;

        Ok(edit)
    });

    return result;
}