use crate::ir;
use crate::parser;

use juno_utils::env::Env;
use juno_utils::stringtab::StringTable;

use hercules_ir::ir::{Device, Schedule};

use lrlex::DefaultLexerTypes;
use lrpar::NonStreamingLexer;

use std::fmt;
use std::str::FromStr;

type Location = ((usize, usize), (usize, usize));

pub enum ScheduleCompilerError {
    UndefinedMacro(String, Location),
    NoSuchPass(String, Location),
    IncorrectArguments {
        expected: String,
        actual: usize,
        loc: Location,
    },
}

impl fmt::Display for ScheduleCompilerError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ScheduleCompilerError::UndefinedMacro(name, loc) => write!(
                f,
                "({}, {}) -- ({}, {}): Undefined macro '{}'",
                loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name
            ),
            ScheduleCompilerError::NoSuchPass(name, loc) => write!(
                f,
                "({}, {}) -- ({}, {}): Undefined pass '{}'",
                loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name
            ),
            ScheduleCompilerError::IncorrectArguments {
                expected,
                actual,
                loc,
            } => write!(
                f,
                "({}, {}) -- ({}, {}): Expected {} arguments, found {}",
                loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual
            ),
        }
    }
}

pub fn compile_schedule(
    sched: parser::OperationList,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
) -> Result<ir::ScheduleStmt, ScheduleCompilerError> {
    let mut macrostab = StringTable::new();
    let mut macros = Env::new();

    macros.open_scope();

    Ok(ir::ScheduleStmt::Block {
        body: compile_ops_as_block(sched, lexer, &mut macrostab, &mut macros)?,
    })
}

#[derive(Debug, Clone)]
struct MacroInfo {
    params: Vec<String>,
    selection_name: String,
    def: ir::ScheduleExp,
}

enum Appliable {
    Pass(ir::Pass),
    // DeleteUncalled requires special handling because it changes FunctionIDs, so it is not
    // treated like a pass
    DeleteUncalled,
    Schedule(Schedule),
    Device(Device),
}

impl Appliable {
    // Tests whether a given number of arguments is a valid number of arguments for this
    fn is_valid_num_args(&self, num: usize) -> bool {
        match self {
            Appliable::Pass(pass) => pass.is_valid_num_args(num),
            // Delete uncalled, Schedules, and devices do not take arguments
            _ => num == 0,
        }
    }

    // Returns a description of the number of arguments this requires
    fn valid_arg_nums(&self) -> &'static str {
        match self {
            Appliable::Pass(pass) => pass.valid_arg_nums(),
            _ => "0",
        }
    }
}

impl FromStr for Appliable {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)),
            "array-to-product" | "array-to-prod" | "a2p" => {
                Ok(Appliable::Pass(ir::Pass::ArrayToProduct))
            }
            "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)),
            "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)),
            "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)),
            "clean-monoid-reduces" => Ok(Appliable::Pass(ir::Pass::CleanMonoidReduces)),
            "const-inline" => Ok(Appliable::Pass(ir::Pass::ConstInline)),
            "dce" => Ok(Appliable::Pass(ir::Pass::DCE)),
            "delete-uncalled" => Ok(Appliable::DeleteUncalled),
            "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)),
            "fork-guard-elim" => Ok(Appliable::Pass(ir::Pass::ForkGuardElim)),
            "fork-split" => Ok(Appliable::Pass(ir::Pass::ForkSplit)),
            "forkify" => Ok(Appliable::Pass(ir::Pass::Forkify)),
            "gcm" | "bbs" => Ok(Appliable::Pass(ir::Pass::GCM)),
            "gvn" => Ok(Appliable::Pass(ir::Pass::GVN)),
            "infer-schedules" => Ok(Appliable::Pass(ir::Pass::InferSchedules)),
            "inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
            "ip-sroa" | "interprocedural-sroa" => {
                Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
            }
            "fork-fission-bufferize" | "fork-fission" => {
                Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize))
            }
            "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
            "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
            "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
            "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
            "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
            "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
            "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
            "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
            "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)),
            "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
            "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
            "predication" => Ok(Appliable::Pass(ir::Pass::Predication)),
            "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)),
            "rename" => Ok(Appliable::Pass(ir::Pass::Rename)),
            "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)),
            "rewrite" | "rewrite-math" | "rewrite-math-expressions" => {
                Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions))
            }
            "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)),
            "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
            "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
            "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)),
            "unforkify-one" => Ok(Appliable::Pass(ir::Pass::UnforkifyOne)),
            "fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)),
            "verify" => Ok(Appliable::Pass(ir::Pass::Verify)),
            "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)),
            "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)),
            "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)),

            "print" => Ok(Appliable::Pass(ir::Pass::Print)),

            "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)),
            "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)),
            "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)),

            "monoid" | "associative" => Ok(Appliable::Schedule(Schedule::MonoidReduce)),
            "parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)),
            "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)),
            "no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)),
            "task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)),

            _ => Err(s.to_string()),
        }
    }
}

