import os from copy import deepcopy from pathlib import Path from typing import Union from shutil import move import distiller import torch from torch.utils.data.dataset import Dataset import yaml from distiller.data_loggers import collect_quant_stats from distiller.quantization import PostTrainLinearQuantizer from torch import nn from torch.utils.data import DataLoader PathLike = Union[str, Path] STATS_FILENAME = "acts_quantization_stats.yaml" QUANT_FILENAME = "layer_quant_params.yaml" QUANT_AFTER_FILENAME = "quant_stats_after_prepare_model.yaml" LAYER_HPVM_NAME = { nn.ReLU: "relu", nn.Linear: "gemm", nn.Conv2d: "conv", nn.MaxPool2d: "pool", nn.Softmax: "softmax", nn.Parameter: "add", } LAYER_DISTILLER_NAME = { nn.ReLU: "softmax", # All point-wise layers use softmax's scale! nn.Linear: "fcs", nn.Conv2d: "convs", nn.MaxPool2d: "softmax", nn.Softmax: "softmax", } def quantize( model: nn.Module, dataset: Dataset, strat: str = "NONE", working_dir: PathLike = ".", output_name: str = "calib.txt", eval_batchsize: int = 128 ): # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE'] print("Quantizing...") dataloader = DataLoader(dataset, batch_size=eval_batchsize) # Collect Pre Quantization Stats distiller.utils.assign_layer_fq_names(model) working_dir = Path(working_dir) stats_file = (working_dir / STATS_FILENAME).as_posix() if not os.path.isfile(stats_file): # generates `stats_file` collect_quant_stats( model, get_loss, dataloader, save_dir=working_dir ) # Generate Quantized Scales new_model = deepcopy(model) quantizer = PostTrainLinearQuantizer( new_model, model_activation_stats=stats_file, mode="SYMMETRIC", bits_activations=8, bits_accum=32, clip_acts=strat, ) dummy_input = torch.rand(1, 3, 32, 32) # generates QUANT_FILENAME and QUANT_AFTER_FILENAME in current dir quantizer.prepare_model(dummy_input) # Let's move it to our working dir move(QUANT_FILENAME, working_dir / QUANT_FILENAME) # We don't need QUANT_AFTER_FILENAME, remove it Path(QUANT_AFTER_FILENAME).unlink() print(f"Quantization process finished; accuracy {evaluate(new_model, dataloader)}%") # converts .yaml file stats to hpvm standard generate_calib_file(model, working_dir, working_dir / output_name) return working_dir / output_name def generate_calib_file(model: nn.Module, working_dir: Path, output_file: Path): print("Generating calibration file...") with open(working_dir / QUANT_FILENAME, "r") as stream: scale_data = yaml.safe_load(stream) lines = [] # add scales for input # fmt: off input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0] # fmt: on input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127 lines.append(f"input:\t{input_scale}\n") layer_count = { nn.ReLU: 0, nn.Linear: 0, nn.Conv2d: 0, nn.MaxPool2d: 0, nn.Softmax: 0, nn.Parameter: 0, } # add scales for layers quant_params = scale_data["linear_quant_params"] for name, layer in model.named_modules(): scale_key = f"{name}.output_scale" if scale_key not in quant_params: continue layer_scale = 1 / quant_params[scale_key] hpvm_name = LAYER_HPVM_NAME[type(layer)] layer_idx = layer_count[type(layer)] layer_count[type(layer)] += 1 lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n") if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): # include 'add' scale add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter] add_idx = layer_count[nn.Parameter] layer_count[nn.Parameter] += 1 lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n") with open(output_file, "w") as f: f.writelines(lines) print(f"Calibration file generated to {output_file}") @torch.no_grad() def get_loss(model: nn.Module, dataloader: DataLoader): from torch.nn import functional as F # Turn on evaluation mode which disables dropout. model.eval() total_loss = 0 for batch in dataloader: data, targets = batch output = model(data) total_loss += len(data) * F.cross_entropy(output, targets) return total_loss / len(dataloader) @torch.no_grad() def evaluate(model: nn.Module, dataloader: DataLoader = None): model.eval() correct = 0 total = 0 for data in dataloader: images, labels = data[0], data[1] outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return 100 * correct / total