Skip to content
Snippets Groups Projects
Commit f764a8aa authored by Neta Zmora's avatar Neta Zmora
Browse files

Adjust Jupyter notebooks to interface change in apputils.load_data API

parent 9e787238
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Activation Histograms
This notebook shows an example of how to generate activation histograms for a specific model and dataset.
## But I Already Know How To Generate Histograms...
If you already generated histograms using Distiller outside this notebook, you can still use it to visualize the data:
* To load the raw data saved by Distiller and visualize it, go to [this section](#Plot-Histograms)
* If enabled saving histogram images and want to view them, go to [this section](#Load-Histogram-Images-from-Disk)
%% Cell type:code id: tags:
``` python
%matplotlib inline
import torch
import matplotlib.pyplot as plt
import os
import math
import torchnet as tnt
from ipywidgets import widgets, interact
import distiller
from distiller.models import create_model
device = torch.device('cuda')
# device = torch.device('cpu')
# Load some common code and configure logging
# We do this so we can see the logging output coming from
# Distiller function calls
%run './distiller_jupyter_helpers.ipynb'
msglogger = config_notebooks_logger()
```
%% Cell type:markdown id: tags:
## Load Your Model
For this example we'll use a pre-trained image classification model.
### Note on Parallelism
Currently, Distiller's implementation of activations histograms collection does not accept models which contain [`DataParallel`](https://pytorch.org/docs/stable/nn.html?highlight=dataparallel#torch.nn.DataParallel) modules. So here we create the model without parallelism to begin with. If you have a model which includes `DataParallel` modules (for example, if loaded from a checkpoint), use the following utlity function to convert the model to serialized execution:
```python
model = distiller.utils.make_non_parallel_copy(model)
```
%% Cell type:code id: tags:
``` python
model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=False)
model = model.to(device) # Comment out if not applicable
```
%% Cell type:markdown id: tags:
## Prepare Data
Usually it is not required to collect histograms based on the entire dataset, and only a representative subset is used (that also helps reduce the runtime).
* **Subset size:** There is no golden rule for selecting the size of the subset. Anywhere between 1-10% of the validation/test set should work.
* **Representative data:** Whatever size is chosen, it is important to make sure that the subset is selected in a way that covers as much of the distribution of the data as possible. So, for example, if the dataset is organized by classes by default, we should make sure to select items randomly and not in order.
**Note:** Working on only a subset of the data can be taken care of at data preparation time, or it can be delayed to the actual model evaluation function (for example, executing only a specific number of mini-batches). In this example we take care of it during data preparation.
%% Cell type:code id: tags:
``` python
# We use Distiller's built-in data loading functionality for ImageNet,
# which takes care of randomizing the data before selecting the subset.
# While it creates train, validation and test data loaders, we're only
# interested in the test dataset in this example.
#
# Subset size: Here we'll go with 1% of the test set, mostly for the
# sake of speed. We control this with the 'effective_test_size' argument.
#
# We set the 'fixed_subset' argument to make sure we're using the
# same subset for both phases of histogram collection - more on that below
dataset = 'imagenet'
dataset_path = '~/datasets/imagenet'
dataset_path = '/datasets/imagenet'
arch = 'resnet18'
batch_size = 256
num_workers = 10
subset_size = 0.01
_, _, test_loader, _ = distiller.apputils.load_data(
dataset, os.path.expanduser(dataset_path), batch_size, num_workers,
dataset, arch, os.path.expanduser(dataset_path),
batch_size, num_workers,
effective_test_size=subset_size, fixed_subset=True)
```
%% Cell type:markdown id: tags:
## Define the Model Evaluation Function
We define a fairly bare-bones evaluation function. Recording the loss and accuracy isn't strictly necessary for histogram collection. We record them nonetheless, so we can verify the data subset being used achieves results that are on par from what we'd expect from a representative subset.
%% Cell type:code id: tags:
``` python
def eval_model(data_loader, model):
print('Evaluating model')
criterion = torch.nn.CrossEntropyLoss().to(device)
loss = tnt.meter.AverageValueMeter()
classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))
total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size
total_steps = math.ceil(total_samples / batch_size)
print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))
# Switch to evaluation mode
model.eval()
for step, (inputs, target) in enumerate(data_loader):
print('[{:3d}/{:3d}] ... '.format(step + 1, total_steps), end='', flush=True)
with torch.no_grad():
inputs, target = inputs.to(device), target.to(device)
# compute output from model
output = model(inputs)
# compute loss and measure accuracy
loss.add(criterion(output, target).item())
classerr.add(output.data, target)
print('Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(
classerr.value(1), classerr.value(5), loss.mean), flush=True)
print('----------')
print('Overall ==> Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(
classerr.value(1), classerr.value(5), loss.mean), flush=True)
```
%% Cell type:markdown id: tags:
## Collect Histograms
Histogram collection is implemented using Distiller's "Collector" mechanism, specifically in the `ActivationHistogramsCollector` class. It is stats-based, meaning it requires pre-computed min/max values per-tensor to be provided.
The min/max stats are expected as a dictionary with the following structure:
```YAML
'layer_name':
'inputs':
0:
'min': value
'max': value
...
n:
'min': value
'max': value
'output':
'min': value
'max': value
```
Where n is the number of inputs the layer has. The `QuantCalibrationStatsCollector` collector class generates stats in the required format.
To streamline this process, a utility function is provided: `distiller.data_loggers.collect_histograms`. Given a model and a test function, it will perform the required stats collection followed by histograms collection. If the user has already computed min/max stats beforehand, those can provided as a dict or as a path to a YAML file (as saved by `QuantCalibrationStatsCollector`). In that case, the stats collection pass will be skipped.
### Dataset Perparation in Context of Stats-Based Histograms
If the data used for min/max stats collection is not the same as the data used for histogram collection, it is highly likely that when collecting histograms some values will fall outside the pre-calculated min/max range. When that happens, the value is **clamped**. Assuming the subsets of data used in both cases are representative enough, this shouldn't have a major effect on the results.
One can choose to avoid this issue by making sure we use the same subset of data in both passes. How to make sure of that will, of course, differ from one use case to another. In this example we do this by using the enabling `fixed_subset` flag when calling `load_data` above.
%% Cell type:code id: tags:
``` python
# The test function passed to 'collect_histograms' must have an
# argument named 'model' which accepts the model for which histograms
# are to be collected. 'collect_histograms' will not set any other
# arguments.
# We'll use Python's 'partial' to handle the set the rest of the
# arguments for the test function before calling 'collect_histograms'
from functools import partial
test_fn = partial(eval_model, data_loader=test_loader)
# Histogram collection parameters
# 'save_dir': Pass a valid directory path to have the histogram
# data saved to disk. Pass None to disable saving.
# 'save_hist_imgs': If save_dir is not None, toggles whether to save
# histogram images in addition to the raw data
# 'hist_imgs_ext': Controls the filetype for histogram images
save_dir = '.'
save_hist_imgs = True
hist_imgs_ext = '.png'
# 'activation_stats': Here we pass None so a stats collection pass
# is performed.
activation_stats = None
# 'classes': To speed-up the calculation here we use the 'classes'
# argument so that stats and histograms are collected only for
# ReLU layers in the model. Pass None to collect for all layers.
classes = [torch.nn.ReLU]
# 'nbins': Number of histogram bins to use.
nbins = 2048
hist_dict = distiller.data_loggers.collect_histograms(
model, test_fn, save_dir=save_dir, activation_stats=activation_stats,
classes=classes, nbins=nbins, save_hist_imgs=save_hist_imgs, hist_imgs_ext=hist_imgs_ext)
```
%% Cell type:markdown id: tags:
## Plot Histograms
The generated dictionary has the following structure (very similar to the structure of the min/max stats dictionary described above):
```yaml
'layer_name':
'inputs':
0:
'hist': tensor # Tensor with bin counts
'bin_centroids': tensor # Tensor with activation values corresponding to center of each bin
...
n:
'hist': tensor
'bin_centroids': tensor
'output':
'hist': tensor
'bin_centroids': tensor
```
%% Cell type:code id: tags:
``` python
# Uncomment this line to load saved output from a previous histogram collection run
# hist_dict = torch.load('acts_histograms.pt')
plt.style.use('seaborn') # pretty matplotlib plots
def draw_hist(layer_name, tensor_name, bin_counts, bin_centroids, normed=True, yscale='linear'):
if normed:
bin_counts = bin_counts / bin_counts.sum()
plt.figure(figsize=(12, 6))
plt.title('\n'.join((layer_name, tensor_name)), fontsize=16)
plt.fill_between(bin_centroids, bin_counts, step='mid', antialiased=False)
if yscale == 'linear':
plt.ylim(bottom=0)
plt.yscale(yscale)
plt.xlabel('Activation Value')
plt.ylabel('Normalized Count')
plt.show()
@interact(layer_name=hist_dict.keys(),
normalize_bin_counts=True,
y_axis_scale=['linear', 'log'])
def draw_layer(layer_name, normalize_bin_counts, y_axis_scale):
print('\nSelected layer: ' + layer_name)
data = hist_dict[layer_name]
for idx, od in data['inputs'].items():
draw_hist(layer_name, 'input_{}'.format(idx), od['hist'], od['bin_centroids'],
normed=normalize_bin_counts, yscale=y_axis_scale)
od = data['output']
draw_hist(layer_name, 'output', od['hist'], od['bin_centroids'],
normed=normalize_bin_counts, yscale=y_axis_scale)
```
%% Cell type:markdown id: tags:
## Load Histogram Images from Disk
If you enabled saving of histogram images above, or have images from a collection executed externally, you can use the code below to display the images.
%% Cell type:code id: tags:
``` python
from IPython.display import Image, SVG, display
import glob
from collections import OrderedDict
# Set the path to the images directory
imgs_dir = 'histogram_imgs'
files = sorted(glob.glob(os.path.join(imgs_dir, '*.*')))
files = [f for f in files if os.path.isfile(f)]
fnames_map = OrderedDict([(os.path.split(f)[1], f) for f in files])
@interact(file_name=fnames_map)
def load_image(file_name):
if file_name.endswith('.svg'):
display(SVG(filename=file_name))
else:
display(Image(filename=file_name))
```
......
%% Cell type:code id: tags:
``` python
%matplotlib inline
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
# Relative import of code from distiller, w/o installing the package
import os
import sys
import pandas as pd
import distiller
import distiller.models as models
from distiller.apputils import *
```
%% Cell type:markdown id: tags:
## Performance overview
%% Cell type:code id: tags:
``` python
model = models.create_model(pretrained=False, dataset='imagenet', arch='resnet50', parallel=False)
```
%% Cell type:code id: tags:
``` python
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
df = distiller.model_performance_summary(model, dummy_input, batch_size=1)
display(df)
total_macs = df['MACs'].sum()
print("Total MACs: " + "{:,}".format(total_macs))
```
%% Cell type:markdown id: tags:
### Let's take a look at how our compute is distibuted:
%% Cell type:code id: tags:
``` python
print("MAC distribution:")
counts = df['MACs'].value_counts()
print(counts)
```
%% Cell type:markdown id: tags:
### Let's look at which convolutions kernel sizes we're using, and how many instances:
%% Cell type:code id: tags:
``` python
print("Convolution kernel size distribution:")
counts = df['Attrs'].value_counts()
print(counts)
```
%% Cell type:markdown id: tags:
### Let's look at how the MACs are distributed between the layers and the convolution kernel sizes
%% Cell type:code id: tags:
``` python
def get_layer_color(layer_type, attrs):
if layer_type == "Conv2d":
if attrs == 'k=(1, 1)':
return 'tomato'
elif attrs == 'k=(3, 3)':
return 'limegreen'
else:
return 'steelblue'
return 'indigo'
df_compute = df['MACs']
ax = df_compute.plot.bar(figsize=[15,10], title="MACs",
color=[get_layer_color(layer_type, attrs) for layer_type,attrs in zip(df['Type'], df['Attrs'])])
ax.set_xticklabels(df.Name, rotation=90);
```
%% Cell type:markdown id: tags:
### How do the Weights and Feature-maps footprints distribute across the layers:
%% Cell type:code id: tags:
``` python
df['FM volume'] = df['IFM volume'] + df['OFM volume']
df_footprint = df[['FM volume', 'Weights volume']]
ax = df_footprint.plot.bar(figsize=[15,10], title="Footprint");
ax.set_xticklabels(df.Name, rotation=90);
```
%% Cell type:markdown id: tags:
### How the Arithmetic Intensity distributes across the layers:
%% Cell type:code id: tags:
``` python
df_performance = df
df_performance['raw traffic'] = df_footprint['FM volume'] + df_footprint['Weights volume']
df_performance['arithmetic intensity'] = df['MACs'] / df_performance['raw traffic']
df_performance2 = df_performance['arithmetic intensity']
ax = df_performance2.plot.bar(figsize=[15,10], title="Arithmetic Intensity");
ax.set_xticklabels(df.Name, rotation=90);
```
%% Cell type:markdown id: tags:
## ResNet20 channel pruning using SSL
Let's see how many MACs we saved by using SSL to prune filters from ResNet20:
%% Cell type:code id: tags:
``` python
resnet20_dense = models.create_model(pretrained=False, dataset='cifar10', arch='resnet20_cifar', parallel=True)
resnet20_sparse = models.create_model(pretrained=False, dataset='cifar10', arch='resnet20_cifar', parallel=True)
checkpoint_file = "../examples/ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20_finetuned.pth.tar"
load_checkpoint(resnet20_sparse, checkpoint_file);
```
%% Cell type:code id: tags:
``` python
dummy_input = Variable(torch.randn(1, 3, 32, 32), requires_grad=False)
df_dense = distiller.model_performance_summary(resnet20_dense, dummy_input, batch_size=1)
df_sparse = distiller.model_performance_summary(resnet20_sparse, dummy_input, batch_size=1)
dense_macs = df_dense['MACs'].sum()
sparse_macs = df_sparse['MACs'].sum()
print("Dense MACs: " + "{:,}".format(int(dense_macs)))
print("Sparse MACs: " + "{:,}".format(int(sparse_macs)))
print("Saved MACs: %.2f%%" % ((1 - sparse_macs/dense_macs)*100))
```
......
%% Cell type:markdown id: tags:
# Convert Distiller Post-Train Quantization Models to "Native" PyTorch
## Background
As of version 1.3 PyTorch comes with built-in quantization functionality. Details are available [here](https://pytorch.org/docs/stable/quantization.html). Distiller's and PyTorch's implementations are completely unrelated. An advantage of PyTorch built-in quantization is that it offers optimized 8-bit execution on CPU and export to GLOW. PyTorch doesn't offer optimized 8-bit execution on GPU (as of version 1.4).
At the moment we are still keeping Distiller's separate API and implementation, but we've added the capability to convert a **post-training quantization** model created in Distiller to a "Distiller-free" model, comprised entirely of PyTorch built-in quantized modules.
Distiller's quantized layers are actually simulated in FP32. Hence, comparing a Distiller model running on CPU to a PyTorch built-in model, the latter will be significantly faster on CPU. However, a Distiller model on a GPU is still likely to be faster compared to a PyTorch model on CPU. So experimenting with Distiller and converting to PyTorch in the end could be useful. Milage may vary of course, depending on the actual HW setup.
Let's see how the conversion works.
%% Cell type:code id: tags:
``` python
import torch
import matplotlib.pyplot as plt
import os
import math
import torchnet as tnt
from ipywidgets import widgets, interact
from copy import deepcopy
from collections import OrderedDict
import distiller
from distiller.models import create_model
import distiller.quantization as quant
# Load some common code and configure logging
# We do this so we can see the logging output coming from
# Distiller function calls
%run './distiller_jupyter_helpers.ipynb'
msglogger = config_notebooks_logger()
```
%% Cell type:markdown id: tags:
## Create Model
%% Cell type:code id: tags:
``` python
# By default, the model is moved to the GPU and parallelized (wrapped with torch.nn.DataParallel)
# If no GPU is available, a non-parallel model is created on the CPU
model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=True)
```
%% Cell type:markdown id: tags:
## Create Data Loaders
We create separate data loaders for GPU and CPU. Set `batch_size` and `num_workers` to optimal values that match your HW setup.
(Note we reset the seed before creating each data loader, to make sure both loaders consist of the same subset of the test set)
%% Cell type:code id: tags:
``` python
# We use Distiller's built-in data loading functionality for ImageNet
distiller.set_seed(0)
subset_size = 1.0 # To save time, can set to value < 1.0
dataset = 'imagenet'
dataset_path = os.path.expanduser('/data2/datasets/imagenet')
dataset_path = os.path.expanduser('/datasets/imagenet')
arch = 'resnet18'
batch_size_gpu = 256
num_workers_gpu = 10
_, _, test_loader_gpu, _ = distiller.apputils.load_data(
dataset, dataset_path, batch_size_gpu, num_workers_gpu,
dataset, arch, dataset_path,
batch_size_gpu, num_workers_gpu,
effective_test_size=subset_size, fixed_subset=True, test_only=True)
```
%% Cell type:code id: tags:
``` python
distiller.set_seed(0)
batch_size_cpu = 44
num_workers_cpu = 10
_, _, test_loader_cpu, _ = distiller.apputils.load_data(
dataset, dataset_path, batch_size_cpu, num_workers_cpu,
dataset, arch, dataset_path,
batch_size_cpu, num_workers_cpu,
effective_test_size=subset_size, fixed_subset=True, test_only=True)
```
%% Cell type:markdown id: tags:
## Define Evaluation Function
%% Cell type:code id: tags:
``` python
def eval_model(data_loader, model, device, print_freq=10):
print('Evaluating model')
criterion = torch.nn.CrossEntropyLoss().to(device)
loss = tnt.meter.AverageValueMeter()
classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))
total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size
total_steps = math.ceil(total_samples / batch_size)
print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))
# Switch to evaluation mode
model.eval()
for step, (inputs, target) in enumerate(data_loader):
with torch.no_grad():
inputs, target = inputs.to(device), target.to(device)
# compute output from model
output = model(inputs)
# compute loss and measure accuracy
loss.add(criterion(output, target).item())
classerr.add(output.data, target)
if (step + 1) % print_freq == 0:
print('[{:3d}/{:3d}] Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(
step + 1, total_steps, classerr.value(1), classerr.value(5), loss.mean), flush=True)
print('----------')
print('Overall ==> Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(
classerr.value(1), classerr.value(5), loss.mean), flush=True)
```
%% Cell type:markdown id: tags:
## Post-Train Quantize with Distiller
%% Cell type:code id: tags:
``` python
quant_mode = {'activations': 'ASYMMETRIC_UNSIGNED', 'weights': 'SYMMETRIC'}
stats_file = "../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml"
dummy_input = distiller.get_dummy_input(input_shape=model.input_shape)
quantizer = quant.PostTrainLinearQuantizer(
deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,
model_activation_stats=stats_file, overrides=None
)
quantizer.prepare_model(dummy_input)
```
%% Cell type:markdown id: tags:
## Convert to PyTorch Built-In
%% Cell type:code id: tags:
``` python
# Here we trigger the conversion via the Quantizer instance. Later on we show another way which does not
# require the quantizer
pyt_model = quantizer.convert_to_pytorch(dummy_input)
# Note that the converted model is automatically moved to the CPU, regardless
# of the device of the Distiller model
print('Distiller model device:', distiller.model_device(quantizer.model))
print('PyTorch model device:', distiller.model_device(pyt_model))
```
%% Cell type:markdown id: tags:
## Run Evaluation
### Distiller Model on GPU (if available)
%% Cell type:code id: tags:
``` python
if torch.cuda.is_available():
%time eval_model(test_loader_gpu, quantizer.model, 'cuda')
```
%% Cell type:markdown id: tags:
### Distiller Model on CPU
%% Cell type:code id: tags:
``` python
if torch.cuda.is_available():
print('Creating CPU copy of Distiller model')
cpu_model = distiller.make_non_parallel_copy(quantizer.model).cpu()
else:
cpu_model = quantizer.model
%time eval_model(test_loader_cpu, cpu_model, 'cpu', print_freq=60)
```
%% Cell type:markdown id: tags:
### PyTorch model in CPU
We expect the PyTorch model on CPU to be much faster than the Distiller model on CPU
%% Cell type:code id: tags:
``` python
%time eval_model(test_loader_cpu, pyt_model, 'cpu', print_freq=60)
```
%% Cell type:markdown id: tags:
## For the Extra-Curious: Comparing the Models
%% Cell type:markdown id: tags:
1. Distiller takes care of quantizing the inputs within the quantized modules PyTorch quantized modules assume the input is already quantized. Hence, for cases where a module's input is not quantized, we explicitly add a quantization operation for the input. The first layer in the model, `conv1` in ResNet18, is such a case
2. Both Distiller and native PyTorch support fused ReLU. In Distiller, this is somewhat obscurely indicated by the `clip_half_range` attribute inside `output_quant_settings`. In PyTorch, the module type is explicitly `QuantizedConvReLU2d`.
%% Cell type:code id: tags:
``` python
print('conv1\n')
print('DISTILLER:\n{}\n'.format(quantizer.model.module.conv1))
print('PyTorch:\n{}\n'.format(pyt_model.conv1))
```
%% Cell type:markdown id: tags:
Example of internal layers which don't require explicit input quantization:
%% Cell type:code id: tags:
``` python
print('layer1.0.conv1')
print(pyt_model.layer1[0].conv1)
print('\nlayer1.0.add')
print(pyt_model.layer1[0].add)
```
%% Cell type:markdown id: tags:
### Automatic de-quantization <--> quantization in the model
For each quantized module in the Distiller implementation, we quantize the input and de-quantize the output.
So, if the user explicitly sets "internal" modules to run in FP32, this is transparent to the other quantized modules (at the cost of redundant quant-dequant operations).
When converting to PyTorch we remove these redundant operations, and keep just the required ones in case the user explicitly decided to run some modules in FP32.
For an example, consider a ResNet "basic block" with a residual connection that contains a downsampling convolution. Let's see how such a block looks in our fully-quantized, converted model:
%% Cell type:code id: tags:
``` python
print(pyt_model.layer2[0])
```
%% Cell type:markdown id: tags:
We can see all layers are either built-in quantized PyTorch modules, or identity operations representing fused operations. The entire block is quantized, so we don't see any quant-dequnt operations in the middle.
Now let's create a new quantized model, and this time leave the 'downsample' module in FP32:
%% Cell type:code id: tags:
``` python
overrides = OrderedDict(
[('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]
)
new_quantizer = quant.PostTrainLinearQuantizer(
deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,
model_activation_stats=stats_file, overrides=overrides
)
new_quantizer.prepare_model(dummy_input)
new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)
print(new_pyt_model.layer2[0])
```
%% Cell type:markdown id: tags:
We can see a few differences:
1. The `downsample` module now contains a de-quantize op before the actual convolution
2. The `add` module now contains a quantize op before the actual add. Note that the add operation accepts 2 inputs. In this case the first input (index 0) comes from the `conv2` module, which is quantized. The second input (index 1) comes from the `downsample` module, which we kept in FP32. So, we only need to quantized the input at index 1. We can see this is indeed what is happening, by looking at the `ModuleDict` inside the `quant` module, and noticing it has only a single key for index "1".
Let's see how the `add` module would look if we also kept the `conv2` module in FP32:
%% Cell type:code id: tags:
``` python
overrides = OrderedDict(
[('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)])),
('layer2.0.conv2', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]
)
new_quantizer = quant.PostTrainLinearQuantizer(
deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,
model_activation_stats=stats_file, overrides=overrides
)
new_quantizer.prepare_model(dummy_input)
new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)
print(new_pyt_model.layer2[0].add)
```
%% Cell type:markdown id: tags:
We can see that now both inputs to the add module are being quantized.
%% Cell type:markdown id: tags:
## Another API for Conversion
In some cases we don't have the actual quantizer. For example - if the Distiller quantized module was loaded from a checkpoint. In those cases we can call a `distiller.quantization` module-level function (In fact, the Quantizer method we used earlier is a wrapper around this function).
### Save Distiller model to checkpoint
%% Cell type:code id: tags:
``` python
# Save Distiller model to checkpoint and load it
distiller.apputils.save_checkpoint(0, 'resnet18', quantizer.model)
```
%% Cell type:markdown id: tags:
### Load Checkpoint
The model is quantized when the checkpoint is loaded
%% Cell type:code id: tags:
``` python
loaded_model = create_model(False, dataset='imagenet', arch='resnet18', parallel=True)
loaded_model = distiller.apputils.load_lean_checkpoint(loaded_model, 'checkpoint.pth.tar')
```
%% Cell type:markdown id: tags:
### Convert and Evaluate
%% Cell type:code id: tags:
``` python
# Convert
loaded_pyt_model = distiller.quantization.convert_distiller_ptq_model_to_pytorch(loaded_model, dummy_input)
# Run evaluation
%time eval_model(test_loader_cpu, loaded_pyt_model, 'cpu', print_freq=60)
# Cleanup
os.remove('checkpoint.pth.tar')
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment