import os
from copy import deepcopy
from pathlib import Path
from shutil import move
from typing import Union

import distiller
import torch
import yaml
from distiller.data_loggers import collect_quant_stats
from distiller.quantization import PostTrainLinearQuantizer
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

PathLike = Union[str, Path]
STATS_FILENAME = "acts_quantization_stats.yaml"
QUANT_FILENAME = "layer_quant_params.yaml"
QUANT_AFTER_FILENAME = "quant_stats_after_prepare_model.yaml"
LAYER_HPVM_NAME = {
    nn.ReLU: "relu",
    nn.Linear: "gemm",
    nn.Conv2d: "conv",
    nn.MaxPool2d: "pool",
    nn.Softmax: "softmax",
    nn.Parameter: "add",
}
LAYER_DISTILLER_NAME = {
    nn.ReLU: "softmax",  # All point-wise layers use softmax's scale!
    nn.Linear: "fcs",
    nn.Conv2d: "convs",
    nn.MaxPool2d: "softmax",
    nn.Softmax: "softmax",
}


def quantize(
    model: nn.Module,
    dataset: Dataset,
    strat: str = "NONE",
    working_dir: PathLike = ".",
    output_name: str = "calib.txt",
    eval_batchsize: int = 128,
):
    # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
    print("Quantizing...")
    dataloader = DataLoader(dataset, batch_size=eval_batchsize)

    # Collect Pre Quantization Stats
    distiller.utils.assign_layer_fq_names(model)

    working_dir = Path(working_dir)
    stats_file = (working_dir / STATS_FILENAME).as_posix()
    if not os.path.isfile(stats_file):
        # generates `stats_file`
        collect_quant_stats(
            model, get_loss, dataloader, save_dir=working_dir
        )

    # Generate Quantized Scales
    new_model = deepcopy(model)
    quantizer = PostTrainLinearQuantizer(
        new_model,
        model_activation_stats=stats_file,
        mode="SYMMETRIC",
        bits_activations=8,
        bits_accum=32,
        clip_acts=strat,
    )
    dummy_input = torch.rand(1, 3, 32, 32)
    # generates QUANT_FILENAME and QUANT_AFTER_FILENAME in current dir
    quantizer.prepare_model(dummy_input)
    # Let's move it to our working dir
    move(QUANT_FILENAME, working_dir / QUANT_FILENAME)
    # We don't need QUANT_AFTER_FILENAME, remove it
    Path(QUANT_AFTER_FILENAME).unlink()

    print(f"Quantization process finished; accuracy {evaluate(new_model, dataloader)}%")
    # converts .yaml file stats to hpvm standard
    generate_calib_file(model, working_dir, working_dir / output_name)
    return working_dir / output_name


def generate_calib_file(model: nn.Module, working_dir: Path, output_file: Path):
    print("Generating calibration file...")
    with open(working_dir / QUANT_FILENAME, "r") as stream:
        scale_data = yaml.safe_load(stream)

    lines = []
    # add scales for input
    # fmt: off
    input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0]
    # fmt: on
    input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127
    lines.append(f"input:\t{input_scale}\n")

    layer_count = {
        nn.ReLU: 0,
        nn.Linear: 0,
        nn.Conv2d: 0,
        nn.MaxPool2d: 0,
        nn.Softmax: 0,
        nn.Parameter: 0,
    }
    # add scales for layers
    quant_params = scale_data["linear_quant_params"]
    for name, layer in model.named_modules():
        scale_key = f"{name}.output_scale"
        if scale_key not in quant_params:
            continue
        layer_scale = 1 / quant_params[scale_key]

        hpvm_name = LAYER_HPVM_NAME[type(layer)]
        layer_idx = layer_count[type(layer)]
        layer_count[type(layer)] += 1
        lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n")

        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            # include 'add' scale
            add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter]
            add_idx = layer_count[nn.Parameter]
            layer_count[nn.Parameter] += 1
            lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n")

    with open(output_file, "w") as f:
        f.writelines(lines)
    print(f"Calibration file generated to {output_file}")


@torch.no_grad()
def get_loss(model: nn.Module, dataloader: DataLoader):
    from torch.nn import functional as F

    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    for batch in dataloader:
        data, targets = batch
        output = model(data)
        total_loss += len(data) * F.cross_entropy(output, targets)
    return total_loss / len(dataloader)


@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader = None):
    model.eval()
    correct = 0
    total = 0
    for data in dataloader:
        images, labels = data[0], data[1]
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return 100 * correct / total