Skip to content
Snippets Groups Projects
quantizer.py 4.81 KiB
Newer Older
  • Learn to ignore specific revisions
  • import os
    from copy import deepcopy
    from pathlib import Path
    from typing import Union
    
    from shutil import move
    
    
    import distiller
    import torch
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.utils.data.dataset import Dataset
    
    import yaml
    from distiller.data_loggers import collect_quant_stats
    from distiller.quantization import PostTrainLinearQuantizer
    from torch import nn
    from torch.utils.data import DataLoader
    
    from .datasets import CIFAR
    
    PathLike = Union[str, Path]
    STATS_FILENAME = "acts_quantization_stats.yaml"
    QUANT_FILENAME = "layer_quant_params.yaml"
    
    QUANT_AFTER_FILENAME = "quant_stats_after_prepare_model.yaml"
    
    LAYER_HPVM_NAME = {
        nn.ReLU: "relu",
        nn.Linear: "gemm",
        nn.Conv2d: "conv",
        nn.MaxPool2d: "pool",
        nn.Softmax: "softmax",
        nn.Parameter: "add",
    }
    LAYER_DISTILLER_NAME = {
    
        nn.ReLU: "softmax",  # All point-wise layers use softmax's scale!
    
        nn.Linear: "fcs",
        nn.Conv2d: "convs",
    
        nn.MaxPool2d: "softmax",
    
        nn.Softmax: "softmax",
    }
    
    
    def quantize(
    
        model: nn.Module,
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        dataset: Dataset,
    
        strat: str = "NONE",
    
        working_dir: PathLike = ".",
        output_name: str = "calib.txt",
    
    ):
        # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
        print("Quantizing...")
        dataloader = DataLoader(dataset, batch_size=1)
    
        # Collect Pre Quantization Stats
        distiller.utils.assign_layer_fq_names(model)
    
    
        working_dir = Path(working_dir)
        stats_file = (working_dir / STATS_FILENAME).as_posix()
        if not os.path.isfile(stats_file):
            # generates `stats_file`
    
            collect_quant_stats(
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                model, lambda model: get_loss(model, dataloader), save_dir=working_dir
    
            )
    
        # Generate Quantized Scales
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        new_model = deepcopy(model)
    
        quantizer = PostTrainLinearQuantizer(
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            new_model,
    
            model_activation_stats=stats_file,
    
            mode="SYMMETRIC",
            bits_activations=8,
            bits_accum=32,
            clip_acts=strat,
        )
        dummy_input = torch.rand(1, 3, 32, 32)
    
        # generates QUANT_FILENAME and QUANT_AFTER_FILENAME in current dir
    
        quantizer.prepare_model(dummy_input)
    
        # Let's move it to our working dir
        move(QUANT_FILENAME, working_dir / QUANT_FILENAME)
        # We don't need QUANT_AFTER_FILENAME, remove it
        Path(QUANT_AFTER_FILENAME).unlink()
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        print(f"Quantization process finished; accuracy {evaluate(new_model, dataloader)}%")
    
        # converts .yaml file stats to hpvm standard
    
        generate_calib_file(model, working_dir, working_dir / output_name)
        return working_dir / output_name
    
    def generate_calib_file(model: nn.Module, working_dir: Path, output_file: Path):
    
        print("Generating calibration file...")
    
        with open(working_dir / QUANT_FILENAME, "r") as stream:
    
            scale_data = yaml.safe_load(stream)
    
        lines = []
        # add scales for input
        # fmt: off
        input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0]
        # fmt: on
        input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127
        lines.append(f"input:\t{input_scale}\n")
    
        layer_count = {
            nn.ReLU: 0,
            nn.Linear: 0,
            nn.Conv2d: 0,
            nn.MaxPool2d: 0,
            nn.Softmax: 0,
            nn.Parameter: 0,
        }
        # add scales for layers
    
        quant_params = scale_data["linear_quant_params"]
        for name, layer in model.named_modules():
            scale_key = f"{name}.output_scale"
            if scale_key not in quant_params:
                continue
            layer_scale = 1 / quant_params[scale_key]
    
    
            hpvm_name = LAYER_HPVM_NAME[type(layer)]
            layer_idx = layer_count[type(layer)]
            layer_count[type(layer)] += 1
            lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n")
    
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                # include 'add' scale
                add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter]
                add_idx = layer_count[nn.Parameter]
                layer_count[nn.Parameter] += 1
                lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n")
    
    
        with open(output_file, "w") as f:
    
            f.writelines(lines)
    
        print(f"Calibration file generated to {output_file}")
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    def get_loss(model: nn.Module, dataloader: DataLoader):
    
        from torch.nn import functional as F
    
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_loss = 0
        for batch in dataloader:
            data, targets = batch
            output = model(data)
            total_loss += len(data) * F.cross_entropy(output, targets)
        return total_loss / len(dataloader)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    
    
    @torch.no_grad()
    def evaluate(model: nn.Module, dataloader: DataLoader):
        model.eval()
        correct = 0
        total = 0
        for data in dataloader:
            images, labels = data[0], data[1]
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        return 100 * correct / total