-
Yifan Zhao authoredYifan Zhao authored
quantizer.py 4.82 KiB
import os
from copy import deepcopy
from pathlib import Path
from shutil import move
from typing import Union
import distiller
import torch
import yaml
from distiller.data_loggers import collect_quant_stats
from distiller.quantization import PostTrainLinearQuantizer
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
PathLike = Union[str, Path]
STATS_FILENAME = "acts_quantization_stats.yaml"
QUANT_FILENAME = "layer_quant_params.yaml"
QUANT_AFTER_FILENAME = "quant_stats_after_prepare_model.yaml"
LAYER_HPVM_NAME = {
nn.ReLU: "relu",
nn.Linear: "gemm",
nn.Conv2d: "conv",
nn.MaxPool2d: "pool",
nn.Softmax: "softmax",
nn.Parameter: "add",
}
LAYER_DISTILLER_NAME = {
nn.ReLU: "softmax", # All point-wise layers use softmax's scale!
nn.Linear: "fcs",
nn.Conv2d: "convs",
nn.MaxPool2d: "softmax",
nn.Softmax: "softmax",
}
def quantize(
model: nn.Module,
dataset: Dataset,
strat: str = "NONE",
working_dir: PathLike = ".",
output_name: str = "calib.txt",
eval_batchsize: int = 128,
):
# possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
print("Quantizing...")
dataloader = DataLoader(dataset, batch_size=eval_batchsize)
# Collect Pre Quantization Stats
distiller.utils.assign_layer_fq_names(model)
working_dir = Path(working_dir)
stats_file = (working_dir / STATS_FILENAME).as_posix()
if not os.path.isfile(stats_file):
# generates `stats_file`
collect_quant_stats(
model, get_loss, dataloader, save_dir=working_dir
)
# Generate Quantized Scales
new_model = deepcopy(model)
quantizer = PostTrainLinearQuantizer(
new_model,
model_activation_stats=stats_file,
mode="SYMMETRIC",
bits_activations=8,
bits_accum=32,
clip_acts=strat,
)
dummy_input = torch.rand(1, 3, 32, 32)
# generates QUANT_FILENAME and QUANT_AFTER_FILENAME in current dir
quantizer.prepare_model(dummy_input)
# Let's move it to our working dir
move(QUANT_FILENAME, working_dir / QUANT_FILENAME)
# We don't need QUANT_AFTER_FILENAME, remove it
Path(QUANT_AFTER_FILENAME).unlink()
print(f"Quantization process finished; accuracy {evaluate(new_model, dataloader)}%")
# converts .yaml file stats to hpvm standard
generate_calib_file(model, working_dir, working_dir / output_name)
return working_dir / output_name
def generate_calib_file(model: nn.Module, working_dir: Path, output_file: Path):
print("Generating calibration file...")
with open(working_dir / QUANT_FILENAME, "r") as stream:
scale_data = yaml.safe_load(stream)
lines = []
# add scales for input
# fmt: off
input_min_max = scale_data["convs.0"]["model_activation_stats"]["convs.0"]["inputs"][0]
# fmt: on
input_scale = max(abs(input_min_max["min"]), abs(input_min_max["max"])) / 127
lines.append(f"input:\t{input_scale}\n")
layer_count = {
nn.ReLU: 0,
nn.Linear: 0,
nn.Conv2d: 0,
nn.MaxPool2d: 0,
nn.Softmax: 0,
nn.Parameter: 0,
}
# add scales for layers
quant_params = scale_data["linear_quant_params"]
for name, layer in model.named_modules():
scale_key = f"{name}.output_scale"
if scale_key not in quant_params:
continue
layer_scale = 1 / quant_params[scale_key]
hpvm_name = LAYER_HPVM_NAME[type(layer)]
layer_idx = layer_count[type(layer)]
layer_count[type(layer)] += 1
lines.append(f"{hpvm_name}{layer_idx + 1}:\t{layer_scale}\n")
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
# include 'add' scale
add_hpvm_name = LAYER_HPVM_NAME[nn.Parameter]
add_idx = layer_count[nn.Parameter]
layer_count[nn.Parameter] += 1
lines.append(f"{add_hpvm_name}{add_idx + 1}:\t{layer_scale}\n")
with open(output_file, "w") as f:
f.writelines(lines)
print(f"Calibration file generated to {output_file}")
@torch.no_grad()
def get_loss(model: nn.Module, dataloader: DataLoader):
from torch.nn import functional as F
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0
for batch in dataloader:
data, targets = batch
output = model(data)
total_loss += len(data) * F.cross_entropy(output, targets)
return total_loss / len(dataloader)
@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader = None):
model.eval()
correct = 0
total = 0
for data in dataloader:
images, labels = data[0], data[1]
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total