From 9405679f40d32498d7a6a46bbfb17830b470bc99 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Wed, 15 May 2019 14:06:12 +0300 Subject: [PATCH] Activation Histograms (#254) Added a collector for activation histograms (sub-class of ActivationStatsCollector). It is stats-based, meaning it requires pre-computed min/max stats per tensor. This is done in order to prevent the need to save all of the activation tensors throughout the run. The stats are expected in the format generated by QuantCalibrationStatsCollector. Details: * Implemented ActivationHistogramsCollector * Added Jupyter notebook showcasing activation histograms * Implemented helper function that performs the stats collection pass and histograms pass in one go * Also added separate helper function just for quantization stats collection * Integrated in image classification sample * data_loaders.py: Added option to have a fixed subset throughout within the same session. Using it to keep the same subset between the stats collection and histograms collection phases. * Other changes: * Calling assign_layer_fq_names in base-class of collectors. We do this since the collectors, as implemented so far, assume this is done. So makes sense to just do it in the base class instead of expecting the user to do it. * Enforcing a non-parallel model for quantization stats and histograms collectors * Jupyter notebooks - add utility function to enable loggers in notebooks. This allows us to see any logging done by Distiller APIs called from notebooks. --- distiller/apputils/data_loaders.py | 41 +- distiller/data_loggers/collector.py | 295 +++++++++++++- .../compress_classifier.py | 71 ++-- examples/classifier_compression/parser.py | 5 + jupyter/activation_histograms.ipynb | 367 ++++++++++++++++++ jupyter/distiller_jupyter_helpers.ipynb | 8 +- jupyter/logging.conf | 32 ++ 7 files changed, 763 insertions(+), 56 deletions(-) create mode 100644 jupyter/activation_histograms.ipynb create mode 100755 jupyter/logging.conf diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index a6a667d..3cc1a24 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -19,7 +19,6 @@ This code will help with the image classification datasets: ImageNet and CIFAR10 """ -import logging import os import torch import torchvision.transforms as transforms @@ -29,14 +28,12 @@ import numpy as np import distiller - -msglogger = logging.getLogger() - DATASETS_NAMES = ['imagenet', 'cifar10'] def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False, - effective_train_size=1., effective_valid_size=1., effective_test_size=1.): + effective_train_size=1., effective_valid_size=1., effective_test_size=1., + fixed_subset=False): """Load a dataset. Args: @@ -50,13 +47,16 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete effective_train/valid/test_size: portion of the datasets to load on each epoch. The subset is chosen randomly each time. For the training and validation sets, this is applied AFTER the split to those sets according to the validation_split parameter + fixed_subset: set to True to keep the same subset of data throughout the run (the size of the subset + is still determined according to the effective_train/valid/test_size args) """ if dataset not in DATASETS_NAMES: raise ValueError('load_data does not support dataset %s" % dataset') datasets_fn = cifar10_get_datasets if dataset == 'cifar10' else imagenet_get_datasets return get_data_loaders(datasets_fn, data_dir, batch_size, workers, validation_split=validation_split, deterministic=deterministic, effective_train_size=effective_train_size, - effective_valid_size=effective_valid_size, effective_test_size=effective_test_size) + effective_valid_size=effective_valid_size, effective_test_size=effective_test_size, + fixed_subset=fixed_subset) def cifar10_get_datasets(data_dir): @@ -155,11 +155,9 @@ class SwitchingSubsetRandomSampler(Sampler): data_source (Dataset): dataset to sample from subset_size (float): value in (0..1], representing the portion of dataset to sample at each enumeration. """ - def __init__(self, data_source, subset_size): - if subset_size <= 0 or subset_size > 1: - raise ValueError('subset_size must be in (0..1]') + def __init__(self, data_source, effective_size): self.data_source = data_source - self.subset_length = int(np.floor(len(self.data_source) * subset_size)) + self.subset_length = _get_subset_length(data_source, effective_size) def __iter__(self): # Randomizing in the same way as in torch.utils.data.sampler.SubsetRandomSampler to maintain @@ -172,8 +170,23 @@ class SwitchingSubsetRandomSampler(Sampler): return self.subset_length +def _get_subset_length(data_source, effective_size): + if effective_size <= 0 or effective_size > 1: + raise ValueError('effective_size must be in (0..1]') + return int(np.floor(len(data_source) * effective_size)) + + +def _get_sampler(data_source, effective_size, fixed_subset=False): + if fixed_subset: + subset_length = _get_subset_length(data_source, effective_size) + indices = torch.randperm(len(data_source)) + subset_indices = indices[:subset_length] + return torch.utils.data.SubsetRandomSampler(subset_indices) + return SwitchingSubsetRandomSampler(data_source, effective_size) + + def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_split=0.1, deterministic=False, - effective_train_size=1., effective_valid_size=1., effective_test_size=1.): + effective_train_size=1., effective_valid_size=1., effective_test_size=1., fixed_subset=False): train_dataset, test_dataset = datasets_fn(data_dir) worker_init_fn = None @@ -192,7 +205,7 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_ valid_indices, train_indices = __split_list(indices, validation_split) - train_sampler = SwitchingSubsetRandomSampler(train_indices, effective_train_size) + train_sampler = _get_sampler(train_indices, effective_train_size, fixed_subset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True, @@ -200,14 +213,14 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_ valid_loader = None if valid_indices: - valid_sampler = SwitchingSubsetRandomSampler(valid_indices, effective_valid_size) + valid_sampler = _get_sampler(valid_indices, effective_valid_size, fixed_subset) valid_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=True, worker_init_fn=worker_init_fn) test_indices = list(range(len(test_dataset))) - test_sampler = SwitchingSubsetRandomSampler(test_indices, effective_test_size) + test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers, pin_memory=True) diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 9b0350b..fc96aa0 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -14,7 +14,8 @@ # limitations under the License. # -from functools import partial +from functools import partial, reduce +import operator import xlsxwriter import yaml import os @@ -25,11 +26,14 @@ import torch from torchnet.meter import AverageValueMeter import logging from math import sqrt +import matplotlib.pyplot as plt import distiller msglogger = logging.getLogger() __all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector', - 'QuantCalibrationStatsCollector', 'collector_context', 'collectors_context'] + 'QuantCalibrationStatsCollector', 'ActivationHistogramsCollector', + 'collect_quant_stats', 'collect_histograms', + 'collector_context', 'collectors_context'] class ActivationStatsCollector(object): @@ -47,13 +51,7 @@ class ActivationStatsCollector(object): ActivationStatsCollector uses the forward hook of modules in order to access the feature-maps. This is both slow and limits us to seeing only the outputs of torch.Modules. We can remove some of the slowness, by choosing to log only specific layers or use it only - during validation or test. By default, we only log torch.nn.ReLU activations. - - The layer names are mangled, because torch.Modules don't have names and we need to invent - a unique name per layer. To assign human-readable names, it is advisable to invoke the following - before starting the statistics collection: - - distiller.utils.assign_layer_fq_names(model) + during validation or test. This can be achieved using the `classes` argument. """ def __init__(self, model, stat_name, classes): """ @@ -72,6 +70,10 @@ class ActivationStatsCollector(object): self.classes = classes self.fwd_hook_handles = [] + # The layer names are mangled, because torch.Modules don't have names and we need to invent + # a unique, human-readable name per layer. + distiller.utils.assign_layer_fq_names(model) + def value(self): """Return a dictionary containing {layer_name: statistic}""" activation_stats = OrderedDict() @@ -316,6 +318,13 @@ class _QuantStatsRecord(object): self.output = self.create_records_dict() +def _verify_no_dataparallel(model): + if torch.nn.DataParallel in [type(m) for m in model.modules()]: + raise ValueError('Model contains DataParallel modules, which can cause inaccurate stats collection. ' + 'Either create a model without DataParallel modules, or call ' + 'distiller.utils.make_non_parallel_copy on the model before invoking the collector') + + class QuantCalibrationStatsCollector(ActivationStatsCollector): """ This class tracks activations stats required for quantization, for each layer and for each input @@ -324,7 +333,25 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): * Average min / max (calculate min / max per sample and average those) * Overall mean * Overall standard-deviation - Calculated stats are saved to a YAML file. + + The generated stats dict has the following structure per-layer: + 'layer_name': + 'inputs': + 0: + 'min': value + 'max': value + ... + ... + n: + 'min': value + 'max': value + ... + 'output': + 'min': value + 'max': value + ... + Where n is the number of inputs the layer has. + The calculated stats can be saved to a YAML file. If a certain layer operates in-place, that layer's input stats will be overwritten by its output stats. The collector can, optionally, check for such cases at runtime. In addition, a simple mechanism to disable inplace @@ -350,6 +377,9 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): def __init__(self, model, classes=None, inplace_runtime_check=False, disable_inplace_attrs=False, inplace_attr_names=('inplace',)): super(QuantCalibrationStatsCollector, self).__init__(model, "quant_stats", classes) + + _verify_no_dataparallel(model) + self.batch_idx = 0 self.inplace_runtime_check = inplace_runtime_check @@ -456,6 +486,251 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): return fname +class ActivationHistogramsCollector(ActivationStatsCollector): + """ + This class collects activation histograms, for each layer and for each input and output tensor. + It requires pre-computed min/max stats per tensor. This is done in order to prevent the need to save + all of the activation tensors throughout the run. The histogram is created once according to these + min/max values, and updated after each iteration. Any value outside the pre-computed range is clamped. + + The generated stats dict has the following structure per-layer: + 'layer_name': + 'inputs': + 0: + 'hist': tensor # Tensor with bin counts + 'bin_centroids': tensor # Tensor with activation values corresponding to center of each bin + ... + n: + 'hist': tensor + 'bin_centroids': tensor + 'output': + 'hist': tensor + 'bin_centroids': tensor + Where n is the number of inputs the layer has. + The generated stats dictionary can be saved to a file. + Optionally, histogram images for all tensor can be saved as well + + Args: + model (torch.nn.Module): The model we are monitoring + activation_stats (str / dict): Either a path to activation stats YAML file, or a dictionary containing + the stats. The stats are expected to be in the same structure as generated by QuantCalibrationStatsCollector. + classes (list): List of class types for which we collect activation statistics. Passing an empty list or + None will collect statistics for all class types. + nbins (int): Number of histogram bins + save_hist_imgs (bool): If set, calling save() will dump images of the histogram plots in addition to saving the + stats dictionary + hist_imgs_ext (str): The file type to be used when saving histogram images + """ + def __init__(self, model, activation_stats, classes=None, nbins=2048, + save_hist_imgs=False, hist_imgs_ext='.svg'): + super(ActivationHistogramsCollector, self).__init__(model, 'histogram', classes) + + _verify_no_dataparallel(model) + + if isinstance(activation_stats, str): + if not os.path.isfile(activation_stats): + raise ValueError("Model activation stats file not found at: " + activation_stats) + msglogger.info('Loading activation stats from: ' + activation_stats) + with open(activation_stats, 'r') as stream: + activation_stats = distiller.utils.yaml_ordered_load(stream) + elif not isinstance(activation_stats, (dict, OrderedDict)): + raise TypeError('model_activation_stats must either be a string, a dict / OrderedDict or None') + + self.act_stats = activation_stats + self.nbins = nbins + self.save_imgs = save_hist_imgs + self.imgs_ext = hist_imgs_ext if hist_imgs_ext[0] == '.' else '.' + hist_imgs_ext + + def _get_min_max(self, *keys): + stats_entry = reduce(operator.getitem, keys, self.act_stats) + return stats_entry['min'], stats_entry['max'] + + def _activation_stats_cb(self, module, inputs, output): + def get_hist(t, stat_min, stat_max): + t_clamped = t.clamp(stat_min, stat_max) + hist = torch.histc(t_clamped.cpu(), bins=self.nbins, min=stat_min, max=stat_max) + return hist + + for idx, input in enumerate(inputs): + stat_min, stat_max = self._get_min_max(module.distiller_name, 'inputs', idx) + curr_hist = get_hist(input, stat_min, stat_max) + module.input_hists[idx] += curr_hist + + stat_min, stat_max = self._get_min_max(module.distiller_name, 'output') + curr_hist = get_hist(output, stat_min, stat_max) + module.output_hist += curr_hist + + def _reset(self, module): + num_inputs = len(self.act_stats[module.distiller_name]['inputs']) + module.input_hists = module.input_hists = [torch.zeros(self.nbins) for _ in range(num_inputs)] + module.output_hist = torch.zeros(self.nbins) + + def _start_counter(self, module): + self._reset(module) + + def _reset_counter(self, module): + if hasattr(module, 'output_hist'): + self._reset(module) + + def _collect_activations_stats(self, module, activation_stats, name=''): + if distiller.utils.has_children(module): + return + if not hasattr(module, 'output_hist'): + return + + def get_hist_entry(min_val, max_val, hist): + od = OrderedDict() + od['hist'] = hist + bin_width = (max_val - min_val) / self.nbins + od['bin_centroids'] = torch.linspace(min_val + bin_width / 2, max_val - bin_width / 2, self.nbins) + return od + + stats_od = OrderedDict() + inputs_od = OrderedDict() + for idx, hist in enumerate(module.input_hists): + inputs_od[idx] = get_hist_entry(*self._get_min_max(module.distiller_name, 'inputs', idx), + module.input_hists[idx]) + + output_od = get_hist_entry(*self._get_min_max(module.distiller_name, 'output'), module.output_hist) + + stats_od['inputs'] = inputs_od + stats_od['output'] = output_od + activation_stats[module.distiller_name] = stats_od + + def save(self, fname): + hist_dict = self.value() + + if not fname.endswith('.pt'): + fname = ".".join([fname, 'pt']) + try: + os.remove(fname) + except OSError: + pass + + torch.save(hist_dict, fname) + + if self.save_imgs: + msglogger.info('Saving histogram images...') + save_dir = os.path.join(os.path.split(fname)[0], 'histogram_imgs') + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + + def save_hist(layer_name, tensor_name, idx, bin_counts, bin_centroids, normed=True): + if normed: + bin_counts = bin_counts / bin_counts.sum() + plt.figure(figsize=(12, 12)) + plt.suptitle('\n'.join((layer_name, tensor_name)), fontsize=18, fontweight='bold') + for subplt_idx, yscale in enumerate(['linear', 'log']): + plt.subplot(2, 1, subplt_idx + 1) + plt.fill_between(bin_centroids, bin_counts, step='mid', antialiased=False) + if yscale == 'linear': + plt.ylim(bottom=0) + plt.title(yscale + ' scale') + plt.yscale(yscale) + plt.xlabel('Activation Value') + plt.ylabel('Normalized Count') + plt.tight_layout(rect=[0, 0, 1, 0.93]) + idx_str = '{:03d}'.format(idx) + plt.savefig(os.path.join(save_dir, '-'.join((idx_str, layer_name, tensor_name)) + self.imgs_ext)) + plt.close() + + cnt = 0 + for layer_name, data in hist_dict.items(): + for idx, od in data['inputs'].items(): + cnt += 1 + save_hist(layer_name, 'input_{}'.format(idx), cnt, od['hist'], od['bin_centroids'], normed=True) + od = data['output'] + cnt += 1 + save_hist(layer_name, 'output', cnt, od['hist'], od['bin_centroids'], normed=True) + msglogger.info('Done') + return fname + + +def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False, + disable_inplace_attrs=False, inplace_attr_names=('inplace',)): + """ + Helper function for collecting quantization calibration statistics for a model using QuantCalibrationStatsCollector + + Args: + model (nn.Module): The model for which to collect stats + test_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that + accepts the model. All other arguments should be set in advance (can be done using functools.partial), or + they will be left with their default values. + save_dir (str): Path to directory where stats YAML file will be saved. If None then YAML will not be saved + to disk. + classes (iterable): See QuantCalibrationStatsCollector + inplace_runtime_check (bool): See QuantCalibrationStatsCollector + disable_inplace_attrs (bool): See QuantCalibrationStatsCollector + inplace_attr_names (iterable): See QuantCalibrationStatsCollector + + Returns: + Dictionary with quantization stats (see QuantCalibrationStatsCollector for a description of the dictionary + contents) + """ + msglogger.info('Collecting quantization calibration stats for model') + quant_stats_collector = QuantCalibrationStatsCollector(model, classes=classes, + inplace_runtime_check=inplace_runtime_check, + disable_inplace_attrs=disable_inplace_attrs, + inplace_attr_names=inplace_attr_names) + with collector_context(quant_stats_collector): + test_fn(model=model) + msglogger.info('Stats collection complete') + if save_dir is not None: + save_path = os.path.join(save_dir, 'acts_quantization_stats.yaml') + quant_stats_collector.save(save_path) + msglogger.info('Stats saved to ' + save_path) + + return quant_stats_collector.value() + + +def collect_histograms(model, test_fn, save_dir=None, activation_stats=None, + classes=None, nbins=2048, save_hist_imgs=False, hist_imgs_ext='.svg'): + """ + Helper function for collecting activation histograms for a model using ActivationsHistogramCollector. + Will perform 2 passes - one to collect the required stats and another to collect the histograms. The first + pass can be skipped by passing pre-calculated stats. + + Args: + model (nn.Module): The model for which to collect histograms + test_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that + accepts the model. All other arguments should be set in advance (can be done using functools.partial), or + they will be left with their default values. + save_dir (str): Path to directory where histograms will be saved. If None then data will not be saved to disk. + activation_stats (str / dict / None): Either a path to activation stats YAML file, or a dictionary containing + the stats. The stats are expected to be in the same structure as generated by QuantCalibrationStatsCollector. + If None, then a stats collection pass will be performed. + classes: See ActivationsHistogramCollector + nbins: See ActivationsHistogramCollector + save_hist_imgs: See ActivationsHistogramCollector + hist_imgs_ext: See ActivationsHistogramCollector + + Returns: + Dictionary with histograms data (See ActivationsHistogramCollector for a description of the dictionary + contents) + """ + msglogger.info('Pass 1: Stats collection') + if activation_stats is not None: + msglogger.info('Pre-computed activation stats passed, skipping stats collection') + else: + activation_stats = collect_quant_stats(model, test_fn, save_dir=save_dir, classes=classes, + inplace_runtime_check=True, disable_inplace_attrs=True) + + msglogger.info('Pass 2: Histograms generation') + histogram_collector = ActivationHistogramsCollector(model, activation_stats, classes=classes, nbins=nbins, + save_hist_imgs=save_hist_imgs, hist_imgs_ext=hist_imgs_ext) + with collector_context(histogram_collector): + test_fn(model=model) + msglogger.info('Histograms generation complete') + if save_dir is not None: + save_path = os.path.join(save_dir, 'acts_histograms.pt') + histogram_collector.save(save_path) + msglogger.info("Histogram data saved to " + save_path) + if save_hist_imgs: + msglogger.info('Histogram images saved in ' + os.path.join(save_dir, 'histogram_imgs')) + + return histogram_collector.value() + + @contextmanager def collector_context(collector): """A context manager for an activation collector""" diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 1582f87..9429381 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -198,25 +198,18 @@ def main(): if args.summary: return summarize_model(model, args.dataset, which_summary=args.summary) - activations_collectors = create_activation_stats_collectors(model, *args.activation_stats) - if args.qe_calibration: - msglogger.info('Quantization calibration stats collection enabled:') - msglogger.info('\tStats will be collected for {:.1%} of test dataset'.format(args.qe_calibration)) - msglogger.info('\tSetting constant seeds and converting model to serialized execution') - distiller.set_deterministic() - model = distiller.make_non_parallel_copy(model) - activations_collectors.update(create_quantization_stats_collector(model)) - args.evaluate = True - args.effective_test_size = args.qe_calibration + return acts_quant_stats_collection(model, criterion, pylogger, args) + + if args.activation_histograms: + return acts_histogram_collection(model, criterion, pylogger, args) + + activations_collectors = create_activation_stats_collectors(model, *args.activation_stats) # Load the datasets: the dataset to load is inferred from the model name passed # in args.arch. The default dataset is ImageNet, but if args.arch contains the # substring "_cifar", then cifar10 is used. - train_loader, val_loader, test_loader, _ = apputils.load_data( - args.dataset, os.path.expanduser(args.data), args.batch_size, - args.workers, args.validation_split, args.deterministic, - args.effective_train_size, args.effective_valid_size, args.effective_test_size) + train_loader, val_loader, test_loader, _ = load_data(args) msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler)) @@ -678,10 +671,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsitie def automated_deep_compression(model, criterion, optimizer, loggers, args): - train_loader, val_loader, test_loader, _ = apputils.load_data( - args.dataset, os.path.expanduser(args.data), args.batch_size, - args.workers, args.validation_split, args.deterministic, - args.effective_train_size, args.effective_valid_size, args.effective_test_size) + train_loader, val_loader, test_loader, _ = load_data(args) args.display_confusion = True validate_fn = partial(test, test_loader=test_loader, criterion=criterion, @@ -695,10 +685,7 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args): def greedy(model, criterion, optimizer, loggers, args): - train_loader, val_loader, test_loader, _ = apputils.load_data( - args.dataset, os.path.expanduser(args.data), args.batch_size, - args.workers, args.validation_split, args.deterministic, - args.effective_train_size, args.effective_valid_size, args.effective_test_size) + train_loader, val_loader, test_loader, _ = load_data(args) test_fn = partial(test, test_loader=test_loader, criterion=criterion, loggers=loggers, args=args, activations_collectors=None) @@ -710,6 +697,37 @@ def greedy(model, criterion, optimizer, loggers, args): test_fn, train_fn) +def acts_quant_stats_collection(model, criterion, loggers, args): + msglogger.info('Collecting quantization calibration stats based on {:.1%} of test dataset' + .format(args.qe_calibration)) + model = distiller.utils.make_non_parallel_copy(model) + args.effective_test_size = args.qe_calibration + train_loader, val_loader, test_loader, _ = load_data(args) + test_fn = partial(test, test_loader=test_loader, criterion=criterion, + loggers=loggers, args=args, activations_collectors=None) + collect_quant_stats(model, test_fn, save_dir=msglogger.logdir, + classes=None, inplace_runtime_check=True, disable_inplace_attrs=True) + + +def acts_histogram_collection(model, criterion, loggers, args): + msglogger.info('Collecting activation histograms based on {:.1%} of test dataset' + .format(args.activation_histograms)) + model = distiller.utils.make_non_parallel_copy(model) + args.effective_test_size = args.activation_histograms + train_loader, val_loader, test_loader, _ = load_data(args, fixed_subset=True) + test_fn = partial(test, test_loader=test_loader, criterion=criterion, + loggers=loggers, args=args, activations_collectors=None) + collect_histograms(model, test_fn, save_dir=msglogger.logdir, + classes=None, nbins=2048, save_hist_imgs=True) + + +def load_data(args, fixed_subset=False): + return apputils.load_data(args.dataset, os.path.expanduser(args.data), args.batch_size, + args.workers, args.validation_split, args.deterministic, + args.effective_train_size, args.effective_valid_size, args.effective_test_size, + fixed_subset) + + class missingdict(dict): """This is a little trick to prevent KeyError""" def __missing__(self, key): @@ -729,8 +747,6 @@ def create_activation_stats_collectors(model, *phases): WARNING! Enabling activation statsitics collection will significantly slow down training! """ - distiller.utils.assign_layer_fq_names(model) - genCollectors = lambda: missingdict({ "sparsity": SummaryActivationStatsCollector(model, "sparsity", lambda t: 100 * distiller.utils.sparsity(t)), @@ -747,13 +763,6 @@ def create_activation_stats_collectors(model, *phases): for k in ('train', 'valid', 'test')} -def create_quantization_stats_collector(model): - distiller.utils.assign_layer_fq_names(model) - return {'test': missingdict({'quantization_stats': QuantCalibrationStatsCollector(model, classes=None, - inplace_runtime_check=True, - disable_inplace_attrs=True)})} - - def save_collectors_data(collectors, directory): """Utility function that saves all activation statistics to disk. diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index 2deddcf..803facf 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -73,6 +73,11 @@ def get_parser(): parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(), help='collect activation statistics on phases: train, valid, and/or test' ' (WARNING: this slows down training)') + parser.add_argument('--activation-histograms', '--act-hist', + type=distiller.utils.float_range_argparse_checker(exc_min=True), + metavar='PORTION_OF_TEST_SET', + help='Run the model in evaluation mode on the specified portion of the test dataset and ' + 'generate activation histograms. NOTE: This slows down evaluation significantly') parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False, help='print masks sparsity table at end of each epoch') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, diff --git a/jupyter/activation_histograms.ipynb b/jupyter/activation_histograms.ipynb new file mode 100644 index 0000000..d59ef93 --- /dev/null +++ b/jupyter/activation_histograms.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Activation Histograms\n", + "\n", + "This notebook shows an example of how to generate activation histograms for a specific model and dataset.\n", + "\n", + "## But I Already Know How To Generate Histograms...\n", + "\n", + "If you already generated histograms using Distiller outside this notebook, you can still use it to visualize the data:\n", + "* To load the raw data saved by Distiller and visualize it, go to [this section](#Plot-Histograms)\n", + "* If enabled saving histogram images and want to view them, go to [this section](#Load-Histogram-Images-from-Disk)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "import math\n", + "import torchnet as tnt\n", + "from ipywidgets import widgets, interact\n", + "\n", + "import distiller\n", + "from distiller.models import create_model\n", + "\n", + "device = torch.device('cuda')\n", + "# device = torch.device('cpu')\n", + "\n", + "# Load some common code and configure logging\n", + "# We do this so we can see the logging output coming from\n", + "# Distiller function calls\n", + "%run './distiller_jupyter_helpers.ipynb'\n", + "msglogger = config_notebooks_logger()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Your Model\n", + "\n", + "For this example we'll use a pre-trained image classification model.\n", + "\n", + "### Note on Parallelism\n", + "\n", + "Currently, Distiller's implementation of activations histograms collection does not accept models which contain [`DataParallel`](https://pytorch.org/docs/stable/nn.html?highlight=dataparallel#torch.nn.DataParallel) modules. So here we create the model without parallelism to begin with. If you have a model which includes `DataParallel` modules (for example, if loaded from a checkpoint), use the following utlity function to convert the model to serialized execution:\n", + "```python\n", + "model = distiller.utils.make_non_parallel_copy(model)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=False)\n", + "model = model.to(device) # Comment out if not applicable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare Data\n", + "\n", + "Usually it is not required to collect histograms based on the entire dataset, and only a representative subset is used (that also helps reduce the runtime).\n", + "* **Subset size:** There is no golden rule for selecting the size of the subset. Anywhere between 1-10% of the validation/test set should work.\n", + "* **Representative data:** Whatever size is chosen, it is important to make sure that the subset is selected in a way that covers as much of the distribution of the data as possible. So, for example, if the dataset is organized by classes by default, we should make sure to select items randomly and not in order.\n", + "\n", + "**Note:** Working on only a subset of the data can be taken care of at data preparation time, or it can be delayed to the actual model evaluation function (for example, executing only a specific number of mini-batches). In this example we take care of it during data preparation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# We use Distiller's built-in data loading functionality for ImageNet,\n", + "# which takes care of randomizing the data before selecting the subset.\n", + "# While it creates train, validation and test data loaders, we're only\n", + "# interested in the test dataset in this example.\n", + "#\n", + "# Subset size: Here we'll go with 1% of the test set, mostly for the\n", + "# sake of speed. We control this with the 'effective_test_size' argument.\n", + "#\n", + "# We set the 'fixed_subset' argument to make sure we're using the\n", + "# same subset for both phases of histogram collection - more on that below\n", + "\n", + "dataset = 'imagenet'\n", + "dataset_path = '~/datasets/imagenet'\n", + "batch_size = 256\n", + "num_workers = 10\n", + "subset_size = 0.01\n", + "\n", + "_, _, test_loader, _ = distiller.apputils.load_data(\n", + " dataset, os.path.expanduser(dataset_path), batch_size, num_workers,\n", + " effective_test_size=subset_size, fixed_subset=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Model Evaluation Function\n", + "\n", + "We define a fairly bare-bones evaluation function. Recording the loss and accuracy isn't strictly necessary for histogram collection. We record them nonetheless, so we can verify the data subset being used achieves results that are on par from what we'd expect from a representative subset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def eval_model(data_loader, model):\n", + " print('Evaluating model')\n", + " criterion = torch.nn.CrossEntropyLoss().to(device)\n", + " \n", + " loss = tnt.meter.AverageValueMeter()\n", + " classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))\n", + "\n", + " total_samples = len(data_loader.sampler)\n", + " batch_size = data_loader.batch_size\n", + " total_steps = math.ceil(total_samples / batch_size)\n", + " print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))\n", + "\n", + " # Switch to evaluation mode\n", + " model.eval()\n", + "\n", + " for step, (inputs, target) in enumerate(data_loader):\n", + " print('[{:3d}/{:3d}] ... '.format(step + 1, total_steps), end='', flush=True)\n", + " with torch.no_grad():\n", + " inputs, target = inputs.to(device), target.to(device)\n", + " # compute output from model\n", + " output = model(inputs)\n", + "\n", + " # compute loss and measure accuracy\n", + " loss.add(criterion(output, target).item())\n", + " classerr.add(output.data, target)\n", + " \n", + " print('Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(\n", + " classerr.value(1), classerr.value(5), loss.mean), flush=True)\n", + " print('----------')\n", + " print('Overall ==> Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(\n", + " classerr.value(1), classerr.value(5), loss.mean), flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Collect Histograms\n", + "\n", + "Histogram collection is implemented using Distiller's \"Collector\" mechanism, specifically in the `ActivationHistogramsCollector` class. It is stats-based, meaning it requires pre-computed min/max values per-tensor to be provided.\n", + "\n", + "The min/max stats are expected as a dictionary with the following structure:\n", + "```YAML\n", + "'layer_name':\n", + " 'inputs':\n", + " 0:\n", + " 'min': value\n", + " 'max': value\n", + " ...\n", + " n:\n", + " 'min': value\n", + " 'max': value\n", + " 'output':\n", + " 'min': value\n", + " 'max': value\n", + "```\n", + "Where n is the number of inputs the layer has. The `QuantCalibrationStatsCollector` collector class generates stats in the required format.\n", + "\n", + "To streamline this process, a utility function is provided: `distiller.data_loggers.collect_histograms`. Given a model and a test function, it will perform the required stats collection followed by histograms collection. If the user has already computed min/max stats beforehand, those can provided as a dict or as a path to a YAML file (as saved by `QuantCalibrationStatsCollector`). In that case, the stats collection pass will be skipped.\n", + "\n", + "### Dataset Perparation in Context of Stats-Based Histograms\n", + "\n", + "If the data used for min/max stats collection is not the same as the data used for histogram collection, it is highly likely that when collecting histograms some values will fall outside the pre-calculated min/max range. When that happens, the value is **clamped**. Assuming the subsets of data used in both cases are representative enough, this shouldn't have a major effect on the results.\n", + "\n", + "One can choose to avoid this issue by making sure we use the same subset of data in both passes. How to make sure of that will, of course, differ from one use case to another. In this example we do this by using the enabling `fixed_subset` flag when calling `load_data` above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# The test function passed to 'collect_histograms' must have an \n", + "# argument named 'model' which accepts the model for which histograms\n", + "# are to be collected. 'collect_histograms' will not set any other\n", + "# arguments.\n", + "# We'll use Python's 'partial' to handle the set the rest of the\n", + "# arguments for the test function before calling 'collect_histograms'\n", + "from functools import partial\n", + "test_fn = partial(eval_model, data_loader=test_loader)\n", + "\n", + "# Histogram collection parameters\n", + "\n", + "# 'save_dir': Pass a valid directory path to have the histogram\n", + "# data saved to disk. Pass None to disable saving.\n", + "# 'save_hist_imgs': If save_dir is not None, toggles whether to save\n", + "# histogram images in addition to the raw data\n", + "# 'hist_imgs_ext': Controls the filetype for histogram images\n", + "save_dir = '.'\n", + "save_hist_imgs = True\n", + "hist_imgs_ext = '.png'\n", + "\n", + "# 'activation_stats': Here we pass None so a stats collection pass\n", + "# is performed.\n", + "activation_stats = None\n", + "\n", + "# 'classes': To speed-up the calculation here we use the 'classes'\n", + "# argument so that stats and histograms are collected only for \n", + "# ReLU layers in the model. Pass None to collect for all layers.\n", + "classes = [torch.nn.ReLU]\n", + "\n", + "# 'nbins': Number of histogram bins to use.\n", + "nbins = 2048\n", + "\n", + "hist_dict = distiller.data_loggers.collect_histograms(\n", + " model, test_fn, save_dir=save_dir, activation_stats=activation_stats,\n", + " classes=classes, nbins=nbins, save_hist_imgs=save_hist_imgs, hist_imgs_ext=hist_imgs_ext)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot Histograms\n", + "\n", + "The generated dictionary has the following structure (very similar to the structure of the min/max stats dictionary described above):\n", + "```yaml\n", + "'layer_name':\n", + " 'inputs':\n", + " 0:\n", + " 'hist': tensor # Tensor with bin counts\n", + " 'bin_centroids': tensor # Tensor with activation values corresponding to center of each bin\n", + " ...\n", + " n:\n", + " 'hist': tensor\n", + " 'bin_centroids': tensor\n", + " 'output':\n", + " 'hist': tensor\n", + " 'bin_centroids': tensor\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Uncomment this line to load saved output from a previous histogram collection run\n", + "# hist_dict = torch.load('acts_histograms.pt')\n", + "\n", + "plt.style.use('seaborn') # pretty matplotlib plots\n", + "\n", + "def draw_hist(layer_name, tensor_name, bin_counts, bin_centroids, normed=True, yscale='linear'):\n", + " if normed:\n", + " bin_counts = bin_counts / bin_counts.sum()\n", + " plt.figure(figsize=(12, 6))\n", + " plt.title('\\n'.join((layer_name, tensor_name)), fontsize=16)\n", + " plt.fill_between(bin_centroids, bin_counts, step='mid', antialiased=False)\n", + " if yscale == 'linear':\n", + " plt.ylim(bottom=0)\n", + " plt.yscale(yscale)\n", + " plt.xlabel('Activation Value')\n", + " plt.ylabel('Normalized Count')\n", + " plt.show()\n", + "\n", + "@interact(layer_name=hist_dict.keys(),\n", + " normalize_bin_counts=True,\n", + " y_axis_scale=['linear', 'log'])\n", + "def draw_layer(layer_name, normalize_bin_counts, y_axis_scale):\n", + " print('\\nSelected layer: ' + layer_name)\n", + " data = hist_dict[layer_name]\n", + " for idx, od in data['inputs'].items():\n", + " draw_hist(layer_name, 'input_{}'.format(idx), od['hist'], od['bin_centroids'],\n", + " normed=normalize_bin_counts, yscale=y_axis_scale)\n", + " od = data['output']\n", + " draw_hist(layer_name, 'output', od['hist'], od['bin_centroids'],\n", + " normed=normalize_bin_counts, yscale=y_axis_scale)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Histogram Images from Disk\n", + "\n", + "If you enabled saving of histogram images above, or have images from a collection executed externally, you can use the code below to display the images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "from IPython.display import Image, SVG, display\n", + "import glob\n", + "from collections import OrderedDict\n", + "\n", + "# Set the path to the images directory\n", + "imgs_dir = 'histogram_imgs'\n", + "\n", + "files = sorted(glob.glob(os.path.join(imgs_dir, '*.*')))\n", + "files = [f for f in files if os.path.isfile(f)]\n", + "fnames_map = OrderedDict([(os.path.split(f)[1], f) for f in files])\n", + "\n", + "@interact(file_name=fnames_map)\n", + "def load_image(file_name):\n", + " if file_name.endswith('.svg'):\n", + " display(SVG(filename=file_name))\n", + " else:\n", + " display(Image(filename=file_name))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/jupyter/distiller_jupyter_helpers.ipynb b/jupyter/distiller_jupyter_helpers.ipynb index cb0aa77..5f1d703 100755 --- a/jupyter/distiller_jupyter_helpers.ipynb +++ b/jupyter/distiller_jupyter_helpers.ipynb @@ -332,7 +332,13 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import logging\n", + "def config_notebooks_logger():\n", + " logging.config.fileConfig('logging.conf')\n", + " msglogger = logging.getLogger()\n", + " msglogger.info('Logging configured successfully')" + ] } ], "metadata": { diff --git a/jupyter/logging.conf b/jupyter/logging.conf new file mode 100755 index 0000000..41e46d6 --- /dev/null +++ b/jupyter/logging.conf @@ -0,0 +1,32 @@ +[formatters] +keys: simple, time_simple + +[handlers] +keys: console + +[loggers] +keys: root, app_cfg + +[formatter_simple] +format: %(message)s + +[formatter_time_simple] +format: %(asctime)s - %(message)s + +[handler_console] +class: StreamHandler +propagate: 0 +args: [] +formatter: simple + +[logger_root] +level: INFO +propagate: 1 +handlers: console + +[logger_app_cfg] +# Use this logger to log the application configuration and execution environment +level: DEBUG +qualname: app_cfg +propagate: 0 +handlers: console -- GitLab