fn compile_ops_as_block(
    sched: parser::OperationList,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> {
    match sched {
        parser::OperationList::NilStmt() => Ok(vec![]),
        parser::OperationList::FinalExpr(expr) => {
            Ok(vec![compile_exp_as_stmt(expr, lexer, macrostab, macros)?])
        }
        parser::OperationList::ConsStmt(stmt, ops) => {
            let mut res = compile_stmt(stmt, lexer, macrostab, macros)?;
            res.extend(compile_ops_as_block(*ops, lexer, macrostab, macros)?);
            Ok(res)
        }
    }
}

fn compile_stmt(
    stmt: parser::Stmt,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> {
    match stmt {
        parser::Stmt::LetStmt { span: _, var, expr } => {
            let var = lexer.span_str(var).to_string();
            Ok(vec![ir::ScheduleStmt::Let {
                var,
                exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
            }])
        }
        parser::Stmt::LetsStmt {
            span: _,
            vars,
            expr,
        } => {
            let tmp = format!("{}_tmp", macros.uniq());
            Ok(std::iter::once(ir::ScheduleStmt::Let {
                var: tmp.clone(),
                exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
            })
            .chain(vars.into_iter().enumerate().map(|(idx, v)| {
                let var = lexer.span_str(v).to_string();
                ir::ScheduleStmt::Let {
                    var,
                    exp: ir::ScheduleExp::TupleField {
                        lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }),
                        field: idx,
                    },
                }
            }))
            .collect())
        }
        parser::Stmt::AssignStmt { span: _, var, rhs } => {
            let var = lexer.span_str(var).to_string();
            Ok(vec![ir::ScheduleStmt::Assign {
                var,
                exp: compile_exp_as_expr(rhs, lexer, macrostab, macros)?,
            }])
        }
        parser::Stmt::ExprStmt { span: _, exp } => {
            Ok(vec![compile_exp_as_stmt(exp, lexer, macrostab, macros)?])
        }
        parser::Stmt::Fixpoint {
            span: _,
            limit,
            body,
        } => {
            let limit = match limit {
                parser::FixpointLimit::NoLimit { .. } => ir::FixpointLimit::NoLimit(),
                parser::FixpointLimit::StopAfter { span: _, limit } => {
                    ir::FixpointLimit::StopAfter(
                        lexer
                            .span_str(limit)
                            .parse()
                            .expect("Parsing ensures integer"),
                    )
                }
                parser::FixpointLimit::PanicAfter { span: _, limit } => {
                    ir::FixpointLimit::PanicAfter(
                        lexer
                            .span_str(limit)
                            .parse()
                            .expect("Parsing ensures integer"),
                    )
                }
                parser::FixpointLimit::PrintIter { .. } => ir::FixpointLimit::PrintIter(),
            };

            macros.open_scope();
            let body = compile_ops_as_block(*body, lexer, macrostab, macros);
            macros.close_scope();

            Ok(vec![ir::ScheduleStmt::Fixpoint {
                body: Box::new(ir::ScheduleStmt::Block { body: body? }),
                limit,
            }])
        }
        parser::Stmt::MacroDecl { span: _, def } => {
            let parser::MacroDecl {
                name,
                params,
                selection_name,
                def,
            } = def;
            let name = lexer.span_str(name).to_string();
            let macro_id = macrostab.lookup_string(name);

            let selection_name = lexer.span_str(selection_name).to_string();

            let params = params
                .into_iter()
                .map(|s| lexer.span_str(s).to_string())
                .collect();

            let def = compile_macro_def(*def, params, selection_name, lexer, macrostab, macros)?;
            macros.insert(macro_id, def);

            Ok(vec![])
        }
    }
}

fn compile_exp_as_stmt(
    expr: parser::Expr,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<ir::ScheduleStmt, ScheduleCompilerError> {
    match compile_expr(expr, lexer, macrostab, macros)? {
        ExprResult::Expr(exp) => Ok(ir::ScheduleStmt::Let {
            var: "_".to_string(),
            exp,
        }),
        ExprResult::Stmt(stm) => Ok(stm),
    }
}

fn compile_exp_as_expr(
    expr: parser::Expr,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<ir::ScheduleExp, ScheduleCompilerError> {
    match compile_expr(expr, lexer, macrostab, macros)? {
        ExprResult::Expr(exp) => Ok(exp),
        ExprResult::Stmt(stm) => Ok(ir::ScheduleExp::Block {
            body: vec![stm],
            res: Box::new(ir::ScheduleExp::Record { fields: vec![] }),
        }),
    }
}

enum ExprResult {
    Expr(ir::ScheduleExp),
    Stmt(ir::ScheduleStmt),
}

fn compile_expr(
    expr: parser::Expr,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<ExprResult, ScheduleCompilerError> {
    match expr {
        parser::Expr::Function {
            span,
            name,
            args,
            selection,
        } => {
            let func: Appliable = lexer
                .span_str(name)
                .to_lowercase()
                .parse()
                .map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?;

            if !func.is_valid_num_args(args.len()) {
                return Err(ScheduleCompilerError::IncorrectArguments {
                    expected: func.valid_arg_nums().to_string(),
                    actual: args.len(),
                    loc: lexer.line_col(span),
                });
            }

            let mut arg_vals = vec![];
            for arg in args {
                arg_vals.push(compile_exp_as_expr(arg, lexer, macrostab, macros)?);
            }

            let selection = compile_selector(selection, lexer, macrostab, macros)?;

            match func {
                Appliable::Pass(pass) => Ok(ExprResult::Expr(ir::ScheduleExp::RunPass {
                    pass,
                    args: arg_vals,
                    on: selection,
                })),
                Appliable::DeleteUncalled => {
                    Ok(ExprResult::Expr(ir::ScheduleExp::DeleteUncalled {
                        on: selection,
                    }))
                }
                Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule {
                    sched,
                    on: selection,
                })),
                Appliable::Device(device) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddDevice {
                    device,
                    on: selection,
                })),
            }
        }
        parser::Expr::Macro {
            span,
            name,
            args,
            selection,
        } => {
            let name_str = lexer.span_str(name).to_string();
            let macro_id = macrostab.lookup_string(name_str.clone());
            let Some(macro_def) = macros.lookup(&macro_id) else {
                return Err(ScheduleCompilerError::UndefinedMacro(
                    name_str,
                    lexer.line_col(name),
                ));
            };
            let macro_def: MacroInfo = macro_def.clone();
            let MacroInfo {
                params,
                selection_name,
                def,
            } = macro_def;

            if args.len() != params.len() {
                return Err(ScheduleCompilerError::IncorrectArguments {
                    expected: params.len().to_string(),
                    actual: args.len(),
                    loc: lexer.line_col(span),
                });
            }

            // To initialize the macro's arguments, we have to do this in two steps, we first
            // evaluate all of the arguments and store them into new variables, using names that
            // cannot conflict with other values in the program and then we assign those variables
            // to the macro's parameters; this avoids any shadowing issues, for instance:
            // macro![3, x] where macro!'s arguments are named x and y becomes
            //      let #0 = 3; let #1 = x; let x = #0; let y = #1;
            // which has the desired semantics, as opposed to
            //      let x = 3; let y = x;
            let mut arg_eval = vec![];
            let mut arg_setters = vec![];

            for (i, (exp, var)) in args.into_iter().zip(params.into_iter()).enumerate() {
                let tmp = format!("#{}", i);
                arg_eval.push(ir::ScheduleStmt::Let {
                    var: tmp.clone(),
                    exp: compile_exp_as_expr(exp, lexer, macrostab, macros)?,
                });
                arg_setters.push(ir::ScheduleStmt::Let {
                    var,
                    exp: ir::ScheduleExp::Variable { var: tmp },
                });
            }

            // Set the selection
            arg_eval.push(ir::ScheduleStmt::Let {
                var: selection_name,
                exp: ir::ScheduleExp::Selection {
                    selection: compile_selector(selection, lexer, macrostab, macros)?,
                },
            });

            // Combine the evaluation and initialization code
            arg_eval.extend(arg_setters);

            Ok(ExprResult::Expr(ir::ScheduleExp::Block {
                body: arg_eval,
                res: Box::new(def),
            }))
        }
        parser::Expr::Variable { span } => {
            let var = lexer.span_str(span).to_string();
            Ok(ExprResult::Expr(ir::ScheduleExp::Variable { var }))
        }
        parser::Expr::Integer { span } => {
            let val: usize = lexer.span_str(span).parse().expect("Parsing");
            Ok(ExprResult::Expr(ir::ScheduleExp::Integer { val }))
        }
        parser::Expr::Boolean { span: _, val } => {
            Ok(ExprResult::Expr(ir::ScheduleExp::Boolean { val }))
        }
        parser::Expr::String { span } => {
            let string = lexer.span_str(span);
            let val = string[1..string.len() - 1].to_string();
            Ok(ExprResult::Expr(ir::ScheduleExp::String { val }))
        }
        parser::Expr::Field {
            span: _,
            lhs,
            field,
        } => {
            let field = lexer.span_str(field).to_string();
            let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
            Ok(ExprResult::Expr(ir::ScheduleExp::Field {
                collect: Box::new(lhs),
                field,
            }))
        }
        parser::Expr::BlockExpr { span: _, body } => {
            compile_ops_as_expr(*body, lexer, macrostab, macros)
        }
        parser::Expr::Record { span: _, fields } => {
            let mut result = vec![];
            for (name, expr) in fields {
                let name = lexer.span_str(name).to_string();
                let expr = compile_exp_as_expr(expr, lexer, macrostab, macros)?;
                result.push((name, expr));
            }
            Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result }))
        }
        parser::Expr::SetOp {
            span: _,
            op,
            lhs,
            rhs,
        } => {
            let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
            let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?;
            Ok(ExprResult::Expr(ir::ScheduleExp::SetOp {
                op,
                lhs: Box::new(lhs),
                rhs: Box::new(rhs),
            }))
        }
        parser::Expr::Tuple { span: _, exps } => {
            let exprs = exps
                .into_iter()
                .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros))
                .fold(Ok(vec![]), |mut res, exp| {
                    let mut res = res?;
                    res.push(exp?);
                    Ok(res)
                })?;
            Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs }))
        }
        parser::Expr::TupleField {
            span: _,
            lhs,
            field,
        } => {
            let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
            let field = lexer.span_str(field).parse().expect("Parsing");
            Ok(ExprResult::Expr(ir::ScheduleExp::TupleField {
                lhs: Box::new(lhs),
                field,
            }))
        }
    }
}

