Skip to content
Snippets Groups Projects
quantizer.py 4.51 KiB
Newer Older
import os
from copy import deepcopy
from pathlib import Path
from typing import Union
from shutil import move

import distiller
import torch
import yaml
from distiller.data_loggers import collect_quant_stats
from distiller.quantization import PostTrainLinearQuantizer
from torch import nn
from torch.utils.data import DataLoader

from .datasets import CIFAR
from .miniera import MiniERA

PathLike = Union[str, Path]
STATS_FILENAME = "acts_quantization_stats.yaml"
QUANT_FILENAME = "layer_quant_params.yaml"
QUANT_AFTER_FILENAME = "quant_stats_after_prepare_model.yaml"
LAYER_HPVM_NAME = {
    nn.ReLU: "relu",
    nn.Linear: "gemm",
    nn.Conv2d: "conv",
    nn.MaxPool2d: "pool",
    nn.Softmax: "softmax",
    nn.Parameter: "add",
}
LAYER_DISTILLER_NAME = {
    nn.ReLU: "softmax",  # All point-wise layers use softmax's scale!
    nn.Linear: "fcs",
    nn.Conv2d: "convs",
    nn.MaxPool2d: "softmax",
    nn.Softmax: "softmax",
}


def quantize(
    model: nn.Module,
    dataset_path: PathLike,
    strat: str = "NONE",
    working_dir: PathLike = ".",
    output_name: str = "calib.txt",
):
    # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
    print("Quantizing...")
    dataset_path = Path(dataset_path)
    dataset = CIFAR.from_file(dataset_path / "input.bin", dataset_path / "labels.bin")
    dataloader = DataLoader(dataset, batch_size=1)

    # Collect Pre Quantization Stats
    distiller.utils.assign_layer_fq_names(model)

    working_dir = Path(working_dir)
    stats_file = (working_dir / STATS_FILENAME).as_posix()
    if not os.path.isfile(stats_file):
        # generates `stats_file`
        collect_quant_stats(
            model, lambda model: evaluate(model, dataloader), save_dir=working_dir
        )

    # Generate Quantized Scales
    quantizer = PostTrainLinearQuantizer(
        deepcopy(model),
        model_activation_stats=stats_file,
        mode="SYMMETRIC",
        bits_activations=8,
        bits_accum=32,
        clip_acts=strat,
    )
    dummy_input = torch.rand(1, 3, 32, 32)
    # generates QUANT_FILENAME and QUANT_AFTER_FILENAME in current dir
    quantizer.prepare_model(dummy_input)
    # Let's move it to our working dir
    move(QUANT_FILENAME, working_dir / QUANT_FILENAME)
    # We don't need QUANT_AFTER_FILENAME, remove it
    Path(QUANT_AFTER_FILENAME).unlink()

    print("Quantization process finished.")
    # converts .yaml file stats to hpvm standard
    generate_calib_file(model, working_dir, working_dir / output_name)
    return working_dir / output_name
def generate_calib_file(model: nn.Module, working_dir: Path, output_file: Path):
    print("Generating calibration file...")
    with open(working_dir / QUANT_FILENAME, "r") as stream:
        scale_data = yaml.safe_load(stream)

    lines = []
    # add scales for input
    # fmt: off
    input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0]
    # fmt: on
    input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127
    lines.append(f"input:\t{input_scale}\n")

    # because of definition of miniera
    layer_count = {
        nn.ReLU: 0,
        nn.Linear: 0,
        nn.Conv2d: 0,
        nn.MaxPool2d: 0,
        nn.Softmax: 0,
        nn.Parameter: 0,
    }
    # add scales for layers
    quant_params = scale_data["linear_quant_params"]
    for name, layer in model.named_modules():
        scale_key = f"{name}.output_scale"
        if scale_key not in quant_params:
            continue
        layer_scale = 1 / quant_params[scale_key]

        hpvm_name = LAYER_HPVM_NAME[type(layer)]
        layer_idx = layer_count[type(layer)]
        layer_count[type(layer)] += 1
        lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n")

        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            # include 'add' scale
            add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter]
            add_idx = layer_count[nn.Parameter]
            layer_count[nn.Parameter] += 1
            lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n")

    with open(output_file, "w") as f:
        f.writelines(lines)
    print(f"Calibration file generated to {output_file}")


@torch.no_grad()
def evaluate(model: MiniERA, dataloader: DataLoader):
    from torch.nn import functional as F

    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    for batch in dataloader:
        data, targets = batch
        output = model(data)
        total_loss += len(data) * F.cross_entropy(output, targets)
    return total_loss / len(dataloader)