use std::collections::{BTreeSet, HashMap, HashSet};
use std::iter::once;

use hercules_ir::ir::*;

use crate::*;

/*
 * Infer parallel fork-joins. These are fork-joins with only parallel reduction
 * variables.
 */
pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
    for id in editor.node_ids() {
        let func = editor.func();
        if !func.nodes[id.idx()].is_fork() {
            continue;
        }
        let join_id = fork_join_map[&id];
        let all_parallel_reduce = editor.get_users(join_id).all(|user| {
            func.schedules[user.idx()].contains(&Schedule::ParallelReduce)
                || func.nodes[user.idx()].is_control()
        });
        if all_parallel_reduce {
            editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork));
        }
    }
}

/*
 * Infer parallel reductions consisting of a simple cycle between a Reduce node
 * and a Write node, where indices of the Write are position indices using the
 * ThreadID nodes attached to the corresponding Fork, and data of the Write is
 * not in the Reduce node's cycle. This procedure also adds the ParallelReduce
 * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the
 * base Write node also has position indices of the ThreadID of the outer fork.
 * In other words, the complete Reduce chain is annotated with ParallelReduce,
 * as long as each ThreadID dimension appears in the positional indexing of the
 * original Write.
 */
pub fn infer_parallel_reduce(
    editor: &mut FunctionEditor,
    fork_join_map: &HashMap<NodeID, NodeID>,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
    let join_fork_map: HashMap<_, _> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();
    for id in editor.node_ids() {
        let func = editor.func();
        if !func.nodes[id.idx()].is_reduce() {
            continue;
        }

        let mut first_control = None;
        let mut last_reduce = id;
        let mut chain_id = id;

        // Walk down Reduce chain until we reach the Reduce potentially looping
        // with the Write. Note the control node of the first Reduce, since this
        // will tell us which Thread ID to look for in the Write.
        while let Node::Reduce {
            control,
            init: _,
            reduct,
        } = func.nodes[chain_id.idx()]
        {
            if first_control.is_none() {
                first_control = Some(control);
            }

            last_reduce = chain_id;
            chain_id = reduct;
        }

        // If the use is a phi that uses the reduce and a write, then we might
        // want to parallelize this still. Set the chain ID to the write.
        if let Node::Phi {
            control: _,
            ref data,
        } = func.nodes[chain_id.idx()]
            && data.len()
                == data
                    .into_iter()
                    .filter(|phi_use| **phi_use == last_reduce)
                    .count()
                    + 1
        {
            chain_id = *data
                .into_iter()
                .filter(|phi_use| **phi_use != last_reduce)
                .next()
                .unwrap();
        }

        // Check for a Write-Reduce tight cycle.
        if let Node::Write {
            collect,
            data,
            indices,
        } = &func.nodes[chain_id.idx()]
            && *collect == last_reduce
            && !reduce_cycles[&last_reduce].contains(data)
        {
            let is_parallel = indices_parallel_over_forks(
                editor,
                indices,
                once(join_fork_map[&first_control.unwrap()]),
            );

            if is_parallel {
                editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
            }
        }
    }
}

/*
 * Infer monoid reduction loops. Exactly one of the associative operation's
 * operands must be the Reduce node, and all other operands must not be in the
 * Reduce node's cycle.
 */
#[rustfmt::skip]
pub fn infer_monoid_reduce(
    editor: &mut FunctionEditor,
    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
    let is_binop_monoid = |op| {
        op == BinaryOperator::Add
            || op == BinaryOperator::Mul
            || op == BinaryOperator::Or
            || op == BinaryOperator::And
    };
    let is_intrinsic_monoid =
        |intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min;

    for id in editor.node_ids() {
        let func = editor.func();
        if let Node::Reduce {
            control: _,
            init: _,
            reduct,
        } = func.nodes[id.idx()]
            && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op }
                if ((left == id && !reduce_cycles[&id].contains(&right)) ||
                    (right == id && !reduce_cycles[&id].contains(&left))) &&
                    is_binop_monoid(op))
                || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
                if (args.contains(&id) && is_intrinsic_monoid(*intrinsic) && 
                    args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg)))))
        {
            editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce));
        }
    }
}

/*
 * From analysis result of which constants don't need to be reset, add schedules
 * to those constant nodes.
 */
pub fn infer_no_reset_constants(
    editor: &mut FunctionEditor,
    no_reset_constants: &BTreeSet<NodeID>,
) {
    for id in no_reset_constants {
        editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant));
    }
}