use std::collections::{HashMap, HashSet};
use std::fmt::{Error, Write};

use hercules_ir::*;

//use egg::*;

use crate::*;

pub fn rewrite_math_expressions(
    editor: &mut FunctionEditor,
    device: Device,
    typing: &Vec<TypeID>,
    fork_join_map: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
) {
    panic!("PANIC: The rewrite math expressions pass is currently disabled, as including egg increases compile times and we're not using it currently.");
}

/*
define_language! {
    enum MathLanguage {
        "zero" = Zero,
        "one" = One,

        ForkDim(i64),
        "tid" = ThreadID(Id),
        "sum" = SumReduction(Box<[Id]>),
        "array" = Comprehension(Box<[Id]>),

        "read" = Read(Box<[Id]>),

        "+" = Add([Id; 2]),
        "*" = Mul([Id; 2]),

        "library_gemm" = LibraryGemm([Id; 2]),

        Opaque(Symbol),
    }
}

fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
    vec![
        rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
        rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
        rewrite!("mul-one"; "(* one ?a)" => "?a"),
        rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
    ]
}

pub fn rewrite_math_expressions(
    editor: &mut FunctionEditor,
    device: Device,
    typing: &Vec<TypeID>,
    fork_join_map: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
) {
    let join_fork_map: HashMap<_, _> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();

    // Step 1: figure out how many fork-joins each reduce is in. We want to
    // rewrite outer reductions before inner ones.
    let mut depth: HashMap<NodeID, u32> = HashMap::new();
    for (_, inside) in nodes_in_fork_joins {
        for id in inside {
            if editor.func().nodes[id.idx()].is_reduce() {
                *depth.entry(*id).or_default() += 1;
            }
        }
    }
    let mut reduces: Vec<(u32, NodeID)> =
        depth.into_iter().map(|(id, depth)| (depth, id)).collect();
    reduces.sort();

    for (_, reduce) in reduces {
        let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
        let fork = join_fork_map[&join];

        // Step 2: convert the reduce to an expression in egg and rewrite it.
        let mut s = String::new();
        egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap();

        let expr: RecExpr<MathLanguage> = s.parse().unwrap();
        let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph;

        // Step 3: match the smallest expression against patterns for known
        // library functions.
        let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap();
        let mut matches = gemm_pattern.search(&egraph);
        if matches.len() > 0
            && let m = matches.remove(0)
            && m.substs.len() > 0
        {
            let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
            let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
            let call = Node::LibraryCall {
                library_function: LibraryFunction::GEMM,
                args: vec![init, left_id, right_id].into_boxed_slice(),
                ty: typing[reduce.idx()],
                device,
            };
            let success = editor.edit(|mut edit| {
                let call = edit.add_node(call);
                edit.replace_all_uses_where(reduce, call, |id| {
                    !nodes_in_fork_joins[&fork].contains(id)
                })
            });
            if success {
                return;
            }
        }
    }
}

fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID {
    let id = *subst.get(var.parse::<Var>().unwrap()).unwrap();
    let expr = egraph.id_to_expr(id);
    let MathLanguage::Opaque(sym) = expr.last().unwrap() else {
        todo!();
    };
    let sym = sym.as_str();
    assert!(sym.chars().nth(0).unwrap() == 'n');
    NodeID::new(sym[1..].parse().unwrap())
}

fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> {
    match env[id.idx()] {
        MathExpr::Zero(_) => write!(w, "zero"),
        MathExpr::One(_) => write!(w, "one"),
        MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()),
        MathExpr::ThreadID(dim) => write!(w, "{}", dim.0),
        MathExpr::SumReduction(id, ref dims) => {
            write!(w, "(sum")?;
            for dim in dims {
                write!(w, " {}", dim.0)?;
            }
            write!(w, " ")?;
            egg_print_math_expr(id, env, w)?;
            write!(w, ")")
        }
        MathExpr::Comprehension(id, ref dims) => {
            write!(w, "(array")?;
            for dim in dims {
                write!(w, " {}", dim.0)?;
            }
            write!(w, " ")?;
            egg_print_math_expr(id, env, w)?;
            write!(w, ")")
        }
        MathExpr::Read(id, ref pos) => {
            write!(w, "(read")?;
            for pos in pos {
                write!(w, " ")?;
                egg_print_math_expr(*pos, env, w)?;
            }
            write!(w, " ")?;
            egg_print_math_expr(id, env, w)?;
            write!(w, ")")
        }
        MathExpr::Binary(op, left, right) => {
            write!(w, "(")?;
            match op {
                BinaryOperator::Add => write!(w, "+ "),
                BinaryOperator::Mul => write!(w, "* "),
                _ => Err(Error::default()),
            }?;
            egg_print_math_expr(left, env, w)?;
            write!(w, " ")?;
            egg_print_math_expr(right, env, w)?;
            write!(w, ")")
        }
        _ => Err(Error::default()),
    }
}
*/