Skip to content
Snippets Groups Projects
quantizer.py 4.82 KiB
import os
from copy import deepcopy
from pathlib import Path
from shutil import move
from typing import Union

import distiller
import torch
import yaml
from distiller.data_loggers import collect_quant_stats
from distiller.quantization import PostTrainLinearQuantizer
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

PathLike = Union[str, Path]
STATS_FILENAME = "acts_quantization_stats.yaml"
QUANT_FILENAME = "layer_quant_params.yaml"
QUANT_AFTER_FILENAME = "quant_stats_after_prepare_model.yaml"
LAYER_HPVM_NAME = {
    nn.ReLU: "relu",
    nn.Linear: "gemm",
    nn.Conv2d: "conv",
    nn.MaxPool2d: "pool",
    nn.Softmax: "softmax",
    nn.Parameter: "add",
}
LAYER_DISTILLER_NAME = {
    nn.ReLU: "softmax",  # All point-wise layers use softmax's scale!
    nn.Linear: "fcs",
    nn.Conv2d: "convs",
    nn.MaxPool2d: "softmax",
    nn.Softmax: "softmax",
}


def quantize(
    model: nn.Module,
    dataset: Dataset,
    strat: str = "NONE",
    working_dir: PathLike = ".",
    output_name: str = "calib.txt",
    eval_batchsize: int = 128,
):
    # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
    print("Quantizing...")
    dataloader = DataLoader(dataset, batch_size=eval_batchsize)

    # Collect Pre Quantization Stats
    distiller.utils.assign_layer_fq_names(model)

    working_dir = Path(working_dir)
    stats_file = (working_dir / STATS_FILENAME).as_posix()
    if not os.path.isfile(stats_file):
        # generates `stats_file`
        collect_quant_stats(
            model, get_loss, dataloader, save_dir=working_dir
        )

    # Generate Quantized Scales
    new_model = deepcopy(model)
    quantizer = PostTrainLinearQuantizer(
        new_model,
        model_activation_stats=stats_file,
        mode="SYMMETRIC",
        bits_activations=8,
        bits_accum=32,
        clip_acts=strat,
    )
    dummy_input = torch.rand(1, 3, 32, 32)
    # generates QUANT_FILENAME and QUANT_AFTER_FILENAME in current dir
    quantizer.prepare_model(dummy_input)
    # Let's move it to our working dir
    move(QUANT_FILENAME, working_dir / QUANT_FILENAME)
    # We don't need QUANT_AFTER_FILENAME, remove it
    Path(QUANT_AFTER_FILENAME).unlink()

    print(f"Quantization process finished; accuracy {evaluate(new_model, dataloader)}%")
    # converts .yaml file stats to hpvm standard
    generate_calib_file(model, working_dir, working_dir / output_name)
    return working_dir / output_name


def generate_calib_file(model: nn.Module, working_dir: Path, output_file: Path):
    print("Generating calibration file...")
    with open(working_dir / QUANT_FILENAME, "r") as stream:
        scale_data = yaml.safe_load(stream)

    lines = []
    # add scales for input
    # fmt: off
    input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0]
    # fmt: on
    input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127
    lines.append(f"input:\t{input_scale}\n")

    layer_count = {
        nn.ReLU: 0,
        nn.Linear: 0,
        nn.Conv2d: 0,
        nn.MaxPool2d: 0,
        nn.Softmax: 0,
        nn.Parameter: 0,
    }
    # add scales for layers
    quant_params = scale_data["linear_quant_params"]
    for name, layer in model.named_modules():
        scale_key = f"{name}.output_scale"
        if scale_key not in quant_params:
            continue
        layer_scale = 1 / quant_params[scale_key]

        hpvm_name = LAYER_HPVM_NAME[type(layer)]
        layer_idx = layer_count[type(layer)]
        layer_count[type(layer)] += 1
        lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n")

        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            # include 'add' scale
            add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter]
            add_idx = layer_count[nn.Parameter]
            layer_count[nn.Parameter] += 1
            lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n")

    with open(output_file, "w") as f:
        f.writelines(lines)
    print(f"Calibration file generated to {output_file}")


@torch.no_grad()
def get_loss(model: nn.Module, dataloader: DataLoader):
    from torch.nn import functional as F

    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    for batch in dataloader:
        data, targets = batch
        output = model(data)
        total_loss += len(data) * F.cross_entropy(output, targets)
    return total_loss / len(dataloader)


@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader = None):
    model.eval()
    correct = 0
    total = 0
    for data in dataloader:
        images, labels = data[0], data[1]
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return 100 * correct / total