Skip to content
Snippets Groups Projects
Unverified Commit 9405679f authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Activation Histograms (#254)

Added a collector for activation histograms (sub-class of
ActivationStatsCollector). It is stats-based, meaning it requires
pre-computed min/max stats per tensor. This is done in order to prevent
the need to save all of the activation tensors throughout the run.
The stats are expected in the format generated by
QuantCalibrationStatsCollector.

Details:

* Implemented ActivationHistogramsCollector
* Added Jupyter notebook showcasing activation histograms
* Implemented helper function that performs the stats collection pass
  and histograms pass in one go
* Also added separate helper function just for quantization stats
  collection
* Integrated in image classification sample
* data_loaders.py: Added option to have a fixed subset throughout
  within the same session. Using it to keep the same subset between
  the stats collection and histograms collection phases.
* Other changes:
  * Calling assign_layer_fq_names in base-class of collectors. We do
    this since the collectors, as implemented so far, assume this is
    done. So makes sense to just do it in the base class instead of
    expecting the user to do it.
  * Enforcing a non-parallel model for quantization stats and
    histograms collectors
  * Jupyter notebooks - add utility function to enable loggers in
    notebooks. This allows us to see any logging done by Distiller
    APIs called from notebooks.
parent f1f0d753
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,6 @@
This code will help with the image classification datasets: ImageNet and CIFAR10
"""
import logging
import os
import torch
import torchvision.transforms as transforms
......@@ -29,14 +28,12 @@ import numpy as np
import distiller
msglogger = logging.getLogger()
DATASETS_NAMES = ['imagenet', 'cifar10']
def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False,
effective_train_size=1., effective_valid_size=1., effective_test_size=1.):
effective_train_size=1., effective_valid_size=1., effective_test_size=1.,
fixed_subset=False):
"""Load a dataset.
Args:
......@@ -50,13 +47,16 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete
effective_train/valid/test_size: portion of the datasets to load on each epoch.
The subset is chosen randomly each time. For the training and validation sets, this is applied AFTER
the split to those sets according to the validation_split parameter
fixed_subset: set to True to keep the same subset of data throughout the run (the size of the subset
is still determined according to the effective_train/valid/test_size args)
"""
if dataset not in DATASETS_NAMES:
raise ValueError('load_data does not support dataset %s" % dataset')
datasets_fn = cifar10_get_datasets if dataset == 'cifar10' else imagenet_get_datasets
return get_data_loaders(datasets_fn, data_dir, batch_size, workers, validation_split=validation_split,
deterministic=deterministic, effective_train_size=effective_train_size,
effective_valid_size=effective_valid_size, effective_test_size=effective_test_size)
effective_valid_size=effective_valid_size, effective_test_size=effective_test_size,
fixed_subset=fixed_subset)
def cifar10_get_datasets(data_dir):
......@@ -155,11 +155,9 @@ class SwitchingSubsetRandomSampler(Sampler):
data_source (Dataset): dataset to sample from
subset_size (float): value in (0..1], representing the portion of dataset to sample at each enumeration.
"""
def __init__(self, data_source, subset_size):
if subset_size <= 0 or subset_size > 1:
raise ValueError('subset_size must be in (0..1]')
def __init__(self, data_source, effective_size):
self.data_source = data_source
self.subset_length = int(np.floor(len(self.data_source) * subset_size))
self.subset_length = _get_subset_length(data_source, effective_size)
def __iter__(self):
# Randomizing in the same way as in torch.utils.data.sampler.SubsetRandomSampler to maintain
......@@ -172,8 +170,23 @@ class SwitchingSubsetRandomSampler(Sampler):
return self.subset_length
def _get_subset_length(data_source, effective_size):
if effective_size <= 0 or effective_size > 1:
raise ValueError('effective_size must be in (0..1]')
return int(np.floor(len(data_source) * effective_size))
def _get_sampler(data_source, effective_size, fixed_subset=False):
if fixed_subset:
subset_length = _get_subset_length(data_source, effective_size)
indices = torch.randperm(len(data_source))
subset_indices = indices[:subset_length]
return torch.utils.data.SubsetRandomSampler(subset_indices)
return SwitchingSubsetRandomSampler(data_source, effective_size)
def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_split=0.1, deterministic=False,
effective_train_size=1., effective_valid_size=1., effective_test_size=1.):
effective_train_size=1., effective_valid_size=1., effective_test_size=1., fixed_subset=False):
train_dataset, test_dataset = datasets_fn(data_dir)
worker_init_fn = None
......@@ -192,7 +205,7 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_
valid_indices, train_indices = __split_list(indices, validation_split)
train_sampler = SwitchingSubsetRandomSampler(train_indices, effective_train_size)
train_sampler = _get_sampler(train_indices, effective_train_size, fixed_subset)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=True,
......@@ -200,14 +213,14 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_
valid_loader = None
if valid_indices:
valid_sampler = SwitchingSubsetRandomSampler(valid_indices, effective_valid_size)
valid_sampler = _get_sampler(valid_indices, effective_valid_size, fixed_subset)
valid_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=True,
worker_init_fn=worker_init_fn)
test_indices = list(range(len(test_dataset)))
test_sampler = SwitchingSubsetRandomSampler(test_indices, effective_test_size)
test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size, sampler=test_sampler,
num_workers=num_workers, pin_memory=True)
......
......@@ -14,7 +14,8 @@
# limitations under the License.
#
from functools import partial
from functools import partial, reduce
import operator
import xlsxwriter
import yaml
import os
......@@ -25,11 +26,14 @@ import torch
from torchnet.meter import AverageValueMeter
import logging
from math import sqrt
import matplotlib.pyplot as plt
import distiller
msglogger = logging.getLogger()
__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector',
'QuantCalibrationStatsCollector', 'collector_context', 'collectors_context']
'QuantCalibrationStatsCollector', 'ActivationHistogramsCollector',
'collect_quant_stats', 'collect_histograms',
'collector_context', 'collectors_context']
class ActivationStatsCollector(object):
......@@ -47,13 +51,7 @@ class ActivationStatsCollector(object):
ActivationStatsCollector uses the forward hook of modules in order to access the
feature-maps. This is both slow and limits us to seeing only the outputs of torch.Modules.
We can remove some of the slowness, by choosing to log only specific layers or use it only
during validation or test. By default, we only log torch.nn.ReLU activations.
The layer names are mangled, because torch.Modules don't have names and we need to invent
a unique name per layer. To assign human-readable names, it is advisable to invoke the following
before starting the statistics collection:
distiller.utils.assign_layer_fq_names(model)
during validation or test. This can be achieved using the `classes` argument.
"""
def __init__(self, model, stat_name, classes):
"""
......@@ -72,6 +70,10 @@ class ActivationStatsCollector(object):
self.classes = classes
self.fwd_hook_handles = []
# The layer names are mangled, because torch.Modules don't have names and we need to invent
# a unique, human-readable name per layer.
distiller.utils.assign_layer_fq_names(model)
def value(self):
"""Return a dictionary containing {layer_name: statistic}"""
activation_stats = OrderedDict()
......@@ -316,6 +318,13 @@ class _QuantStatsRecord(object):
self.output = self.create_records_dict()
def _verify_no_dataparallel(model):
if torch.nn.DataParallel in [type(m) for m in model.modules()]:
raise ValueError('Model contains DataParallel modules, which can cause inaccurate stats collection. '
'Either create a model without DataParallel modules, or call '
'distiller.utils.make_non_parallel_copy on the model before invoking the collector')
class QuantCalibrationStatsCollector(ActivationStatsCollector):
"""
This class tracks activations stats required for quantization, for each layer and for each input
......@@ -324,7 +333,25 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
* Average min / max (calculate min / max per sample and average those)
* Overall mean
* Overall standard-deviation
Calculated stats are saved to a YAML file.
The generated stats dict has the following structure per-layer:
'layer_name':
'inputs':
0:
'min': value
'max': value
...
...
n:
'min': value
'max': value
...
'output':
'min': value
'max': value
...
Where n is the number of inputs the layer has.
The calculated stats can be saved to a YAML file.
If a certain layer operates in-place, that layer's input stats will be overwritten by its output stats.
The collector can, optionally, check for such cases at runtime. In addition, a simple mechanism to disable inplace
......@@ -350,6 +377,9 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
def __init__(self, model, classes=None, inplace_runtime_check=False,
disable_inplace_attrs=False, inplace_attr_names=('inplace',)):
super(QuantCalibrationStatsCollector, self).__init__(model, "quant_stats", classes)
_verify_no_dataparallel(model)
self.batch_idx = 0
self.inplace_runtime_check = inplace_runtime_check
......@@ -456,6 +486,251 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
return fname
class ActivationHistogramsCollector(ActivationStatsCollector):
"""
This class collects activation histograms, for each layer and for each input and output tensor.
It requires pre-computed min/max stats per tensor. This is done in order to prevent the need to save
all of the activation tensors throughout the run. The histogram is created once according to these
min/max values, and updated after each iteration. Any value outside the pre-computed range is clamped.
The generated stats dict has the following structure per-layer:
'layer_name':
'inputs':
0:
'hist': tensor # Tensor with bin counts
'bin_centroids': tensor # Tensor with activation values corresponding to center of each bin
...
n:
'hist': tensor
'bin_centroids': tensor
'output':
'hist': tensor
'bin_centroids': tensor
Where n is the number of inputs the layer has.
The generated stats dictionary can be saved to a file.
Optionally, histogram images for all tensor can be saved as well
Args:
model (torch.nn.Module): The model we are monitoring
activation_stats (str / dict): Either a path to activation stats YAML file, or a dictionary containing
the stats. The stats are expected to be in the same structure as generated by QuantCalibrationStatsCollector.
classes (list): List of class types for which we collect activation statistics. Passing an empty list or
None will collect statistics for all class types.
nbins (int): Number of histogram bins
save_hist_imgs (bool): If set, calling save() will dump images of the histogram plots in addition to saving the
stats dictionary
hist_imgs_ext (str): The file type to be used when saving histogram images
"""
def __init__(self, model, activation_stats, classes=None, nbins=2048,
save_hist_imgs=False, hist_imgs_ext='.svg'):
super(ActivationHistogramsCollector, self).__init__(model, 'histogram', classes)
_verify_no_dataparallel(model)
if isinstance(activation_stats, str):
if not os.path.isfile(activation_stats):
raise ValueError("Model activation stats file not found at: " + activation_stats)
msglogger.info('Loading activation stats from: ' + activation_stats)
with open(activation_stats, 'r') as stream:
activation_stats = distiller.utils.yaml_ordered_load(stream)
elif not isinstance(activation_stats, (dict, OrderedDict)):
raise TypeError('model_activation_stats must either be a string, a dict / OrderedDict or None')
self.act_stats = activation_stats
self.nbins = nbins
self.save_imgs = save_hist_imgs
self.imgs_ext = hist_imgs_ext if hist_imgs_ext[0] == '.' else '.' + hist_imgs_ext
def _get_min_max(self, *keys):
stats_entry = reduce(operator.getitem, keys, self.act_stats)
return stats_entry['min'], stats_entry['max']
def _activation_stats_cb(self, module, inputs, output):
def get_hist(t, stat_min, stat_max):
t_clamped = t.clamp(stat_min, stat_max)
hist = torch.histc(t_clamped.cpu(), bins=self.nbins, min=stat_min, max=stat_max)
return hist
for idx, input in enumerate(inputs):
stat_min, stat_max = self._get_min_max(module.distiller_name, 'inputs', idx)
curr_hist = get_hist(input, stat_min, stat_max)
module.input_hists[idx] += curr_hist
stat_min, stat_max = self._get_min_max(module.distiller_name, 'output')
curr_hist = get_hist(output, stat_min, stat_max)
module.output_hist += curr_hist
def _reset(self, module):
num_inputs = len(self.act_stats[module.distiller_name]['inputs'])
module.input_hists = module.input_hists = [torch.zeros(self.nbins) for _ in range(num_inputs)]
module.output_hist = torch.zeros(self.nbins)
def _start_counter(self, module):
self._reset(module)
def _reset_counter(self, module):
if hasattr(module, 'output_hist'):
self._reset(module)
def _collect_activations_stats(self, module, activation_stats, name=''):
if distiller.utils.has_children(module):
return
if not hasattr(module, 'output_hist'):
return
def get_hist_entry(min_val, max_val, hist):
od = OrderedDict()
od['hist'] = hist
bin_width = (max_val - min_val) / self.nbins
od['bin_centroids'] = torch.linspace(min_val + bin_width / 2, max_val - bin_width / 2, self.nbins)
return od
stats_od = OrderedDict()
inputs_od = OrderedDict()
for idx, hist in enumerate(module.input_hists):
inputs_od[idx] = get_hist_entry(*self._get_min_max(module.distiller_name, 'inputs', idx),
module.input_hists[idx])
output_od = get_hist_entry(*self._get_min_max(module.distiller_name, 'output'), module.output_hist)
stats_od['inputs'] = inputs_od
stats_od['output'] = output_od
activation_stats[module.distiller_name] = stats_od
def save(self, fname):
hist_dict = self.value()
if not fname.endswith('.pt'):
fname = ".".join([fname, 'pt'])
try:
os.remove(fname)
except OSError:
pass
torch.save(hist_dict, fname)
if self.save_imgs:
msglogger.info('Saving histogram images...')
save_dir = os.path.join(os.path.split(fname)[0], 'histogram_imgs')
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
def save_hist(layer_name, tensor_name, idx, bin_counts, bin_centroids, normed=True):
if normed:
bin_counts = bin_counts / bin_counts.sum()
plt.figure(figsize=(12, 12))
plt.suptitle('\n'.join((layer_name, tensor_name)), fontsize=18, fontweight='bold')
for subplt_idx, yscale in enumerate(['linear', 'log']):
plt.subplot(2, 1, subplt_idx + 1)
plt.fill_between(bin_centroids, bin_counts, step='mid', antialiased=False)
if yscale == 'linear':
plt.ylim(bottom=0)
plt.title(yscale + ' scale')
plt.yscale(yscale)
plt.xlabel('Activation Value')
plt.ylabel('Normalized Count')
plt.tight_layout(rect=[0, 0, 1, 0.93])
idx_str = '{:03d}'.format(idx)
plt.savefig(os.path.join(save_dir, '-'.join((idx_str, layer_name, tensor_name)) + self.imgs_ext))
plt.close()
cnt = 0
for layer_name, data in hist_dict.items():
for idx, od in data['inputs'].items():
cnt += 1
save_hist(layer_name, 'input_{}'.format(idx), cnt, od['hist'], od['bin_centroids'], normed=True)
od = data['output']
cnt += 1
save_hist(layer_name, 'output', cnt, od['hist'], od['bin_centroids'], normed=True)
msglogger.info('Done')
return fname
def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False,
disable_inplace_attrs=False, inplace_attr_names=('inplace',)):
"""
Helper function for collecting quantization calibration statistics for a model using QuantCalibrationStatsCollector
Args:
model (nn.Module): The model for which to collect stats
test_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that
accepts the model. All other arguments should be set in advance (can be done using functools.partial), or
they will be left with their default values.
save_dir (str): Path to directory where stats YAML file will be saved. If None then YAML will not be saved
to disk.
classes (iterable): See QuantCalibrationStatsCollector
inplace_runtime_check (bool): See QuantCalibrationStatsCollector
disable_inplace_attrs (bool): See QuantCalibrationStatsCollector
inplace_attr_names (iterable): See QuantCalibrationStatsCollector
Returns:
Dictionary with quantization stats (see QuantCalibrationStatsCollector for a description of the dictionary
contents)
"""
msglogger.info('Collecting quantization calibration stats for model')
quant_stats_collector = QuantCalibrationStatsCollector(model, classes=classes,
inplace_runtime_check=inplace_runtime_check,
disable_inplace_attrs=disable_inplace_attrs,
inplace_attr_names=inplace_attr_names)
with collector_context(quant_stats_collector):
test_fn(model=model)
msglogger.info('Stats collection complete')
if save_dir is not None:
save_path = os.path.join(save_dir, 'acts_quantization_stats.yaml')
quant_stats_collector.save(save_path)
msglogger.info('Stats saved to ' + save_path)
return quant_stats_collector.value()
def collect_histograms(model, test_fn, save_dir=None, activation_stats=None,
classes=None, nbins=2048, save_hist_imgs=False, hist_imgs_ext='.svg'):
"""
Helper function for collecting activation histograms for a model using ActivationsHistogramCollector.
Will perform 2 passes - one to collect the required stats and another to collect the histograms. The first
pass can be skipped by passing pre-calculated stats.
Args:
model (nn.Module): The model for which to collect histograms
test_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that
accepts the model. All other arguments should be set in advance (can be done using functools.partial), or
they will be left with their default values.
save_dir (str): Path to directory where histograms will be saved. If None then data will not be saved to disk.
activation_stats (str / dict / None): Either a path to activation stats YAML file, or a dictionary containing
the stats. The stats are expected to be in the same structure as generated by QuantCalibrationStatsCollector.
If None, then a stats collection pass will be performed.
classes: See ActivationsHistogramCollector
nbins: See ActivationsHistogramCollector
save_hist_imgs: See ActivationsHistogramCollector
hist_imgs_ext: See ActivationsHistogramCollector
Returns:
Dictionary with histograms data (See ActivationsHistogramCollector for a description of the dictionary
contents)
"""
msglogger.info('Pass 1: Stats collection')
if activation_stats is not None:
msglogger.info('Pre-computed activation stats passed, skipping stats collection')
else:
activation_stats = collect_quant_stats(model, test_fn, save_dir=save_dir, classes=classes,
inplace_runtime_check=True, disable_inplace_attrs=True)
msglogger.info('Pass 2: Histograms generation')
histogram_collector = ActivationHistogramsCollector(model, activation_stats, classes=classes, nbins=nbins,
save_hist_imgs=save_hist_imgs, hist_imgs_ext=hist_imgs_ext)
with collector_context(histogram_collector):
test_fn(model=model)
msglogger.info('Histograms generation complete')
if save_dir is not None:
save_path = os.path.join(save_dir, 'acts_histograms.pt')
histogram_collector.save(save_path)
msglogger.info("Histogram data saved to " + save_path)
if save_hist_imgs:
msglogger.info('Histogram images saved in ' + os.path.join(save_dir, 'histogram_imgs'))
return histogram_collector.value()
@contextmanager
def collector_context(collector):
"""A context manager for an activation collector"""
......
......@@ -198,25 +198,18 @@ def main():
if args.summary:
return summarize_model(model, args.dataset, which_summary=args.summary)
activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)
if args.qe_calibration:
msglogger.info('Quantization calibration stats collection enabled:')
msglogger.info('\tStats will be collected for {:.1%} of test dataset'.format(args.qe_calibration))
msglogger.info('\tSetting constant seeds and converting model to serialized execution')
distiller.set_deterministic()
model = distiller.make_non_parallel_copy(model)
activations_collectors.update(create_quantization_stats_collector(model))
args.evaluate = True
args.effective_test_size = args.qe_calibration
return acts_quant_stats_collection(model, criterion, pylogger, args)
if args.activation_histograms:
return acts_histogram_collection(model, criterion, pylogger, args)
activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)
# Load the datasets: the dataset to load is inferred from the model name passed
# in args.arch. The default dataset is ImageNet, but if args.arch contains the
# substring "_cifar", then cifar10 is used.
train_loader, val_loader, test_loader, _ = apputils.load_data(
args.dataset, os.path.expanduser(args.data), args.batch_size,
args.workers, args.validation_split, args.deterministic,
args.effective_train_size, args.effective_valid_size, args.effective_test_size)
train_loader, val_loader, test_loader, _ = load_data(args)
msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
......@@ -678,10 +671,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsitie
def automated_deep_compression(model, criterion, optimizer, loggers, args):
train_loader, val_loader, test_loader, _ = apputils.load_data(
args.dataset, os.path.expanduser(args.data), args.batch_size,
args.workers, args.validation_split, args.deterministic,
args.effective_train_size, args.effective_valid_size, args.effective_test_size)
train_loader, val_loader, test_loader, _ = load_data(args)
args.display_confusion = True
validate_fn = partial(test, test_loader=test_loader, criterion=criterion,
......@@ -695,10 +685,7 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):
def greedy(model, criterion, optimizer, loggers, args):
train_loader, val_loader, test_loader, _ = apputils.load_data(
args.dataset, os.path.expanduser(args.data), args.batch_size,
args.workers, args.validation_split, args.deterministic,
args.effective_train_size, args.effective_valid_size, args.effective_test_size)
train_loader, val_loader, test_loader, _ = load_data(args)
test_fn = partial(test, test_loader=test_loader, criterion=criterion,
loggers=loggers, args=args, activations_collectors=None)
......@@ -710,6 +697,37 @@ def greedy(model, criterion, optimizer, loggers, args):
test_fn, train_fn)
def acts_quant_stats_collection(model, criterion, loggers, args):
msglogger.info('Collecting quantization calibration stats based on {:.1%} of test dataset'
.format(args.qe_calibration))
model = distiller.utils.make_non_parallel_copy(model)
args.effective_test_size = args.qe_calibration
train_loader, val_loader, test_loader, _ = load_data(args)
test_fn = partial(test, test_loader=test_loader, criterion=criterion,
loggers=loggers, args=args, activations_collectors=None)
collect_quant_stats(model, test_fn, save_dir=msglogger.logdir,
classes=None, inplace_runtime_check=True, disable_inplace_attrs=True)
def acts_histogram_collection(model, criterion, loggers, args):
msglogger.info('Collecting activation histograms based on {:.1%} of test dataset'
.format(args.activation_histograms))
model = distiller.utils.make_non_parallel_copy(model)
args.effective_test_size = args.activation_histograms
train_loader, val_loader, test_loader, _ = load_data(args, fixed_subset=True)
test_fn = partial(test, test_loader=test_loader, criterion=criterion,
loggers=loggers, args=args, activations_collectors=None)
collect_histograms(model, test_fn, save_dir=msglogger.logdir,
classes=None, nbins=2048, save_hist_imgs=True)
def load_data(args, fixed_subset=False):
return apputils.load_data(args.dataset, os.path.expanduser(args.data), args.batch_size,
args.workers, args.validation_split, args.deterministic,
args.effective_train_size, args.effective_valid_size, args.effective_test_size,
fixed_subset)
class missingdict(dict):
"""This is a little trick to prevent KeyError"""
def __missing__(self, key):
......@@ -729,8 +747,6 @@ def create_activation_stats_collectors(model, *phases):
WARNING! Enabling activation statsitics collection will significantly slow down training!
"""
distiller.utils.assign_layer_fq_names(model)
genCollectors = lambda: missingdict({
"sparsity": SummaryActivationStatsCollector(model, "sparsity",
lambda t: 100 * distiller.utils.sparsity(t)),
......@@ -747,13 +763,6 @@ def create_activation_stats_collectors(model, *phases):
for k in ('train', 'valid', 'test')}
def create_quantization_stats_collector(model):
distiller.utils.assign_layer_fq_names(model)
return {'test': missingdict({'quantization_stats': QuantCalibrationStatsCollector(model, classes=None,
inplace_runtime_check=True,
disable_inplace_attrs=True)})}
def save_collectors_data(collectors, directory):
"""Utility function that saves all activation statistics to disk.
......
......@@ -73,6 +73,11 @@ def get_parser():
parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
help='collect activation statistics on phases: train, valid, and/or test'
' (WARNING: this slows down training)')
parser.add_argument('--activation-histograms', '--act-hist',
type=distiller.utils.float_range_argparse_checker(exc_min=True),
metavar='PORTION_OF_TEST_SET',
help='Run the model in evaluation mode on the specified portion of the test dataset and '
'generate activation histograms. NOTE: This slows down evaluation significantly')
parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
help='print masks sparsity table at end of each epoch')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
......
%% Cell type:markdown id: tags:
# Activation Histograms
This notebook shows an example of how to generate activation histograms for a specific model and dataset.
## But I Already Know How To Generate Histograms...
If you already generated histograms using Distiller outside this notebook, you can still use it to visualize the data:
* To load the raw data saved by Distiller and visualize it, go to [this section](#Plot-Histograms)
* If enabled saving histogram images and want to view them, go to [this section](#Load-Histogram-Images-from-Disk)
%% Cell type:code id: tags:
``` python
import torch
import matplotlib.pyplot as plt
import os
import math
import torchnet as tnt
from ipywidgets import widgets, interact
import distiller
from distiller.models import create_model
device = torch.device('cuda')
# device = torch.device('cpu')
# Load some common code and configure logging
# We do this so we can see the logging output coming from
# Distiller function calls
%run './distiller_jupyter_helpers.ipynb'
msglogger = config_notebooks_logger()
```
%% Cell type:markdown id: tags:
## Load Your Model
For this example we'll use a pre-trained image classification model.
### Note on Parallelism
Currently, Distiller's implementation of activations histograms collection does not accept models which contain [`DataParallel`](https://pytorch.org/docs/stable/nn.html?highlight=dataparallel#torch.nn.DataParallel) modules. So here we create the model without parallelism to begin with. If you have a model which includes `DataParallel` modules (for example, if loaded from a checkpoint), use the following utlity function to convert the model to serialized execution:
```python
model = distiller.utils.make_non_parallel_copy(model)
```
%% Cell type:code id: tags:
``` python
model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=False)
model = model.to(device) # Comment out if not applicable
```
%% Cell type:markdown id: tags:
## Prepare Data
Usually it is not required to collect histograms based on the entire dataset, and only a representative subset is used (that also helps reduce the runtime).
* **Subset size:** There is no golden rule for selecting the size of the subset. Anywhere between 1-10% of the validation/test set should work.
* **Representative data:** Whatever size is chosen, it is important to make sure that the subset is selected in a way that covers as much of the distribution of the data as possible. So, for example, if the dataset is organized by classes by default, we should make sure to select items randomly and not in order.
**Note:** Working on only a subset of the data can be taken care of at data preparation time, or it can be delayed to the actual model evaluation function (for example, executing only a specific number of mini-batches). In this example we take care of it during data preparation.
%% Cell type:code id: tags:
``` python
# We use Distiller's built-in data loading functionality for ImageNet,
# which takes care of randomizing the data before selecting the subset.
# While it creates train, validation and test data loaders, we're only
# interested in the test dataset in this example.
#
# Subset size: Here we'll go with 1% of the test set, mostly for the
# sake of speed. We control this with the 'effective_test_size' argument.
#
# We set the 'fixed_subset' argument to make sure we're using the
# same subset for both phases of histogram collection - more on that below
dataset = 'imagenet'
dataset_path = '~/datasets/imagenet'
batch_size = 256
num_workers = 10
subset_size = 0.01
_, _, test_loader, _ = distiller.apputils.load_data(
dataset, os.path.expanduser(dataset_path), batch_size, num_workers,
effective_test_size=subset_size, fixed_subset=True)
```
%% Cell type:markdown id: tags:
## Define the Model Evaluation Function
We define a fairly bare-bones evaluation function. Recording the loss and accuracy isn't strictly necessary for histogram collection. We record them nonetheless, so we can verify the data subset being used achieves results that are on par from what we'd expect from a representative subset.
%% Cell type:code id: tags:
``` python
def eval_model(data_loader, model):
print('Evaluating model')
criterion = torch.nn.CrossEntropyLoss().to(device)
loss = tnt.meter.AverageValueMeter()
classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))
total_samples = len(data_loader.sampler)
batch_size = data_loader.batch_size
total_steps = math.ceil(total_samples / batch_size)
print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))
# Switch to evaluation mode
model.eval()
for step, (inputs, target) in enumerate(data_loader):
print('[{:3d}/{:3d}] ... '.format(step + 1, total_steps), end='', flush=True)
with torch.no_grad():
inputs, target = inputs.to(device), target.to(device)
# compute output from model
output = model(inputs)
# compute loss and measure accuracy
loss.add(criterion(output, target).item())
classerr.add(output.data, target)
print('Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(
classerr.value(1), classerr.value(5), loss.mean), flush=True)
print('----------')
print('Overall ==> Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(
classerr.value(1), classerr.value(5), loss.mean), flush=True)
```
%% Cell type:markdown id: tags:
## Collect Histograms
Histogram collection is implemented using Distiller's "Collector" mechanism, specifically in the `ActivationHistogramsCollector` class. It is stats-based, meaning it requires pre-computed min/max values per-tensor to be provided.
The min/max stats are expected as a dictionary with the following structure:
```YAML
'layer_name':
'inputs':
0:
'min': value
'max': value
...
n:
'min': value
'max': value
'output':
'min': value
'max': value
```
Where n is the number of inputs the layer has. The `QuantCalibrationStatsCollector` collector class generates stats in the required format.
To streamline this process, a utility function is provided: `distiller.data_loggers.collect_histograms`. Given a model and a test function, it will perform the required stats collection followed by histograms collection. If the user has already computed min/max stats beforehand, those can provided as a dict or as a path to a YAML file (as saved by `QuantCalibrationStatsCollector`). In that case, the stats collection pass will be skipped.
### Dataset Perparation in Context of Stats-Based Histograms
If the data used for min/max stats collection is not the same as the data used for histogram collection, it is highly likely that when collecting histograms some values will fall outside the pre-calculated min/max range. When that happens, the value is **clamped**. Assuming the subsets of data used in both cases are representative enough, this shouldn't have a major effect on the results.
One can choose to avoid this issue by making sure we use the same subset of data in both passes. How to make sure of that will, of course, differ from one use case to another. In this example we do this by using the enabling `fixed_subset` flag when calling `load_data` above.
%% Cell type:code id: tags:
``` python
# The test function passed to 'collect_histograms' must have an
# argument named 'model' which accepts the model for which histograms
# are to be collected. 'collect_histograms' will not set any other
# arguments.
# We'll use Python's 'partial' to handle the set the rest of the
# arguments for the test function before calling 'collect_histograms'
from functools import partial
test_fn = partial(eval_model, data_loader=test_loader)
# Histogram collection parameters
# 'save_dir': Pass a valid directory path to have the histogram
# data saved to disk. Pass None to disable saving.
# 'save_hist_imgs': If save_dir is not None, toggles whether to save
# histogram images in addition to the raw data
# 'hist_imgs_ext': Controls the filetype for histogram images
save_dir = '.'
save_hist_imgs = True
hist_imgs_ext = '.png'
# 'activation_stats': Here we pass None so a stats collection pass
# is performed.
activation_stats = None
# 'classes': To speed-up the calculation here we use the 'classes'
# argument so that stats and histograms are collected only for
# ReLU layers in the model. Pass None to collect for all layers.
classes = [torch.nn.ReLU]
# 'nbins': Number of histogram bins to use.
nbins = 2048
hist_dict = distiller.data_loggers.collect_histograms(
model, test_fn, save_dir=save_dir, activation_stats=activation_stats,
classes=classes, nbins=nbins, save_hist_imgs=save_hist_imgs, hist_imgs_ext=hist_imgs_ext)
```
%% Cell type:markdown id: tags:
## Plot Histograms
The generated dictionary has the following structure (very similar to the structure of the min/max stats dictionary described above):
```yaml
'layer_name':
'inputs':
0:
'hist': tensor # Tensor with bin counts
'bin_centroids': tensor # Tensor with activation values corresponding to center of each bin
...
n:
'hist': tensor
'bin_centroids': tensor
'output':
'hist': tensor
'bin_centroids': tensor
```
%% Cell type:code id: tags:
``` python
# Uncomment this line to load saved output from a previous histogram collection run
# hist_dict = torch.load('acts_histograms.pt')
plt.style.use('seaborn') # pretty matplotlib plots
def draw_hist(layer_name, tensor_name, bin_counts, bin_centroids, normed=True, yscale='linear'):
if normed:
bin_counts = bin_counts / bin_counts.sum()
plt.figure(figsize=(12, 6))
plt.title('\n'.join((layer_name, tensor_name)), fontsize=16)
plt.fill_between(bin_centroids, bin_counts, step='mid', antialiased=False)
if yscale == 'linear':
plt.ylim(bottom=0)
plt.yscale(yscale)
plt.xlabel('Activation Value')
plt.ylabel('Normalized Count')
plt.show()
@interact(layer_name=hist_dict.keys(),
normalize_bin_counts=True,
y_axis_scale=['linear', 'log'])
def draw_layer(layer_name, normalize_bin_counts, y_axis_scale):
print('\nSelected layer: ' + layer_name)
data = hist_dict[layer_name]
for idx, od in data['inputs'].items():
draw_hist(layer_name, 'input_{}'.format(idx), od['hist'], od['bin_centroids'],
normed=normalize_bin_counts, yscale=y_axis_scale)
od = data['output']
draw_hist(layer_name, 'output', od['hist'], od['bin_centroids'],
normed=normalize_bin_counts, yscale=y_axis_scale)
```
%% Cell type:markdown id: tags:
## Load Histogram Images from Disk
If you enabled saving of histogram images above, or have images from a collection executed externally, you can use the code below to display the images.
%% Cell type:code id: tags:
``` python
from IPython.display import Image, SVG, display
import glob
from collections import OrderedDict
# Set the path to the images directory
imgs_dir = 'histogram_imgs'
files = sorted(glob.glob(os.path.join(imgs_dir, '*.*')))
files = [f for f in files if os.path.isfile(f)]
fnames_map = OrderedDict([(os.path.split(f)[1], f) for f in files])
@interact(file_name=fnames_map)
def load_image(file_name):
if file_name.endswith('.svg'):
display(SVG(filename=file_name))
else:
display(Image(filename=file_name))
```
%% Cell type:markdown id: tags:
## Interpreting your pruning and regularization experiments
This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:<br>
```%run distiller_jupyter_helpers.ipynb```
%% Cell type:code id: tags:
``` python
# Relative import of code from distiller, w/o installing the package
import os
import sys
import distiller.utils
import distiller
import distiller.apputils.checkpoint
```
%% Cell type:code id: tags:
``` python
import torch
import torchvision
import os
import collections
import matplotlib.pyplot as plt
import numpy as np
def to_np(x):
return x.cpu().numpy()
def flatten(weights):
weights = weights.clone().view(weights.numel())
weights = to_np(weights)
return weights
import scipy.stats as stats
def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None):
weights = flatten(weights_pytorch)
if remove_zeros:
weights = weights[weights!=0]
n, bins, patches = plt.hist(weights, bins=200)
plt.title(name)
if kmeans is not None:
labels = kmeans.labels_
centroids = kmeans.cluster_centers_
cnt_coefficients = [len(labels[labels==i]) for i in range(16)]
# Normalize the coefficients so they display in the same range as the float32 histogram
cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients]
centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients)))
cnt_coefficients = list(cnt_coefficients)
centroids = list(centroids)
if remove_zeros:
for i in range(len(centroids)):
if abs(centroids[i]) < 0.0001: # almost zero
centroids.remove(centroids[i])
cnt_coefficients.remove(cnt_coefficients[i])
break
plt.plot(centroids, cnt_coefficients)
zeros = [0] * len(centroids)
plt.plot(centroids, zeros, 'r+', markersize=15)
h = cnt_coefficients
hmean = np.mean(h)
hstd = np.std(h)
pdf = stats.norm.pdf(h, hmean, hstd)
#plt.plot(h, pdf)
plt.show()
print("mean: %f\nstddev: %f" % (weights.mean(), weights.std()))
print("size=%s %d elements" % distiller.size2str(weights_pytorch.size()))
print("min: %.3f\nmax:%.3f" % (weights.min(), weights.max()))
def plot_params_hist(params, which='weight', remove_zeros=False):
for name, weights_pytorch in params.items():
if which not in name:
continue
plot_params_hist_single(name, weights_pytorch, remove_zeros)
def plot_params2d(classifier_weights, figsize, binary_mask=True,
gmin=None, gmax=None,
xlabel="", ylabel="", title=""):
if not isinstance(classifier_weights, list):
classifier_weights = [classifier_weights]
for weights in classifier_weights:
assert weights.dim() in [2,4], "something's wrong"
shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights)
# Clone because we are going to change the tensor values
if binary_mask:
weights2d = weights.clone()
else:
weights2d = weights
if weights.dim() == 4:
weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)
sparsity = len(weights2d[weights2d==0]) / volume
# Move to CPU so we can plot it.
if weights2d.is_cuda:
weights2d = weights2d.cpu()
cmap='seismic'
# create a binary image (non-zero elements are black; zeros are white)
if binary_mask:
cmap='binary'
weights2d[weights2d!=0] = 1
fig = plt.figure(figsize=figsize)
if (not binary_mask) and (gmin is not None) and (gmax is not None):
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)
else:
plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)
#plt.figure(figsize=(20,40))
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.colorbar( pad=0.01, fraction=0.01)
plt.show()
print("sparsity = %.1f%% (nnz=black)" % (sparsity*100))
print("size=%s = %d elements" % (shape_str, volume))
def printk(k):
"""Print the values of the elements of a kernel as a list"""
print(list(k.view(k.numel())))
def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model',
gmin=None, gmax=None, interpolation=None, first_kernel=0):
ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3]
print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max()))
shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights)
print("size=%s = %d elements" % (shape_str, volume))
# Clone because we are going to change the tensor values
weights = weights.clone()
if binary_mask:
weights[weights!=0] = 1
# Take the inverse of the pixels, because we want zeros to appear white
#weights = 1 - weights
kernels = weights.view(ofms * ifms, kh, kw)
nrow, ncol = layout[0], layout[1]
# Move to CPU so we can plot it.
if kernels.is_cuda:
kernels = kernels.cpu()
# Plot the graph
plt.gray()
#plt.tight_layout()
fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );
# We want to normalize the grayscale brightness levels for all of the images we display (group),
# otherwise, each image is normalized separately and this causes distortion between the different
# filters images we ddisplay.
# We don't normalize across all of the filters images, because the outliers cause the image of each
# filter to be very muted. This is because each group of filters we display usually has low variance
# between the element values of that group.
if color_normalization=='Tensor':
gmin = weights.min()
gmax = weights.max()
elif color_normalization=='Group':
gmin = weights[0:nrow, 0:ncol].min()
gmax = weights[0:nrow, 0:ncol].max()
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
i = 0
for row in range(0, nrow):
for col in range (0, ncol):
ax = fig.add_subplot(layout[0], layout[1], i+1)
if binary_mask:
ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1);
else:
# Use siesmic so that colors around the center are lighter. Red and blue are used
# to represent (and visually separate) negative and positive weights
ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);
ax.set(xticks=[], yticks=[])
i += 1
def l1_norm_histogram(weights):
"""Compute a histogram of the L1-norms of the kernels of a weights tensor.
The L1-norm of a kernel is one way to quantify the "magnitude" of the total coeffiecients
making up this kernel.
Another interesting look at filters is to compute a histogram per filter.
"""
ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3]
kernels = weights.view(ofms * ifms, kh, kw)
if kernels.is_cuda:
kernels = kernels.cpu()
l1_hist = []
for kernel in range(ofms*ifms):
l1_hist.append(kernels[kernel].norm(1))
return l1_hist
def plot_l1_norm_hist(weights):
l1_hist = l1_norm_histogram(weights)
n, bins, patches = plt.hist(l1_hist, bins=200)
plt.title('Kernel L1-norm histograms')
plt.ylabel('Frequency')
plt.xlabel('Kernel L1-norm')
plt.show()
def plot_layer_sizes(which, sparse_model, dense_model):
dense = []
sparse = []
names = []
for name, sparse_weights in sparse_model.state_dict().items():
if ('weight' not in name) or (which!='*' and which not in name):
continue
sparse.append(len(sparse_weights[sparse_weights!=0]))
names.append(name)
for name, dense_weights in dense_model.state_dict().items():
if ('weight' not in name) or (which!='*' and which not in name):
continue
dense.append(dense_weights.numel())
N = len(sparse)
ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots()
width = .47
p1 = plt.bar(ind, dense, width = .47, color = '#278DBC')
p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099')
plt.ylabel('Size')
plt.title('Layer sizes')
plt.xticks(rotation='vertical')
plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), ('Dense', 'Sparse'))
#Remove plot borders
for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True)
plt.show()
def conv_param_names(model):
return [param_name for param_name, p in model.state_dict().items()
if (p.dim()>2) and ("weight" in param_name)]
def conv_fc_param_names(model):
return [param_name for param_name, p in model.state_dict().items()
if (p.dim()>1) and ("weight" in param_name)]
def conv_fc_params(model):
return [(param_name,p) for (param_name, p) in model.state_dict()
if (p.dim()>1) and ("weight" in param_name)]
def fc_param_names(model):
return [param_name for param_name, p in model.state_dict().items()
if (p.dim()==2) and ("weight" in param_name)]
```
%% Cell type:code id: tags:
``` python
def plot_bars(which, setA, setAName, setB, setBName, names, title):
N = len(setA)
ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots(figsize=(20,10))
width = .47
p1 = plt.bar(ind, setA, width = .47, color = '#278DBC')
p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')
plt.ylabel('Size')
plt.title(title)
plt.xticks(rotation='vertical')
plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), (setAName, setBName))
#Remove plot borders
for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True)
plt.show()
```
%% Cell type:code id: tags:
``` python
import logging
def config_notebooks_logger():
logging.config.fileConfig('logging.conf')
msglogger = logging.getLogger()
msglogger.info('Logging configured successfully')
```
......
[formatters]
keys: simple, time_simple
[handlers]
keys: console
[loggers]
keys: root, app_cfg
[formatter_simple]
format: %(message)s
[formatter_time_simple]
format: %(asctime)s - %(message)s
[handler_console]
class: StreamHandler
propagate: 0
args: []
formatter: simple
[logger_root]
level: INFO
propagate: 1
handlers: console
[logger_app_cfg]
# Use this logger to log the application configuration and execution environment
level: DEBUG
qualname: app_cfg
propagate: 0
handlers: console
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment