Skip to content
Snippets Groups Projects
quantizer.py 3.99 KiB
Newer Older
  • Learn to ignore specific revisions
  • import os
    from copy import deepcopy
    from pathlib import Path
    from typing import Union
    
    import distiller
    import torch
    import yaml
    from distiller.data_loggers import collect_quant_stats
    from distiller.quantization import PostTrainLinearQuantizer
    from torch import nn
    from torch.utils.data import DataLoader
    
    from .datasets import CIFAR
    from .miniera import MiniERA
    
    PathLike = Union[str, Path]
    STATS_FILENAME = "acts_quantization_stats.yaml"
    QUANT_FILENAME = "layer_quant_params.yaml"
    LAYER_HPVM_NAME = {
        nn.ReLU: "relu",
        nn.Linear: "gemm",
        nn.Conv2d: "conv",
        nn.MaxPool2d: "pool",
        nn.Softmax: "softmax",
        nn.Parameter: "add",
    }
    LAYER_DISTILLER_NAME = {
        nn.Linear: "fcs",
        nn.Conv2d: "convs",
        nn.Softmax: "softmax",
    }
    
    
    def quantize(
        dataset_path: PathLike,
        model_chkpt: PathLike,
        strat: str = "NONE",
        output: PathLike = "calib.txt",
    ):
        # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
        print("Quantizing...")
        dataset_path = Path(dataset_path)
        dataset = CIFAR(dataset_path / "input.bin", dataset_path / "labels.bin")
        dataloader = DataLoader(dataset, batch_size=1)
    
        # Load Model
        model = MiniERA()
        model.load_state_dict(torch.load(model_chkpt))
    
        # Collect Pre Quantization Stats
        distiller.utils.assign_layer_fq_names(model)
    
        if not os.path.isfile(STATS_FILENAME):
            # generates STATS_FILENAME
            collect_quant_stats(
                model, lambda model: evaluate(model, dataloader), save_dir="."
            )
    
        # Generate Quantized Scales
        quantizer = PostTrainLinearQuantizer(
            deepcopy(model),
            model_activation_stats=STATS_FILENAME,
            mode="SYMMETRIC",
            bits_activations=8,
            bits_accum=32,
            clip_acts=strat,
        )
        dummy_input = torch.rand(1, 3, 32, 32)
        quantizer.prepare_model(dummy_input)
        quantizer.save_per_layer_parameters()
    
        print("Quantization process finished.")
        # converts .yaml file stats to hpvm standard
        generate_calib_file(model, output)
    
    
    def generate_calib_file(model: MiniERA, output: PathLike):
        print("Generating calibration file...")
        with open(QUANT_FILENAME, "r") as stream:
            scale_data = yaml.safe_load(stream)
    
        lines = []
        # add scales for input
        # fmt: off
        input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0]
        # fmt: on
        input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127
        lines.append(f"input:\t{input_scale}\n")
    
        # because of definition of miniera
        layers = [*model.convs, *model.fcs, model.softmax]
        layer_count = {
            nn.ReLU: 0,
            nn.Linear: 0,
            nn.Conv2d: 0,
            nn.MaxPool2d: 0,
            nn.Softmax: 0,
            nn.Parameter: 0,
        }
        # add scales for layers
        for layer in layers:
            hpvm_name = LAYER_HPVM_NAME[type(layer)]
            distiller_typename = LAYER_DISTILLER_NAME[type(layer)]
            layer_idx = layer_count[type(layer)]
            layer_count[type(layer)] += 1
    
            scale_key = f"{distiller_typename}.{layer_idx}.output_scale"
            layer_scale = 1 / scale_data["linear_quant_params"][scale_key]
            lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n")
    
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                # include 'add' scale
                add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter]
                add_idx = layer_count[nn.Parameter]
                layer_count[nn.Parameter] += 1
                lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n")
    
        with open(output, "w+") as f:
            f.writelines(lines)
        print(f"Calibration file generated to {output}")
    
    
    @torch.no_grad()
    def evaluate(model: MiniERA, dataloader: DataLoader):
        from torch.nn import functional as F
    
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_loss = 0
        for batch in dataloader:
            data, targets = batch
            output = model(data)
            total_loss += len(data) * F.cross_entropy(output, targets)
        return total_loss / len(dataloader)