fn compile_ops_as_expr(
    mut sched: parser::OperationList,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<ExprResult, ScheduleCompilerError> {
    let mut body = vec![];
    loop {
        match sched {
            parser::OperationList::NilStmt() => {
                return Ok(ExprResult::Stmt(ir::ScheduleStmt::Block { body }));
            }
            parser::OperationList::FinalExpr(expr) => {
                return Ok(ExprResult::Expr(ir::ScheduleExp::Block {
                    body,
                    res: Box::new(compile_exp_as_expr(expr, lexer, macrostab, macros)?),
                }));
            }
            parser::OperationList::ConsStmt(stmt, ops) => {
                body.extend(compile_stmt(stmt, lexer, macrostab, macros)?);
                sched = *ops;
            }
        }
    }
}

fn compile_selector(
    sel: parser::Selector,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<ir::Selector, ScheduleCompilerError> {
    match sel {
        parser::Selector::SelectAll { span: _ } => Ok(ir::Selector::Everything()),
        parser::Selector::SelectExprs { span: _, exprs } => {
            let mut res = vec![];
            for exp in exprs {
                res.push(compile_exp_as_expr(exp, lexer, macrostab, macros)?);
            }
            Ok(ir::Selector::Selection(res))
        }
    }
}

fn compile_macro_def(
    body: parser::OperationList,
    params: Vec<String>,
    selection_name: String,
    lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
    macrostab: &mut StringTable,
    macros: &mut Env<usize, MacroInfo>,
) -> Result<MacroInfo, ScheduleCompilerError> {
    // FIXME: The body should be checked in an environment that prohibits running anything on
    // everything (*) and check that only local variables/parameters are used
    Ok(MacroInfo {
        params,
        selection_name,
        def: match compile_ops_as_expr(body, lexer, macrostab, macros)? {
            ExprResult::Expr(expr) => expr,
            ExprResult::Stmt(stmt) => ir::ScheduleExp::Block {
                body: vec![stmt],
                res: Box::new(ir::ScheduleExp::Record { fields: vec![] }),
            },
        },
    })
}