diff --git a/README.md b/README.md index e01a111cc436e42b5e317f334c81b9cd59ede025..a767009b81bc21ffadb9e49e9a193ea11a28e739 100755 --- a/README.md +++ b/README.md @@ -37,12 +37,12 @@ Network compression can reduce the memory footprint of a neural network, increas <details><summary><b>What's New in November?</b></summary> <p> - - Quantization: - - To improve quantization results: Added averaging-based activations clipping in SymmetricLinearQuantizer. - - For better control over quantization configuration: Added command line arguments for post-training quantization settings in image classification sample. + - Quantization: Several new features in [range-based linear quantization](https://nervanasystems.github.io/distiller/algo_quantization/index.html#range-based-linear-quantization): - Asymmetric post-training quantization (only symmetric supported so until now) - Quantization aware training for range-based (min-max) symmetric and asymmetric quantization - - Per-channel quantization support in all of the above scenarios + - Per-channel weights quantization support (per output channel) in both training and post-training + - To improve quantization results: Averaging-based activations clipping in post-training quantization. + - More control over post-training quantization configuration: Additional [command line arguments](https://nervanasystems.github.io/distiller/usage/index.html##post-training-quantization) in image classification sample. - Added an implementation of [Dynamic Network Surgery for Efficient DNNs](https://arxiv.org/abs/1608.04493) with: - A sample implementation on ResNet50 which achieves 82.5% compression 75.5% Top1 (-0.6% from TorchVision baseline). - A new SplicingPruner pruning algorithm. @@ -101,22 +101,30 @@ Beware. ## Table of Contents -* [Feature set](#feature-set) -* [Installation](#installation) - + [Clone Distiller](#clone-distiller) - + [Create a Python virtual environment](#create-a-python-virtual-environment) - + [Install dependencies](#install-dependencies) -* [Getting Started](#getting-started) - + [Example invocations of the sample application](#example-invocations-of-the-sample-application) - + [Explore the sample Jupyter notebooks](#explore-the-sample-jupyter-notebooks) -* [Set up the classification datasets](#set-up-the-classification-datasets) -* [Running the tests](#running-the-tests) -* [Generating the HTML documentation site](#generating-the-html-documentation-site) -* [Versioning](#versioning) -* [License](#license) -* [Citation](#citation) -* [Acknowledgments](#acknowledgments) -* [Disclaimer](#disclaimer) +- [Table of Contents](#table-of-contents) +- [Highlighted features](#highlighted-features) +- [Installation](#installation) + - [Clone Distiller](#clone-distiller) + - [Create a Python virtual environment](#create-a-python-virtual-environment) + - [Using virtualenv](#using-virtualenv) + - [Using venv](#using-venv) + - [Activate the environment](#activate-the-environment) + - [Install dependencies](#install-dependencies) +- [Getting Started](#getting-started) + - [Example invocations of the sample application](#example-invocations-of-the-sample-application) + - [Training-only](#training-only) + - [Getting parameter statistics of a sparsified model](#getting-parameter-statistics-of-a-sparsified-model) + - [Post-training quantization](#post-training-quantization) + - [Explore the sample Jupyter notebooks](#explore-the-sample-jupyter-notebooks) +- [Set up the classification datasets](#set-up-the-classification-datasets) +- [Running the tests](#running-the-tests) +- [Generating the HTML documentation site](#generating-the-html-documentation-site) +- [Built With](#built-with) +- [Versioning](#versioning) +- [License](#license) +- [Citation](#citation) +- [Acknowledgments](#acknowledgments) +- [Disclaimer](#disclaimer) ## Highlighted features @@ -144,8 +152,8 @@ Beware. - Group Lasso an group variance regularization * **Quantization** - Automatic mechanism to transform existing models to quantized versions, with customizable bit-width configuration for different layers. No need to re-write the model for different quantization methods. - - Support for [training with quantization](https://nervanasystems.github.io/distiller/quantization/index.html#training-with-quantization) in the loop - - One-shot 8-bit quantization of trained full-precision models + - Post-training quantization of trained full-precision models + - Support for [quantization-aware training](https://nervanasystems.github.io/distiller/quantization/index.html#quantization-aware-training) in the loop * **Knowledge distillation** - Training with [knowledge distillation](https://nervanasystems.github.io/distiller/knowledge_distillation/index.html), in conjunction with the other available pruning / regularization / quantization methods. * **Conditional computation** @@ -232,7 +240,7 @@ For more details, there are some other resources you can refer to: ### Example invocations of the sample application + [Training-only](#training-only) + [Getting parameter statistics of a sparsified model](#getting-parameter-statistics-of-a-sparsified-model) -+ [8-bit quantization](#8-bit-quantization) ++ [Post-training quantization](#post-training-quantization) #### Training-only The following will invoke training-only (no compression) of a network named 'simplenet' on the CIFAR10 dataset. This is roughly based on TorchVision's sample Imagenet training application, so it should look familiar if you've used that application. In this example we don't invoke any compression mechanisms: we just train because for fine-tuning after pruning, training is an essential part.<br> @@ -273,14 +281,14 @@ $ python3 compress_classifier.py --resume=../ssl/checkpoints/checkpoint_trained_ ``` <center> <img src="imgs/ch_compute_stats.png"></center> -#### 8-bit quantization +#### Post-training quantization This example performs 8-bit quantization of ResNet20 for CIFAR10. We've included in the git repository the checkpoint of a ResNet20 model that we've trained with 32-bit floats, so we'll take this model and quantize it: ``` -$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10 --resume ../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar --quantize --evaluate +$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10 --resume ../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar --quantize-eval --evaluate ``` -The command-line above will save a checkpoint named `quantized_checkpoint.pth.tar` containing the quantized model parameters. +The command-line above will save a checkpoint named `quantized_checkpoint.pth.tar` containing the quantized model parameters. See more examples [here](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant.md). ### Explore the sample Jupyter notebooks The set of notebooks that come with Distiller is described [here](https://nervanasystems.github.io/distiller/jupyter/index.html#using-the-distiller-notebooks), which also explains the steps to install the Jupyter notebook server.<br> diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py index e174005241b14f5f6f1f0fe97dcabb9c850489a6..171655be429a17d55336070d10b8cb8f55096c03 100644 --- a/distiller/quantization/__init__.py +++ b/distiller/quantization/__init__.py @@ -15,7 +15,8 @@ # from .quantizer import Quantizer -from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer +from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \ + LinearQuantMode, QuantAwareTrainRangeLinearQuantizer from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer del quantizer diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index 2312dec5414bec8556101d13901a8ed07226aea5..d7fedf94c322d034d2c48fed3e9c86b644627c8f 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -27,33 +27,17 @@ msglogger = logging.getLogger() ### -class LinearQuantizeSTE(torch.autograd.Function): - @staticmethod - def forward(ctx, input, scale_factor, dequantize, inplace): - if inplace: - ctx.mark_dirty(input) - output = linear_quantize(input, scale_factor, inplace) - if dequantize: - output = linear_dequantize(output, scale_factor, inplace) - return output - - @staticmethod - def backward(ctx, grad_output): - # Straight-through estimator - return grad_output, None, None, None - - class LearnedClippedLinearQuantizeSTE(torch.autograd.Function): @staticmethod def forward(ctx, input, clip_val, num_bits, dequantize, inplace): ctx.save_for_backward(input, clip_val) if inplace: ctx.mark_dirty(input) - scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val.data[0]) + scale, zero_point = asymmetric_linear_quantization_params(num_bits, 0, clip_val.data[0], signed=False) output = clamp(input, 0, clip_val.data[0], inplace) - output = linear_quantize(output, scale_factor, inplace) + output = linear_quantize(output, scale, zero_point, inplace) if dequantize: - output = linear_dequantize(output, scale_factor, inplace) + output = linear_dequantize(output, scale, zero_point, inplace) return output @staticmethod @@ -76,13 +60,13 @@ class ClippedLinearQuantization(nn.Module): super(ClippedLinearQuantization, self).__init__() self.num_bits = num_bits self.clip_val = clip_val - self.scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val) + self.scale, self.zero_point = asymmetric_linear_quantization_params(num_bits, 0, clip_val, signed=False) self.dequantize = dequantize self.inplace = inplace def forward(self, input): input = clamp(input, 0, self.clip_val, self.inplace) - input = LinearQuantizeSTE.apply(input, self.scale_factor, self.dequantize, self.inplace) + input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, self.inplace) return input def __repr__(self): @@ -100,7 +84,8 @@ class LearnedClippedLinearQuantization(nn.Module): self.inplace = inplace def forward(self, input): - input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, self.dequantize, self.inplace) + input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, + self.dequantize, self.inplace) return input def __repr__(self): @@ -124,10 +109,10 @@ class WRPNQuantizer(Quantizer): bits_weights=bits_weights, bits_overrides=bits_overrides, train_with_fp_copy=True, quantize_bias=quantize_bias) - def wrpn_quantize_param(param_fp, num_bits): - scale_factor = symmetric_linear_quantization_scale_factor(num_bits, 1) + def wrpn_quantize_param(param_fp, param_meta): + scale, zero_point = symmetric_linear_quantization_params(param_meta.num_bits, 1) out = param_fp.clamp(-1, 1) - out = LinearQuantizeSTE.apply(out, scale_factor, True, False) + out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False) return out def relu_replace_fn(module, name, qbits_map): @@ -141,11 +126,11 @@ class WRPNQuantizer(Quantizer): self.replacement_factory[nn.ReLU] = relu_replace_fn -def dorefa_quantize_param(param_fp, num_bits): - scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1) +def dorefa_quantize_param(param_fp, param_meta): + scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False) out = param_fp.tanh() out = out / (2 * out.abs().max()) + 0.5 - out = LinearQuantizeSTE.apply(out, scale_factor, True, False) + out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False) out = 2 * out - 1 return out diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py index 348721106225665708e813cab992afdabefa7190..306f319689aa6994c997b4017dd555498260d323 100644 --- a/distiller/quantization/q_utils.py +++ b/distiller/quantization/q_utils.py @@ -17,15 +17,30 @@ import torch -def symmetric_linear_quantization_scale_factor(num_bits, saturation_val): +def symmetric_linear_quantization_params(num_bits, saturation_val): # Leave one bit for sign n = 2 ** (num_bits - 1) - 1 - return n / saturation_val + scale = n / saturation_val + if isinstance(scale, torch.Tensor): + zero_point = torch.zeros_like(scale) + else: + zero_point = 0.0 + return scale, zero_point -def asymmetric_linear_quantization_scale_factor(num_bits, saturation_min, saturation_max): +def asymmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, + integral_zero_point=True, signed=False): n = 2 ** num_bits - 1 - return n / (saturation_max - saturation_min) + scale = n / (saturation_max - saturation_min) + zero_point = scale * saturation_min + if integral_zero_point: + if isinstance(zero_point, torch.Tensor): + zero_point = zero_point.round() + else: + zero_point = float(round(zero_point)) + if signed: + zero_point += 2 ** (num_bits - 1) + return scale, zero_point def clamp(input, min, max, inplace=False): @@ -35,35 +50,48 @@ def clamp(input, min, max, inplace=False): return torch.clamp(input, min, max) -def linear_quantize(input, scale_factor, inplace=False): +def linear_quantize(input, scale, zero_point, inplace=False): if inplace: - input.mul_(scale_factor).round_() + input.mul_(scale).sub_(zero_point).round_() return input - return torch.round(scale_factor * input) + return torch.round(scale * input - zero_point) -def linear_quantize_clamp(input, scale_factor, clamp_min, clamp_max, inplace=False): - output = linear_quantize(input, scale_factor, inplace) +def linear_quantize_clamp(input, scale, zero_point, clamp_min, clamp_max, inplace=False): + output = linear_quantize(input, scale, zero_point, inplace) return clamp(output, clamp_min, clamp_max, inplace) -def linear_dequantize(input, scale_factor, inplace=False): +def linear_dequantize(input, scale, zero_point, inplace=False): if inplace: - input.div_(scale_factor) + input.add_(zero_point).div_(scale) return input - return input / scale_factor + return (input + zero_point) / scale + + +def get_tensor_min_max(t, per_dim=None): + if per_dim is None: + return t.min(), t.max() + if per_dim > t.dim(): + raise ValueError('Got per_dim={0}, but tensor only has {1} dimensions', per_dim, t.dim()) + view_dims = [t.shape[i] for i in range(per_dim + 1)] + [-1] + tv = t.view(*view_dims) + return tv.min(dim=-1)[0], tv.max(dim=-1)[0] + +def get_tensor_avg_min_max(t, across_dim=None): + min_per_dim, max_per_dim = get_tensor_min_max(t, per_dim=across_dim) + return min_per_dim.mean(), max_per_dim.mean() -def get_tensor_max_abs(tensor): - return max(abs(tensor.max().item()), abs(tensor.min().item())) +def get_tensor_max_abs(t, per_dim=None): + min_val, max_val = get_tensor_min_max(t, per_dim=per_dim) + return torch.max(min_val.abs_(), max_val.abs_()) -def get_tensor_avg_max_abs_across_batch(tensor): - # Assume batch is at dim 0 - tv = tensor.view(tensor.size()[0], -1) - avg_max = tv.max(dim=1)[0].mean().item() - avg_min = tv.min(dim=1)[0].mean().item() - return max(abs(avg_max), abs(avg_min)) + +def get_tensor_avg_max_abs(t, across_dim=None): + avg_min, avg_max = get_tensor_avg_min_max(t, across_dim=across_dim) + return torch.max(avg_min.abs_(), avg_max.abs_()) def get_quantized_range(num_bits, signed=True): @@ -71,3 +99,19 @@ def get_quantized_range(num_bits, signed=True): n = 2 ** (num_bits - 1) return -n, n - 1 return 0, 2 ** num_bits - 1 + + +class LinearQuantizeSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, input, scale, zero_point, dequantize, inplace): + if inplace: + ctx.mark_dirty(input) + output = linear_quantize(input, scale, zero_point, inplace) + if dequantize: + output = linear_dequantize(output, scale, zero_point, inplace) + return output + + @staticmethod + def backward(ctx, grad_output): + # Straight-through estimator + return grad_output, None, None, None, None diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 0a5b976db78994627eca2c89a6bd24abe5794c02..3e9083ae4d01727c1f05b0e0143325701f0471be 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -33,7 +33,7 @@ def has_bias(module): return hasattr(module, 'bias') and module.bias is not None -def hack_float_backup_parameter(module, name): +def hack_float_backup_parameter(module, name, num_bits): try: data = dict(module.named_parameters())[name].data except KeyError: @@ -42,6 +42,17 @@ def hack_float_backup_parameter(module, name): delattr(module, name) module.register_buffer(name, torch.zeros_like(data)) + first = False + if not hasattr(module, 'repr_mod'): + setattr(module, 'repr_mod', ', \nDistiller_QuantAwareTrain: ') + first = True + module.original_extra_repr = module.extra_repr + module.extra_repr = lambda: module.original_extra_repr() + module.repr_mod + + if not first: + module.repr_mod += ' ; ' + module.repr_mod += '{0} --> {1} bits'.format(name, num_bits) + class _ParamToQuant(object): def __init__(self, module, module_name, fp_attr_name, q_attr_name, num_bits): @@ -146,6 +157,11 @@ class Quantizer(object): self.module_qbits_map[module_name] = qbits def prepare_model(self): + self._prepare_model_impl() + + msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) + + def _prepare_model_impl(self): r""" Iterates over the model and replaces modules with their quantized counterparts as defined by self.replacement_factory @@ -164,7 +180,7 @@ class Quantizer(object): continue fp_attr_name = param_name if self.train_with_fp_copy: - hack_float_backup_parameter(module, param_name) + hack_float_backup_parameter(module, param_name, qbits.wts) fp_attr_name = FP_BKP_PREFIX + param_name self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, qbits.wts)) @@ -178,8 +194,6 @@ class Quantizer(object): new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) - msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) - def _pre_process_container(self, container, prefix=''): # Iterate through model, insert quantization functions as appropriate for name, module in container.named_children(): @@ -220,11 +234,10 @@ class Quantizer(object): def quantize_params(self): """ - Quantize all parameters using the parameters using self.param_quantization_fn (using the defined number - of bits for each parameter) + Quantize all parameters using self.param_quantization_fn (with the defined number of bits for each parameter) """ for ptq in self.params_to_quantize: - q_param = self.param_quantization_fn(getattr(ptq.module, ptq.fp_attr_name), ptq.num_bits) + q_param = self.param_quantization_fn(getattr(ptq.module, ptq.fp_attr_name), ptq) if self.train_with_fp_copy: setattr(ptq.module, ptq.q_attr_name, q_param) else: diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 863920b644f36a4254e754489c054165325e816e..c484b9d66fbe71883b2002a3018226115042bd1c 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -15,13 +15,58 @@ # import torch.nn as nn +from enum import Enum +from collections import OrderedDict +import distiller.utils from .quantizer import Quantizer from .q_utils import * -### -# Range-based linear quantization -### + +class LinearQuantMode(Enum): + SYMMETRIC = 1 + ASYMMETRIC_UNSIGNED = 2 + ASYMMETRIC_SIGNED = 3 + + +def verify_mode(mode): + if isinstance(mode, str): + try: + return LinearQuantMode[mode] + except KeyError: + raise ValueError('Unknown quantization mode string') + elif isinstance(mode, LinearQuantMode): + return mode + else: + raise TypeError("'mode' argument can be either a string or member of {0}".format(LinearQuantMode.__name__)) + + +############################################################################### +# Post Training +############################################################################### + + +def _get_tensor_quantization_params(tensor, num_bits, mode, clip=False, per_channel=False): + if per_channel and tensor.dim() not in [2, 4]: + raise ValueError('Per channel quantization possible only with 2D or 4D tensors (linear or conv layer weights)') + dim = 0 if clip or per_channel else None + if mode == LinearQuantMode.SYMMETRIC: + sat_fn = get_tensor_avg_max_abs if clip else get_tensor_max_abs + sat_val = sat_fn(tensor, dim) + scale, zp = symmetric_linear_quantization_params(num_bits, sat_val) + else: # Asymmetric mode + sat_fn = get_tensor_avg_min_max if clip else get_tensor_min_max + sat_min, sat_max = sat_fn(tensor, dim) + signed = mode == LinearQuantMode.ASYMMETRIC_SIGNED + scale, zp = asymmetric_linear_quantization_params(num_bits, sat_min, sat_max, signed=signed) + + if per_channel: + # Reshape scale and zero_points so they can be broadcast properly with the weight tensor + dims = [scale.shape[0]] + [1] * (tensor.dim() - 1) + scale = scale.view(dims) + zp = zp.view(dims) + + return scale, zp class RangeLinearQuantWrapper(nn.Module): @@ -32,63 +77,107 @@ class RangeLinearQuantWrapper(nn.Module): wrapped_module (torch.nn.Module): Module to be wrapped num_bits_acts (int): Number of bits used for inputs and output quantization num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results + mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned) clip_acts (bool): If true, will clip activations instead of using absolute min/max. At the moment clipping is done by averaging over the max absolute values of samples within a batch. More methods might be added in the future. """ - def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, clip_acts=False): + + def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC, + clip_acts=False): super(RangeLinearQuantWrapper, self).__init__() self.wrapped_module = wrapped_module self.num_bits_acts = num_bits_acts self.num_bits_accum = num_bits_accum + self.mode = mode self.clip_acts = clip_acts - self.acts_sat_val_func = get_tensor_avg_max_abs_across_batch if clip_acts else get_tensor_max_abs - self.acts_min_q_val, self.acts_max_q_val = get_quantized_range(num_bits_acts, signed=True) + # Controls whether output is de-quantized at end of forward op. Meant as a debug / test flag only + # (note that if False, the quantized output will be returned, but without any quantization parameters, + # so other than inspecting the contents there's not much to do with it) + self._dequant_out = True + + signed = mode != LinearQuantMode.ASYMMETRIC_UNSIGNED + self.acts_min_q_val, self.acts_max_q_val = get_quantized_range(num_bits_acts, signed=signed) + # The accumulator is always signed self.accum_min_q_val, self.accum_max_q_val = get_quantized_range(num_bits_accum, signed=True) def forward(self, *inputs): - in_scales = self.pre_quantized_forward(*inputs) + if self.training: + raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode") + + in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs) # Quantize inputs inputs_q = [] for idx, input in enumerate(inputs): - input_q = linear_quantize_clamp(input.data, in_scales[idx], self.acts_min_q_val, self.acts_max_q_val, - inplace=False) + input_q = linear_quantize_clamp(input.data, in_scales[idx], in_zero_points[idx], + self.acts_min_q_val, self.acts_max_q_val, inplace=False) inputs_q.append(torch.autograd.Variable(input_q)) # Forward through wrapped module - accum = self.wrapped_module.forward(*inputs_q) - clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) + accum = self.quantized_forward(*inputs_q) # Re-quantize accumulator to quantized output range - requant_scale, out_scale = self.post_quantized_forward(accum) - out_q = linear_quantize_clamp(accum.data, requant_scale, self.acts_min_q_val, self.acts_max_q_val, inplace=True) + out_scale, out_zero_point = self.get_output_quantization_params(accum) + requant_scale, requant_zero_point = self.get_accum_to_output_re_quantization_params(out_scale, out_zero_point) + out_q = linear_quantize_clamp(accum.data, requant_scale, requant_zero_point, + self.acts_min_q_val, self.acts_max_q_val, inplace=True) + + if not self._dequant_out: + return torch.autograd.Variable(out_q) # De-quantize back to FP32 - out_f = linear_dequantize(out_q, out_scale, inplace=True) + out_f = linear_dequantize(out_q, out_scale, out_zero_point, inplace=True) return torch.autograd.Variable(out_f) - def pre_quantized_forward(self, *inputs): + def get_inputs_quantization_params(self, *inputs): """ - Calculate input scale factors and perform any action required before quantization of inputs. + Calculate input quantization parameters (scale and zero-point) Should be overridden by all subclasses :param inputs: Current input tensors passed to forward method - :return: List of scale factors per input + :return: Tuple of 2 lists - list of scales per input and list of zero-point per input + """ + raise NotImplementedError + + def quantized_forward(self, *inputs_q): + """ + Perform forward pass with quantized inputs and return quantized outputs + + :param inputs_q: Tensor (or list of tensors) with quantized input values + :return: Tensor with quantized output values """ raise NotImplementedError - def post_quantized_forward(self, accumulator): + def get_output_quantization_params(self, accumulator): """ - Calculate re-quantization scale factor (for converting the intermediate integer accumulator to output range), - and output scale factor. + Calculate quantization parameters (scale and zero-point) for the output. + This is used for: + * Calculating the accumulator-to-output re-quantization parameters + (see get_accum_to_output_re_quantization_params) + * De-quantizing the output back to FP32 + + Should be overridden by all subclasses :param accumulator: Tensor with accumulator values - :return: Tuple of (re-quantization scale factor, output scale factor) + :return: Tuple of scale and zero-point + """ + raise NotImplementedError + + def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): + """ + Calculate quantization parameters (scale and zero-point) for re-quantization, that is: + Converting the intermediate integer accumulator to the output range + + Should be overridden by all subclasses + + :param output_scale: Output scale factor + :param output_zero_point: Output zero-point + :return: Tuple of scale and zero-point """ raise NotImplementedError @@ -98,73 +187,131 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): Linear range-based quantization wrappers for layers with weights and bias (namely torch.nn.ConvNd and torch.nn.Linear) + Assume: + + x_q = round(scale_x * x_f) - zero_point_x + + Hence: + + x_f = 1/scale_x * x_q + zero_point_x + + (And the same for y_q, w_q and b_q) + + So, we get: (use "zp" as abbreviation for zero_point) + + y_f = x_f * w_f + b_f + + y_q = scale_y * y_f + zp_y = scale_y * (x_f * w_f + b_f) + zp_y = + + scale_y scale_x * scale_w + = ------------------- * ((x_q + zp_x) * (w_q + zp_w) + ------------------- * (b_q + zp_b)) + zp_y + scale_x * scale_w scale_b + Args: wrapped_module (torch.nn.Module): Module to be wrapped num_bits_acts (int): Number of bits used for inputs and output quantization num_bits_params (int): Number of bits used for parameters (weights and bias) quantization num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results + mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned) clip_acts (bool): See RangeLinearQuantWrapper """ - def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32, clip_acts=False): - super(RangeLinearQuantParamLayerWrapper, self).__init__(wrapped_module, num_bits_acts, - num_bits_accum, clip_acts) + def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32, + mode=LinearQuantMode.SYMMETRIC, clip_acts=False, per_channel_wts=False): + super(RangeLinearQuantParamLayerWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode, + clip_acts) if not isinstance(wrapped_module, (nn.Conv2d, nn.Linear)): raise ValueError(self.__class__.__name__ + ' can wrap only Conv2D and Linear modules') self.num_bits_params = num_bits_params - self.params_min_q_val, self.params_max_q_val = get_quantized_range(num_bits_params, signed=True) + self.per_channel_wts = per_channel_wts + + self.params_min_q_val, self.params_max_q_val = get_quantized_range( + num_bits_params, signed=mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) # Quantize weights - overwrite FP32 weights - self.w_scale = symmetric_linear_quantization_scale_factor(num_bits_params, - get_tensor_max_abs(wrapped_module.weight)) - linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.params_min_q_val, self.params_max_q_val, - inplace=True) + w_scale, w_zero_point = _get_tensor_quantization_params(wrapped_module.weight, num_bits_params, self.mode, + per_channel=per_channel_wts) + + self.register_buffer('w_scale', w_scale) + self.register_buffer('w_zero_point', w_zero_point) + linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.params_min_q_val, + self.params_max_q_val, inplace=True) # Quantize bias self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None if self.has_bias: - self.b_scale = symmetric_linear_quantization_scale_factor(num_bits_params, - get_tensor_max_abs(wrapped_module.bias)) - base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, + b_scale, b_zero_point = _get_tensor_quantization_params(wrapped_module.bias, num_bits_params, self.mode) + self.register_buffer('b_scale', b_scale) + self.register_buffer('b_zero_point', b_zero_point) + base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point, self.params_min_q_val, self.params_max_q_val) # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor self.register_buffer('base_b_q', base_b_q) + self.current_in_scale = 1 + self.current_in_zero_point = 0 self.current_accum_scale = 1 - def pre_quantized_forward(self, input): - super(RangeLinearQuantParamLayerWrapper, self).forward(input) + def get_inputs_quantization_params(self, input): + self.current_in_scale, self.current_in_zero_point = _get_tensor_quantization_params(input, self.num_bits_acts, + self.mode, + clip=self.clip_acts) + return [self.current_in_scale], [self.current_in_zero_point] + + def quantized_forward(self, input_q): + # See class documentation for quantized calculation details. + + self.current_accum_scale = self.current_in_scale * self.w_scale + if self.per_channel_wts: + self.current_accum_scale = self.current_accum_scale.squeeze(dim=-1) - def pre_quantized_forward(self, input): - in_scale = symmetric_linear_quantization_scale_factor(self.num_bits_acts, - self.acts_sat_val_func(input)) - self.current_accum_scale = in_scale * self.w_scale if self.has_bias: - # Re-quantize bias to match x * w scale: b_q' = (in_scale * w_scale / b_scale) * b_q - self.wrapped_module.bias.data = linear_quantize_clamp(self.base_b_q, self.current_accum_scale / self.b_scale, + # Re-quantize bias to match x * w scale: b_q' = (in_scale * w_scale / b_scale) * (b_q + b_zero_point) + self.wrapped_module.bias.data = linear_quantize_clamp(self.base_b_q + self.b_zero_point, + self.current_accum_scale / self.b_scale, 0, self.accum_min_q_val, self.accum_max_q_val) - return [in_scale] - def post_quantized_forward(self, accumulator): - accum_max_abs = self.acts_sat_val_func(accumulator) - y_f_max_abs = accum_max_abs / self.current_accum_scale - out_scale = symmetric_linear_quantization_scale_factor(self.num_bits_acts, y_f_max_abs) - requant_scale = out_scale / self.current_accum_scale - return requant_scale, out_scale + # Note the main terms within the summation is: + # (x_q + zp_x) * (w_q + zp_w) + # In a performance-optimized solution, we would expand the parentheses and perform the computation similar + # to what is described here: + # https://github.com/google/gemmlowp/blob/master/doc/low-precision.md#efficient-handling-of-offsets + # However, for now we're more concerned with simplicity rather than speed. So we'll just add the zero points + # to the input and weights and pass those to the wrapped model. Functionally, since at this point we're + # dealing solely with integer values, the results are the same either way. + + if self.mode != LinearQuantMode.SYMMETRIC: + input_q += self.current_in_zero_point + self.wrapped_module.weight.data += self.w_zero_point + + accum = self.wrapped_module.forward(input_q) + clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) + + if self.mode != LinearQuantMode.SYMMETRIC: + self.wrapped_module.weight.data -= self.w_zero_point + return accum + + def get_output_quantization_params(self, accumulator): + y_f = accumulator / self.current_accum_scale + return _get_tensor_quantization_params(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts) + + def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): + return output_scale / self.current_accum_scale, output_zero_point def extra_repr(self): - tmpstr = 'wrapped_module: ' + self.wrapped_module.__repr__() + '\n' - tmpstr += 'num_bits_acts={0}, num_bits_params={1}, num_bits_accum={2}'.format(self.num_bits_acts, - self.num_bits_params, - self.num_bits_accum) + '\n' - tmpstr += 'clip_acts={0}'.format(self.clip_acts) + tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1]) + tmpstr += 'num_bits_acts={0}, num_bits_params={1}, num_bits_accum={2}, '.format(self.num_bits_acts, + self.num_bits_params, + self.num_bits_accum) + tmpstr += 'clip_acts={0}, per_channel_wts={1}'.format(self.clip_acts, self.per_channel_wts) return tmpstr -class SymmetricLinearQuantizer(Quantizer): +class PostTrainLinearQuantizer(Quantizer): """ - Applies symmetric, range-based linear quantization to a model. + Applies range-based linear quantization to a model. + This quantizer is expected to be executed at evaluation only, on a pre-trained model Currently, the following Modules are supported: torch.nn.Conv2d, torch.nn.Linear Args: @@ -175,22 +322,173 @@ class SymmetricLinearQuantizer(Quantizer): A common practice is to not clip the activations of the last layer before softmax. Applicable only if clip_acts is True. """ - def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, - clip_acts=False, no_clip_layers=[]): - super(SymmetricLinearQuantizer, self).__init__(model, bits_activations=bits_activations, + def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, mode=LinearQuantMode.SYMMETRIC, + clip_acts=False, no_clip_layers=[], per_channel_wts=False): + super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_parameters, train_with_fp_copy=False) + + mode = verify_mode(mode) self.model.quantizer_metadata = {'type': type(self), 'params': {'bits_activations': bits_activations, - 'bits_parameters': bits_parameters}} + 'bits_parameters': bits_parameters, + 'bits_accum': bits_accum, + 'mode': str(mode).split('.')[1], 'clip_acts': clip_acts, + 'no_clip_layers': no_clip_layers, + 'per_channel_wts': per_channel_wts}} def replace_fn(module, name, qbits_map): - clip = self.clip_acts and name not in no_clip_layers + clip = self.clip_acts and distiller.utils.normalize_module_name(name) not in no_clip_layers return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts, - num_bits_accum=self.bits_accum, clip_acts=clip) + num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip, + per_channel_wts=per_channel_wts) self.clip_acts = clip_acts self.no_clip_layers = no_clip_layers self.bits_accum = bits_accum + self.mode = mode self.replacement_factory[nn.Conv2d] = replace_fn self.replacement_factory[nn.Linear] = replace_fn + + +############################################################################### +# Quantization-aware training +############################################################################### + + +def update_ema(biased_ema, value, decay, step): + biased_ema = biased_ema * decay + (1 - decay) * value + unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction + return unbiased_ema + + +def inputs_quantize_wrapped_forward(self, input): + input = self.inputs_quant(input) + return self.original_forward(input) + + +class FakeLinearQuantization(nn.Module): + def __init__(self, num_bits=8, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, dequantize=True, inplace=False): + super(FakeLinearQuantization, self).__init__() + + self.num_bits = num_bits + self.mode = mode + self.dequantize = dequantize + self.inplace = inplace + + # We track activations ranges with exponential moving average, as proposed by Jacob et al., 2017 + # https://arxiv.org/abs/1712.05877 + # We perform bias correction on the EMA, so we keep both unbiased and biased values and the iterations count + # For a simple discussion of this see here: + # https://www.coursera.org/lecture/deep-neural-network/bias-correction-in-exponentially-weighted-averages-XjuhD + self.register_buffer('ema_decay', torch.tensor(ema_decay)) + self.register_buffer('tracked_min_biased', torch.zeros(1)) + self.register_buffer('tracked_min', torch.zeros(1)) + self.register_buffer('tracked_max_biased', torch.zeros(1)) + self.register_buffer('tracked_max', torch.zeros(1)) + self.register_buffer('iter_count', torch.zeros(1)) + self.register_buffer('scale', torch.ones(1)) + self.register_buffer('zero_point', torch.zeros(1)) + + def forward(self, input): + with torch.no_grad(): + current_min, current_max = get_tensor_min_max(input) + self.iter_count = self.iter_count + 1 + self.tracked_min = update_ema(self.tracked_min_biased, current_min, self.ema_decay, self.iter_count) + self.tracked_max = update_ema(self.tracked_max_biased, current_max, self.ema_decay, self.iter_count) + + if self.mode == LinearQuantMode.SYMMETRIC: + max_abs = max(abs(self.tracked_min), abs(self.tracked_max)) + actual_min, actual_max = -max_abs, max_abs + self.scale, self.zero_point = symmetric_linear_quantization_params(self.num_bits, max_abs) + else: + actual_min, actual_max = self.tracked_min, self.tracked_max + signed = self.mode == LinearQuantMode.ASYMMETRIC_SIGNED + self.scale, self.zero_point = asymmetric_linear_quantization_params(self.num_bits, self.tracked_min, + self.tracked_max, signed=signed) + + input = clamp(input, actual_min.item(), actual_max.item(), False) + input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, False) + + return input + + def extra_repr(self): + mode_str = str(self.mode).split('.')[1] + return 'mode={0}, num_bits={1}, ema_decay={2:.4f})'.format(mode_str, self.num_bits, self.ema_decay) + + +class QuantAwareTrainRangeLinearQuantizer(Quantizer): + def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), + quantize_bias=True, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False, + quantize_inputs=True, num_bits_inputs=None): + super(QuantAwareTrainRangeLinearQuantizer, self).__init__(model, optimizer=optimizer, + bits_activations=bits_activations, + bits_weights=bits_weights, + bits_overrides=bits_overrides, + quantize_bias=quantize_bias, + train_with_fp_copy=True) + + mode = verify_mode(mode) + + self.model.quantizer_metadata['params']['mode'] = str(mode).split('.')[1] + self.model.quantizer_metadata['params']['ema_decay'] = ema_decay + self.model.quantizer_metadata['params']['per_channel_wts'] = per_channel_wts + self.model.quantizer_metadata['params']['quantize_inputs'] = quantize_inputs + + # Keeping some parameters for input quantization + self.quantize_inputs = quantize_inputs + if num_bits_inputs is not None: + self.num_bits_inputs = num_bits_inputs + else: + self.num_bits_inputs = bits_activations + self.mode = mode + self.decay = ema_decay + self.per_channel_wts = per_channel_wts + + def linear_quantize_param(param_fp, param_meta): + perch = per_channel_wts and param_fp.dim() in [2, 4] + with torch.no_grad(): + scale, zero_point = _get_tensor_quantization_params(param_fp, param_meta.num_bits, mode, + per_channel=perch) + m = param_meta.module + setattr(m, param_meta.q_attr_name + '_scale', scale) + setattr(m, param_meta.q_attr_name + '_zero_point', zero_point) + out = LinearQuantizeSTE.apply(param_fp, scale, zero_point, True, False) + return out + + def relu_replace_fn(module, name, qbits_map): + bits_acts = qbits_map[name].acts + if bits_acts is None: + return module + return nn.Sequential(module, FakeLinearQuantization(bits_acts, mode, ema_decay, dequantize=True, + inplace=module.inplace)) + + self.param_quantization_fn = linear_quantize_param + + self.replacement_factory[nn.ReLU] = relu_replace_fn + + def _prepare_model_impl(self): + super(QuantAwareTrainRangeLinearQuantizer, self)._prepare_model_impl() + + if self.quantize_inputs: + if isinstance(self.model, nn.DataParallel): + m = self.model.module + else: + m = self.model + + m.inputs_quant = FakeLinearQuantization(self.num_bits_inputs, self.mode, self.decay, + dequantize=True, inplace=False) + m.__class__.original_forward = m.__class__.forward + m.__class__.forward = inputs_quantize_wrapped_forward + + # Prepare scale and zero point buffers in modules where parameters are being quantized + # We're calculating "dummy" scale and zero point just to get their dimensions + for ptq in self.params_to_quantize: + m = ptq.module + param_fp = getattr(m, ptq.fp_attr_name) + perch = self.per_channel_wts and param_fp.dim() in [2, 4] + with torch.no_grad(): + scale, zero_point = _get_tensor_quantization_params(param_fp, ptq.num_bits, self.mode, + per_channel=perch) + m.register_buffer(ptq.q_attr_name + '_scale', torch.ones_like(scale)) + m.register_buffer(ptq.q_attr_name + '_zero_point', torch.zeros_like(zero_point)) diff --git a/docs-src/docs/algo_quantization.md b/docs-src/docs/algo_quantization.md index d2fbc065266dfd9284354b9db509bc9fd80f110d..2c22aaffbfa20e1da62d8d1ce15e778f62122911 100644 --- a/docs-src/docs/algo_quantization.md +++ b/docs-src/docs/algo_quantization.md @@ -1,6 +1,104 @@ # Quantization Algorithms -The following quantization methods are currently implemented in Distiller: +**Note:** +For any of the methods below that require quantization-aware training, please see [here](schedule.md#quantization) for details on how to invoke it using Distiller's scheduling mechanism. + +## Range-Based Linear Quantization + +Let's break down the terminology we use here: + +- **Linear:** Means a float value is quantized by multiplying with a numeric constant (the **scale factor**). +- **Range-Based:**: Means that in order to calculate the scale factor, we look at the actual range of the tensor's values. In the most naive implementation, we use the actual min/max values of the tensor. Alternatively, we use some derivation based on the tensor's range / distribution to come up with a narrower min/max range, in order to remove possible outliers. This is in contrast to the other methods described here, which we could call **clipping-based**, as they impose an explicit clipping function on the tensors (using either a hard-coded value or a learned value). + +### Asymmetric vs. Symmetric + +In this method we can use two modes - **asymmetric** and **symmetric**. + +#### Asymmetric Mode + +<p align="center"> + <img src="../imgs/quant_asym.png"/> +</p> + +In **asymmetric** mode, we map the min/max in the float range to the min/max of the integer range. This is done by using a **zero-point** (also called *quantization bias*, or *offset*) in addition to the scale factor. + +Let us denote the original floating-point tensor by \(x_f\), the quantized tensor by \(x_q\), the scale factor by \(q_x\), the zero-point by \(zp_x\) and the number of bits used for quantization by \(n\). Then, we get: + +\[x_q = round\left ((x_f - min_{x_f})\underbrace{\frac{2^n - 1}{max_{x_f} - min_{x_f}}}_{q_x} \right) = round(q_x x_f - \underbrace{min_{x_f}q_x)}_{zp_x} = round(q_x x_f - zp_x)\] + +In practice, we actually use \(zp_x = round(min_{x_f}q_x)\). This means that zero is exactly representable by an integer in the quantized range. This is important, for example, for layers that have zero-padding. By rounding the zero-point, we effectively "nudge" the min/max values in the float range a little bit, in order to gain this exact quantization of zero. + +Note that in the derivation above we use unsigned integer to represent the quantized range. That is, \(x_q \in [0, 2^n-1]\). One could use signed integer if necessary (perhaps due to HW considerations). This can be achieved by subtracting \(2^{n-1}\). + +Let's see how a **convolution** or **fully-connected (FC)** layer is quantized in asymmetric mode: (we denote input, output, weights and bias with \(x, y, w\) and \(b\) respectively) + +\[y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q + zp_x}{q_x} \frac{w_q + zp_w}{q_w}} + \frac{b_q + zp_b}{q_b} =\] +\[ = \frac{1}{q_x q_w} \left( \sum { (x_q + zp_x) (w_q + zp_w) + \frac{q_x q_w}{q_b}(b_q + zp_b) } \right)\] + +Therefore: + +\[y_q = round(q_y y_f) = round\left(\frac{q_y}{q_x q_w} \left( \sum { (x_q+zp_x) (w_q+zp_w) + \frac{q_x q_w}{q_b}(b_q+zp_b) } \right) \right) \] + +Notes: + +- We can see that the bias has to be re-scaled to match the scale of the summation. +- In a proper integer-only HW pipeline, we would like our main accumulation term to simply be \(\sum{x_q w_q}\). In order to achieve this, one needs to further develop the expression we derived above. For further details please refer to the [gemmlowp documentation](https://github.com/google/gemmlowp/blob/master/doc/quantization.md#implementation-of-quantized-matrix-multiplication) + +#### Symmetric Mode + +<p align="center"> + <img src="../imgs/quant_sym.png"/> +</p> + +In **symmetric** mode, instead of mapping the exact min/max of the float range to the quantized range, we choose the maximum absolute value between min/max. In addition, we don't use a zero-point. So, the floating-point range we're effectively quantizing is symmetric with respect to zero, and so is the quantized range. + +Using the same notations as above, we get: + +\[x_q = round\left (x_f \underbrace{\frac{2^{n-1} - 1}{\max|x_f|}}_{q_x} \right) = round(q_x x_f)\] + +Again, let's see how a **convolution** or **fully-connected (FC)** layer is quantized, this time in symmetric mode: + +\[y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q}{q_x} \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right)\] + +Therefore: + +\[y_q = round(q_y y_f) = round\left(\frac{q_y}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right) \right) \] + +#### Comparing the Two Modes + +The main trade-off between these two modes is simplicity vs. utilization of the quantized range. + +- When using asymmetric quantization, the quantized range is fully utilized. That is because we exactly map the min/max values from the float range to the min/max of the quantized range. Using symmetric mode, if the float range is biased towards one side, could result in a quantized range where significant dynamic range is dedicated to values that we'll never see. The most extreme example of this is after ReLU, where the entire tensor is positive. Quantizing it in symmetric mode means we're effectively losing 1 bit. +- On the other hand, if we look at the derviations for convolution / FC layers above, we can see that the actual implementation of symmetric mode is much simpler. In asymmetric mode, the zero-points require additional logic in HW. The cost of this extra logic in terms of latency and/or power and/or area will of course depend on the exact implementation. + +### Other Features + +- **Removing Outliers:** As discussed [here](quantization.md#outliers-removal), in some cases the float range of activations contains outliers. Spending dynamic range on these outliers hurts our ability ro represent the values we actually care about accurately. + <p align="center"> + <img src="../imgs/quant_clipped.png"/> + </p> + Currently, Distiller supports clipping of activations with averaging during post-training quantization. That is - for each batch, instead of calculating global min/max values, an average of the min/max values of each sample in the batch. +- **Scale factor scope:** For weight tensors, Distiller supports per-channel quantization (per output channel). + +### Implementation in Distiller + +#### Post-Training + +For post-training quantization, currently **convolution** and **FC** are supported using this method. + +- They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the `RangeLinearQuantParamLayerWrapper` class. +- All other layers are unaffected and are executed using their original FP32 implementation. +- To automatically transform an existing model to a quantized model using this method, use the `PostTrainLinearQuantizer` class. For an example of how to do this, see the [`compress_classifier.py`](https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py). This sample also exposes command line arguments to invoke post-training quantization. For details see [here](usage.md#post-training-quantization). +- For weights and bias the scale factor and zero-point are determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online"). The calculated quantization parameters are store as buffers within the module, so they are automatically serialized when the model checkpoint is saved. +- As this is post-training, using it with number of bits < 8 is likely to lead to severe accuracy degradation for any non-trivial workload. + +#### Quantization-Aware Training + +To apply range-based linear quantization in training, use the `QuantAwareTrainRangeLinearQuantizer` class. As it is now, it will apply weights quantization to convolution and FC modules. For activations quantization, it will insert instances `FakeLinearQuantization` module after ReLUs. This module follows the methodology described in [Benoit et al., 2018](http://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html) and uses exponential moving averages to track activation ranges. + +Similarly to post-training, the calculated quantization parameters (scale factors, zero-points, tracked activation ranges) are stored as buffers within their respective modules, so they're saved when a checkpoint is created. + +Note that converting from a quantization-aware training model to a post-training quantization model is not yet supported. Such a conversion will use the activation ranges tracked during training, so additional offline or online calculation of quantization parameters will not be required. ## DoReFa @@ -22,7 +120,7 @@ Now we can use \(quantize_k\) to get quantized weight values, as follows: \[w_q = 2 quantize_k \left( f(w_f) \right) - 1\] -This method requires training the model with quantization, as discussed [here](quantization.md#training-with-quantization). Use the `DorefaQuantizer` class to transform an existing model to a model suitable for training with quantization using DoReFa. +This method requires training the model with quantization-aware training, as discussed [here](quantization.md#quantization-aware-training). Use the `DorefaQuantizer` class to transform an existing model to a model suitable for training with quantization using DoReFa. ### Notes: @@ -35,7 +133,7 @@ This method requires training the model with quantization, as discussed [here](q This method is similar to DoReFa, but the upper clipping values, \(\alpha\), of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \(\alpha\) is shared per layer. -This method requires training the model with quantization, as discussed [here](quantization/#training-with-quantization). Use the `PACTQuantizer` class to transform an existing model to a model suitable for training with quantization using PACT. +This method requires training the model with quantization-aware training, as discussed [here](quantization/#quantization-aware-training). Use the `PACTQuantizer` class to transform an existing model to a model suitable for training with quantization using PACT. ## WRPN @@ -51,33 +149,9 @@ Weights are clipped to \([-1, 1]\) and quantized as follows: Note that \(k-1\) bits are used to quantize weights, leaving one bit for sign. -This method requires training the model with quantization, as discussed [here](quantization/#training-with-quantization). Use the `WRPNQuantizer` class to transform an existing model to a model suitable for training with quantization using WRPN. +This method requires training the model with quantization-aware training, as discussed [here](quantization/#quantization-aware-training). Use the `WRPNQuantizer` class to transform an existing model to a model suitable for training with quantization using WRPN. ### Notes: - The paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of `WRPNQuantizer` at the moment. To experiment with this, modify your model implementation to have wider layers. -- The paper defines special handling for binary weights which isn't supported in Distiller yet. - -## Symmetric Linear Quantization - -In this method, a float value is quantized by multiplying with a numeric constant (the **scale factor**), hence it is **Linear**. We use a signed integer to represent the quantized range, with no quantization bias (or "offset") used. As a result, the floating-point range considered for quantization is **symmetric** with respect to zero. -In the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers). -Let us denote the original floating-point tensor by \(x_f\), the quantized tensor by \(x_q\), the scale factor by \(q_x\) and the number of bits used for quantization by \(n\). Then, we get: -\[q_x = \frac{2^{n-1}-1}{\max|x|}\] -\[x_q = round(q_x x_f)\] -(The \(round\) operation is round-to-nearest-integer) - -Let's see how a **convolution** or **fully-connected (FC)** layer is quantized using this method: (we denote input, output, weights and bias with \(x, y, w\) and \(b\) respectively) -\[y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q}{q_x} \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right)\] -\[y_q = round(q_y y_f) = round\left(\frac{q_y}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right) \right) \] -Note how the bias has to be re-scaled to match the scale of the summation. - -### Implementation - -We've implemented **convolution** and **FC** using this method. - -- They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the `RangeLinearQuantParamLayerWrapper` class. -- All other layers are unaffected and are executed using their original FP32 implementation. -- To automatically transform an existing model to a quantized model using this method, use the `SymmetricLinearQuantizer` class. -- For weights and bias the scale factor is determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online"). -- **Important note:** Currently, this method is implemented as **inference only**, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with \(n < 8\) is likely to lead to severe accuracy degradation for any non-trivial workload. \ No newline at end of file +- The paper defines special handling for binary weights which isn't supported in Distiller yet. \ No newline at end of file diff --git a/docs-src/docs/design.md b/docs-src/docs/design.md index f6f2be507fe625a0cf83fea9ad28f9b3fb299f5a..3849caa715b790bcdc887f39b06a7a43a9d7f2ae 100755 --- a/docs-src/docs/design.md +++ b/docs-src/docs/design.md @@ -74,11 +74,11 @@ To execute the model transformation, call the `prepare_model` function of the `Q The `Quantizer` class also provides an API to quantize the weights of all layers at once. To use it, the `param_quantization_fn` attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the `Quantizer` class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the `quantize_params` function can be called, which will iterate over all parameters and quantize them using `params_quantization_fn`. -### Training with Quantization +### Quantization-Aware Training -The `Quantizer` class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios: +The `Quantizer` class supports quantization-aware training, that is - training with quantization in the loop. This requires handling of a couple of flows / scenarios: -1. Maintaining a full precision copy of the weights, as described [here](quantization.md#training-with-quantization). This is enabled by setting `train_with_fp_copy=True` in the `Quantizer` constructor. At model transformation, in each module that has parameters that should be quantized, a new `torch.nn.Parameter` is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module **is not** created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": +1. Maintaining a full precision copy of the weights, as described [here](quantization.md#quantization-aware-training). This is enabled by setting `train_with_fp_copy=True` in the `Quantizer` constructor. At model transformation, in each module that has parameters that should be quantized, a new `torch.nn.Parameter` is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module **is not** created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": 1. The existing `torch.nn.Parameter`, e.g. `weights`, is replaced by a `torch.nn.Parameter` named `float_weight`. 2. To maintain the existing functionality of the module, we then register a `buffer` in the module with the original name - `weights`. diff --git a/docs-src/docs/imgs/quant_asym.png b/docs-src/docs/imgs/quant_asym.png new file mode 100644 index 0000000000000000000000000000000000000000..9dc43817a595687119a5d6e3fee1a905be4c308e Binary files /dev/null and b/docs-src/docs/imgs/quant_asym.png differ diff --git a/docs-src/docs/imgs/quant_clipped.png b/docs-src/docs/imgs/quant_clipped.png new file mode 100644 index 0000000000000000000000000000000000000000..a5cfa215a0deb90217e777a4a485ee8b577f25d1 Binary files /dev/null and b/docs-src/docs/imgs/quant_clipped.png differ diff --git a/docs-src/docs/imgs/quant_sym.png b/docs-src/docs/imgs/quant_sym.png new file mode 100644 index 0000000000000000000000000000000000000000..7a1d390c2abb8d537de6763b2fbb271778c36605 Binary files /dev/null and b/docs-src/docs/imgs/quant_sym.png differ diff --git a/docs-src/docs/quantization.md b/docs-src/docs/quantization.md index 6db573c25ebe715a810049123e04bd8717f3a879..80c12f550ff54ac1ff12e295742bee2dfe39e982 100644 --- a/docs-src/docs/quantization.md +++ b/docs-src/docs/quantization.md @@ -40,15 +40,18 @@ As mentioned above, a scale factor is used to adapt the dynamic range of the ten - **Offline** means gathering activations statistics before deploying the model, either during training or by running a few "calibration" batches on the trained FP32 model. Based on these gathered statistics, the scaled factors are calculated and are fixed once the model is deployed. This method has the risk of encountering values outside the previously observed ranges at runtime. These values will be clipped, which might lead to accuracy degradation. - **Online** means calculating the min/max values for each tensor dynamically during runtime. In this method clipping cannot occur, however the added computation resources required to calculate the min/max values at runtime might be prohibitive. -It is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. Statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible ([Migacz, 2017](#migacz-2017)). +<div id="outliers-removal"></div> +It is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. A simple method which can yield nice results is to simply use an average of the observed min/max values instead of the actual values. Alternatively, statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible ([Migacz, 2017](#migacz-2017)). Going further, [Banner et al., 2018](#banner-et-al-2018) have proposed a method for analytically computing the clipping value under certain conditions. Another possible optimization point is **scale-factor scope**. The most common way is use a single scale-factor per-layer, but it is also possible to calculate a scale-factor per-channel. This can be beneficial if the weight distributions vary greatly between channels. +When used to directly quantize a model without re-training, as described so far, this method is commonly referred to as **post-training quantization**. However, recent publications have shown that there are cases where post-training quantization to INT8 doesn't preserve accuracy ([Benoit et al., 2018](#benoit-et-al-2018), [Krishnamoorthi, 2018](#krishnamoorthi-2018)). Namely, smaller models such as MobileNet seem to not respond as well to post-training quantization, presumabley due to their smaller representational capacity. In such cases, [quantization-aware training](#quantization-aware-training) is used. + ## "Aggressive" Quantization: INT4 and Lower Naively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy: -- **Training / Re-Training**: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the [next section](#training-with-quantization). +- **Training / Re-Training**: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the [next section](#quantization-aware-training). [Zhou S et al., 2016](#zhou-et-al-2016) have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods *require* a trained FP32 model, either as a starting point ([Zhou A et al., 2017](#zhou-et-al-2017)), or as a teacher network in a knowledge distillation training setup (see [here](knowledge_distillation.md#combining)). - **Replacing the activation function**: The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used ([Zhou S et al., 2016](#zhou-et-al-2016), [Mishra et al., 2018](#mishra-et-al-2018)). Another method learns the clipping value per layer, with better results ([Choi et al., 2018](#choi-et-al-2018)). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above). - **Modifying network structure**: [Mishra et al., 2018](#mishra-et-al-2018) try to compensate for the loss of information due to quantization by using wider layers (more channels). [Lin et al., 2017](#lin-et-al-2017) proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different "base", covering a larger dynamic range overall. @@ -56,17 +59,17 @@ Naively quantizing a FP32 model to INT4 and lower usually incurs significant acc - **Iterative quantization**: Most methods quantize the entire model at once. [Zhou A et al., 2017](#zhou-et-al-2017) employ an iterative method, which starts with a trained FP32 baseline, and quantizes only a portion of the model at the time followed by several epochs of re-training to recover the accuracy loss from quantization. - **Mixed Weights and Activations Precision**: It has been observed that activations are more sensitive to quantization than weights ([Zhou S et al., 2016](#zhou-et-al-2016)). Hence it is not uncommon to see experiments with activations quantized to a higher precision compared to weights. Some works have focused solely on quantizing weights, keeping the activations at FP32 ([Li et al., 2016](#li-et-al-2016), [Zhu et al., 2016](#zhu-et-al-2016)). -## Training with Quantization +## Quantization-Aware Training -As mentioned above, in order to minimize the loss of accuracy from "aggressive" quantization, many methods that target INT4 and lower involve training the model in a way that considers the quantization. This means training with quantization of weights and activations "baked" into the training procedure. The training graph usually looks like this: +As mentioned above, in order to minimize the loss of accuracy from "aggressive" quantization, many methods that target INT4 and lower (and in some cases for INT8 as well) involve training the model in a way that considers the quantization. This means training with quantization of weights and activations "baked" into the training procedure. The training graph usually looks like this: - + A full precision copy of the weights is maintained throughout the training process ("weights_fp" in the diagram). Its purpose is to accumulate the small changes from the gradients without loss of precision (Note that the quantization of the weights is an integral part of the training graph, meaning that we back-propagate through it as well). Once the model is trained, only the quantized weights are used for inference. In the diagram we show "layer N" as the conv + batch-norm + activation combination, but the same applies to fully-connected layers, element-wise operations, etc. During training, the operations within "layer N" can still run in full precision, with the "quantize" operations in the boundaries ensuring discrete-valued weights and activations. This is sometimes called "simulated quantization". ### Straight-Through Estimator -An important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severly hinder the learning process. An approximation commonly used to overcome this issue is the "straight-through estimator" (STE) ([Hinton et al., 2012](#hinton-et-al-2012), [Bengio, 2013](#bengio-et-al-2013)), which simply passes the gradient through these functions as-is. +An important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severely hinder the learning process. An approximation commonly used to overcome this issue is the "straight-through estimator" (STE) ([Hinton et al., 2012](#hinton-et-al-2012), [Bengio, 2013](#bengio-et-al-2013)), which simply passes the gradient through these functions as-is. ## References <div id="dally-2015"></div> @@ -94,7 +97,7 @@ An important question in this context is how to back-propagate through the quant **Asit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr**. WRPN: Wide Reduced-Precision Networks. [ICLR, 2018](https://openreview.net/forum?id=B1ZvaaeAZ) <div id="choi-et-al-2018"></div> -**Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan**. PACT: Parameterized Clipping Activation for Quantized Neural Networks. [2018](https://openreview.net/forum?id=By5ugjyCb) +**Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan**. PACT: Parameterized Clipping Activation for Quantized Neural Networks. [arxiv:1805.06085](https://arxiv.org/abs/1805.06085) <div id="lin-et-al-2017"></div> **Xiaofan Lin, Cong Zhao and Wei Pan**. Towards Accurate Binary Convolutional Neural Network. [NIPS, 2017](http://papers.nips.cc/paper/6638-towards-accurate-binary-convolutional-neural-network) @@ -109,7 +112,16 @@ An important question in this context is how to back-propagate through the quant **Chenzhuo Zhu, Song Han, Huizi Mao and William J. Dally**. Trained Ternary Quantization. [arxiv:1612.01064](https://arxiv.org/abs/1612.01064) <div id="bengio-et-al-2013"></div> -**Yoshua Bengio, Nicholas Leonard and Aaron Courville**. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. [arxiv:1308.3432, 2013](https://arxiv.org/abs/1308.3432) +**Yoshua Bengio, Nicholas Leonard and Aaron Courville**. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. [arxiv:1308.3432](https://arxiv.org/abs/1308.3432) <div id="hinton-et-al-2012"></div> **Geoffrey Hinton, Nitish Srivastava, Kevin Swersky, Tijmen Tieleman and Abdelrahman Mohamed**. Neural Networks for Machine Learning. [Coursera, video lectures, 2012](https://www.coursera.org/learn/neural-networks) + +<div id="benoit-et-al-2018"></div> +**Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam and Dmitry Kalenichenko**. Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [ECCV, 2018](http://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html) + +<div id="krishnamoorthi-2018"></div> +**Raghuraman Krishnamoorthi**. Quantizing deep convolutional networks for efficient inference: A whitepaper [arxiv:1806.08342](https://arxiv.org/abs/1806.08342) + +<div id="banner-et-al-2018"></div> +**Ron Banner, Yury Nahshan, Elad Hoffer and Daniel Soudry**. ACIQ: Analytical Clipping for Integer Quantization of neural networks [arxiv:1810.05723](https://arxiv.org/abs/1810.05723) diff --git a/docs-src/docs/schedule.md b/docs-src/docs/schedule.md index f82e136c2a131a2aab4237c0ca95326d3b53a7c8..aeb1fc71cc0c3473ba8a41a657d8e7bb7e8ce013 100755 --- a/docs-src/docs/schedule.md +++ b/docs-src/docs/schedule.md @@ -240,8 +240,8 @@ policies: ## Quantization -Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the `Quantizer` class (see details [here](design.md#quantization)). -**Notes**: Only a single quantizer instance may be defined. +Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the `Quantizer` class (see details [here](design.md#quantization)). **Note** that only a single quantizer instance may be defined per YAML. + Let's see an example: ``` @@ -299,7 +299,7 @@ bits_overrides: - **Important Note**: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed **before** the broad one. -The `QuantizationPolicy`, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the `prepare_model()` function of the `Quantizer` when it's initialized, followed by the first call to `quantize_params()`. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the `quantize_params()` function again. +The `QuantizationPolicy`, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the `prepare_model()` function of the `Quantizer` when it's initialized, followed by the first call to `quantize_params()`. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the `quantize_params()` function again. ``` policies: @@ -310,7 +310,7 @@ policies: frequency: 1 ``` -**Important Note**: As mentioned [here](design.md#training-with-quantization), since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to `prepare_model()` must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a "warm-startup" (or "boot-strapping"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights. +**Important Note**: As mentioned [here](design.md#quantization-aware-training), since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to `prepare_model()` must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a "warm-startup" (or "boot-strapping"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights. ## Knowledge Distillation diff --git a/docs-src/docs/usage.md b/docs-src/docs/usage.md index 530a545fbebc59d583550f117f9807f3f2cc2ce3..f1dc9bbc7e2dda5bc624bcdca57dabd449e5cdfc 100755 --- a/docs-src/docs/usage.md +++ b/docs-src/docs/usage.md @@ -141,39 +141,45 @@ The ```sense``` command-line argument can be set to either ```element``` or ```f There is also a [Jupyter notebook](http://localhost:8888/notebooks/sensitivity_analysis.ipynb) with example invocations, outputs and explanations. -## "Direct" Quantization Without Training -Distiller supports 8-bit quantization of trained modules without re-training (using [Symmetric Linear Quantization](algo_quantization.md#symmetric-linear-quantization)). So, any model (whether pruned or not) can be quantized. -Use the ```--quantize``` command-line flag, together with ```--evaluate``` to evaluate the accuracy of your model after quantization. The following example qunatizes ResNet18 for ImageNet: -``` -$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize --evaluate +## Post-Training Quantization + +Distiller supports post-training quantization of trained modules without re-training (using [Range-Based Linear Quantization](algo_quantization.md#range-based-linear-quantization)). So, any model (whether pruned or not) can be quantized. To invoke post-training quantization, use `--quantize-eval` along with `--evaluate`. Additional arguments are available to control parameters of the quantization: + ``` -Generates: +Arguments controlling quantization at evaluation time("post-training quantization"): + --quantize-eval, --qe + Apply linear quantization to model before evaluation. + Applicable only if --evaluate is also set + --qe-mode QE_MODE, --qem QE_MODE + Linear quantization mode. Choices: asym_s | asym_u | + sym + --qe-bits-acts NUM_BITS, --qeba NUM_BITS + Number of bits for quantization of activations + --qe-bits-wts NUM_BITS, --qebw NUM_BITS + Number of bits for quantization of weights + --qe-bits-accum NUM_BITS + Number of bits for quantization of the accumulator + --qe-clip-acts, --qeca + Enable clipping of activations using min/max values + averaging over batch + --qe-no-clip-layers LAYER_NAME [LAYER_NAME ...], --qencl LAYER_NAME [LAYER_NAME ...] + List of fully-qualified layer names for which not to + clip activations. Applicable only if --qe-clip-acts is + also set + --qe-per-channel, --qepc + Enable per-channel quantization of weights (per output channel) + ``` -Preparing model for quantization ---- test --------------------- -50000 samples (256 per mini-batch) -Test: [ 10/ 195] Loss 0.856354 Top1 79.257812 Top5 92.500000 -Test: [ 20/ 195] Loss 0.923131 Top1 76.953125 Top5 92.246094 -Test: [ 30/ 195] Loss 0.885186 Top1 77.955729 Top5 92.486979 -Test: [ 40/ 195] Loss 0.930263 Top1 76.181641 Top5 92.597656 -Test: [ 50/ 195] Loss 0.931062 Top1 75.726562 Top5 92.906250 -Test: [ 60/ 195] Loss 0.932019 Top1 75.651042 Top5 93.151042 -Test: [ 70/ 195] Loss 0.921287 Top1 76.060268 Top5 93.270089 -Test: [ 80/ 195] Loss 0.932539 Top1 75.986328 Top5 93.100586 -Test: [ 90/ 195] Loss 0.996000 Top1 74.700521 Top5 92.330729 -Test: [ 100/ 195] Loss 1.066699 Top1 73.289062 Top5 91.437500 -Test: [ 110/ 195] Loss 1.100970 Top1 72.574574 Top5 91.001420 -Test: [ 120/ 195] Loss 1.122376 Top1 72.268880 Top5 90.696615 -Test: [ 130/ 195] Loss 1.171726 Top1 71.198918 Top5 90.120192 -Test: [ 140/ 195] Loss 1.191500 Top1 70.797991 Top5 89.902344 -Test: [ 150/ 195] Loss 1.219954 Top1 70.210938 Top5 89.453125 -Test: [ 160/ 195] Loss 1.240942 Top1 69.855957 Top5 89.162598 -Test: [ 170/ 195] Loss 1.265741 Top1 69.342831 Top5 88.807445 -Test: [ 180/ 195] Loss 1.281185 Top1 69.051649 Top5 88.589410 -Test: [ 190/ 195] Loss 1.279682 Top1 69.019326 Top5 88.632812 -==> Top1: 69.130 Top5: 88.732 Loss: 1.276 + +The following example qunatizes ResNet18 for ImageNet: + +```bash +$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize-eval --evaluate ``` +A checkpoint with the quantized model will be dumped in the run directory. It will contain the quantized model parameters (the data type will still be FP32, but the values will be integers). The calculated quantization parameters (scale and zero-point) are stored as well in each quantized layer. + +For more examples of post-training quantization see [here](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant.md) ## Summaries You can use the sample compression application to generate model summary reports, such as the attributes and compute summary report (see screen capture below). diff --git a/docs/algo_quantization/index.html b/docs/algo_quantization/index.html index 55c748c0704501bb567193eed498de5ad8bd2d8e..32efe69af5d540a813018f4f169e96a4a20d5e07 100644 --- a/docs/algo_quantization/index.html +++ b/docs/algo_quantization/index.html @@ -112,14 +112,14 @@ <ul> + <li><a class="toctree-l4" href="#range-based-linear-quantization">Range-Based Linear Quantization</a></li> + <li><a class="toctree-l4" href="#dorefa">DoReFa</a></li> <li><a class="toctree-l4" href="#pact">PACT</a></li> <li><a class="toctree-l4" href="#wrpn">WRPN</a></li> - <li><a class="toctree-l4" href="#symmetric-linear-quantization">Symmetric Linear Quantization</a></li> - </ul> @@ -184,7 +184,89 @@ <div class="section"> <h1 id="quantization-algorithms">Quantization Algorithms</h1> -<p>The following quantization methods are currently implemented in Distiller:</p> +<p><strong>Note:</strong><br /> +For any of the methods below that require quantization-aware training, please see <a href="../schedule/index.html#quantization">here</a> for details on how to invoke it using Distiller's scheduling mechanism.</p> +<h2 id="range-based-linear-quantization">Range-Based Linear Quantization</h2> +<p>Let's break down the terminology we use here:</p> +<ul> +<li><strong>Linear:</strong> Means a float value is quantized by multiplying with a numeric constant (the <strong>scale factor</strong>).</li> +<li><strong>Range-Based:</strong>: Means that in order to calculate the scale factor, we look at the actual range of the tensor's values. In the most naive implementation, we use the actual min/max values of the tensor. Alternatively, we use some derivation based on the tensor's range / distribution to come up with a narrower min/max range, in order to remove possible outliers. This is in contrast to the other methods described here, which we could call <strong>clipping-based</strong>, as they impose an explicit clipping function on the tensors (using either a hard-coded value or a learned value).</li> +</ul> +<h3 id="asymmetric-vs-symmetric">Asymmetric vs. Symmetric</h3> +<p>In this method we can use two modes - <strong>asymmetric</strong> and <strong>symmetric</strong>.</p> +<h4 id="asymmetric-mode">Asymmetric Mode</h4> +<p align="center"> + <img src="../imgs/quant_asym.png"/> +</p> + +<p>In <strong>asymmetric</strong> mode, we map the min/max in the float range to the min/max of the integer range. This is done by using a <strong>zero-point</strong> (also called <em>quantization bias</em>, or <em>offset</em>) in addition to the scale factor.</p> +<p>Let us denote the original floating-point tensor by <script type="math/tex">x_f</script>, the quantized tensor by <script type="math/tex">x_q</script>, the scale factor by <script type="math/tex">q_x</script>, the zero-point by <script type="math/tex">zp_x</script> and the number of bits used for quantization by <script type="math/tex">n</script>. Then, we get:</p> +<p> +<script type="math/tex; mode=display">x_q = round\left ((x_f - min_{x_f})\underbrace{\frac{2^n - 1}{max_{x_f} - min_{x_f}}}_{q_x} \right) = round(q_x x_f - \underbrace{min_{x_f}q_x)}_{zp_x} = round(q_x x_f - zp_x)</script> +</p> +<p>In practice, we actually use <script type="math/tex">zp_x = round(min_{x_f}q_x)</script>. This means that zero is exactly representable by an integer in the quantized range. This is important, for example, for layers that have zero-padding. By rounding the zero-point, we effectively "nudge" the min/max values in the float range a little bit, in order to gain this exact quantization of zero.</p> +<p>Note that in the derivation above we use unsigned integer to represent the quantized range. That is, <script type="math/tex">x_q \in [0, 2^n-1]</script>. One could use signed integer if necessary (perhaps due to HW considerations). This can be achieved by subtracting <script type="math/tex">2^{n-1}</script>.</p> +<p>Let's see how a <strong>convolution</strong> or <strong>fully-connected (FC)</strong> layer is quantized in asymmetric mode: (we denote input, output, weights and bias with <script type="math/tex">x, y, w</script> and <script type="math/tex">b</script> respectively)</p> +<p> +<script type="math/tex; mode=display">y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q + zp_x}{q_x} \frac{w_q + zp_w}{q_w}} + \frac{b_q + zp_b}{q_b} =</script> +<script type="math/tex; mode=display"> = \frac{1}{q_x q_w} \left( \sum { (x_q + zp_x) (w_q + zp_w) + \frac{q_x q_w}{q_b}(b_q + zp_b) } \right)</script> +</p> +<p>Therefore:</p> +<p> +<script type="math/tex; mode=display">y_q = round(q_y y_f) = round\left(\frac{q_y}{q_x q_w} \left( \sum { (x_q+zp_x) (w_q+zp_w) + \frac{q_x q_w}{q_b}(b_q+zp_b) } \right) \right) </script> +</p> +<p>Notes:</p> +<ul> +<li>We can see that the bias has to be re-scaled to match the scale of the summation.</li> +<li>In a proper integer-only HW pipeline, we would like our main accumulation term to simply be <script type="math/tex">\sum{x_q w_q}</script>. In order to achieve this, one needs to further develop the expression we derived above. For further details please refer to the <a href="https://github.com/google/gemmlowp/blob/master/doc/quantization.md#implementation-of-quantized-matrix-multiplication">gemmlowp documentation</a></li> +</ul> +<h4 id="symmetric-mode">Symmetric Mode</h4> +<p align="center"> + <img src="../imgs/quant_sym.png"/> +</p> + +<p>In <strong>symmetric</strong> mode, instead of mapping the exact min/max of the float range to the quantized range, we choose the maximum absolute value between min/max. In addition, we don't use a zero-point. So, the floating-point range we're effectively quantizing is symmetric with respect to zero, and so is the quantized range.</p> +<p>Using the same notations as above, we get:</p> +<p> +<script type="math/tex; mode=display">x_q = round\left (x_f \underbrace{\frac{2^{n-1} - 1}{\max|x_f|}}_{q_x} \right) = round(q_x x_f)</script> +</p> +<p>Again, let's see how a <strong>convolution</strong> or <strong>fully-connected (FC)</strong> layer is quantized, this time in symmetric mode:</p> +<p> +<script type="math/tex; mode=display">y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q}{q_x} \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right)</script> +</p> +<p>Therefore:</p> +<p> +<script type="math/tex; mode=display">y_q = round(q_y y_f) = round\left(\frac{q_y}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right) \right) </script> +</p> +<h4 id="comparing-the-two-modes">Comparing the Two Modes</h4> +<p>The main trade-off between these two modes is simplicity vs. utilization of the quantized range.</p> +<ul> +<li>When using asymmetric quantization, the quantized range is fully utilized. That is because we exactly map the min/max values from the float range to the min/max of the quantized range. Using symmetric mode, if the float range is biased towards one side, could result in a quantized range where significant dynamic range is dedicated to values that we'll never see. The most extreme example of this is after ReLU, where the entire tensor is positive. Quantizing it in symmetric mode means we're effectively losing 1 bit.</li> +<li>On the other hand, if we look at the derviations for convolution / FC layers above, we can see that the actual implementation of symmetric mode is much simpler. In asymmetric mode, the zero-points require additional logic in HW. The cost of this extra logic in terms of latency and/or power and/or area will of course depend on the exact implementation.</li> +</ul> +<h3 id="other-features">Other Features</h3> +<ul> +<li><strong>Removing Outliers:</strong> As discussed <a href="../quantization/index.html#outliers-removal">here</a>, in some cases the float range of activations contains outliers. Spending dynamic range on these outliers hurts our ability ro represent the values we actually care about accurately. + <p align="center"> + <img src="../imgs/quant_clipped.png"/> + </p> + Currently, Distiller supports clipping of activations with averaging during post-training quantization. That is - for each batch, instead of calculating global min/max values, an average of the min/max values of each sample in the batch.</li> +<li><strong>Scale factor scope:</strong> For weight tensors, Distiller supports per-channel quantization (per output channel).</li> +</ul> +<h3 id="implementation-in-distiller">Implementation in Distiller</h3> +<h4 id="post-training">Post-Training</h4> +<p>For post-training quantization, currently <strong>convolution</strong> and <strong>FC</strong> are supported using this method. </p> +<ul> +<li>They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the <code>RangeLinearQuantParamLayerWrapper</code> class. </li> +<li>All other layers are unaffected and are executed using their original FP32 implementation. </li> +<li>To automatically transform an existing model to a quantized model using this method, use the <code>PostTrainLinearQuantizer</code> class. For an example of how to do this, see the <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py"><code>compress_classifier.py</code></a>. This sample also exposes command line arguments to invoke post-training quantization. For details see <a href="../usage/index.html#post-training-quantization">here</a>.</li> +<li>For weights and bias the scale factor and zero-point are determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online"). The calculated quantization parameters are store as buffers within the module, so they are automatically serialized when the model checkpoint is saved.</li> +<li>As this is post-training, using it with number of bits < 8 is likely to lead to severe accuracy degradation for any non-trivial workload.</li> +</ul> +<h4 id="quantization-aware-training">Quantization-Aware Training</h4> +<p>To apply range-based linear quantization in training, use the <code>QuantAwareTrainRangeLinearQuantizer</code> class. As it is now, it will apply weights quantization to convolution and FC modules. For activations quantization, it will insert instances <code>FakeLinearQuantization</code> module after ReLUs. This module follows the methodology described in <a href="http://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html">Benoit et al., 2018</a> and uses exponential moving averages to track activation ranges.</p> +<p>Similarly to post-training, the calculated quantization parameters (scale factors, zero-points, tracked activation ranges) are stored as buffers within their respective modules, so they're saved when a checkpoint is created.</p> +<p>Note that converting from a quantization-aware training model to a post-training quantization model is not yet supported. Such a conversion will use the activation ranges tracked during training, so additional offline or online calculation of quantization parameters will not be required.</p> <h2 id="dorefa">DoReFa</h2> <p>(As proposed in <a href="https://arxiv.org/abs/1606.06160">DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients</a>) </p> <p>In this method, we first define the quantization function <script type="math/tex">quantize_k</script>, which takes a real value <script type="math/tex">a_f \in [0, 1]</script> and outputs a discrete-valued <script type="math/tex">a_q \in \left\{ \frac{0}{2^k-1}, \frac{1}{2^k-1}, ... , \frac{2^k-1}{2^k-1} \right\}</script>, where <script type="math/tex">k</script> is the number of bits used for quantization.</p> @@ -203,7 +285,7 @@ <p> <script type="math/tex; mode=display">w_q = 2 quantize_k \left( f(w_f) \right) - 1</script> </p> -<p>This method requires training the model with quantization, as discussed <a href="../quantization/index.html#training-with-quantization">here</a>. Use the <code>DorefaQuantizer</code> class to transform an existing model to a model suitable for training with quantization using DoReFa.</p> +<p>This method requires training the model with quantization-aware training, as discussed <a href="../quantization/index.html#quantization-aware-training">here</a>. Use the <code>DorefaQuantizer</code> class to transform an existing model to a model suitable for training with quantization using DoReFa.</p> <h3 id="notes">Notes:</h3> <ul> <li>Gradients quantization as proposed in the paper is not supported yet.</li> @@ -212,7 +294,7 @@ <h2 id="pact">PACT</h2> <p>(As proposed in <a href="https://arxiv.org/abs/1805.06085">PACT: Parameterized Clipping Activation for Quantized Neural Networks</a>)</p> <p>This method is similar to DoReFa, but the upper clipping values, <script type="math/tex">\alpha</script>, of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, <script type="math/tex">\alpha</script> is shared per layer.</p> -<p>This method requires training the model with quantization, as discussed <a href="../quantization/#training-with-quantization">here</a>. Use the <code>PACTQuantizer</code> class to transform an existing model to a model suitable for training with quantization using PACT.</p> +<p>This method requires training the model with quantization-aware training, as discussed <a href="../quantization/#quantization-aware-training">here</a>. Use the <code>PACTQuantizer</code> class to transform an existing model to a model suitable for training with quantization using PACT.</p> <h2 id="wrpn">WRPN</h2> <p>(As proposed in <a href="https://arxiv.org/abs/1709.01134">WRPN: Wide Reduced-Precision Networks</a>) </p> <p>In this method, activations are clipped to <script type="math/tex">[0, 1]</script> and quantized as follows (<script type="math/tex">k</script> is the number of bits used for quantization):</p> @@ -224,31 +306,11 @@ <script type="math/tex; mode=display">w_q = \frac{1}{2^{k-1}-1} round \left( \left(2^{k-1} - 1 \right)w_f \right)</script> </p> <p>Note that <script type="math/tex">k-1</script> bits are used to quantize weights, leaving one bit for sign.</p> -<p>This method requires training the model with quantization, as discussed <a href="../quantization/#training-with-quantization">here</a>. Use the <code>WRPNQuantizer</code> class to transform an existing model to a model suitable for training with quantization using WRPN.</p> +<p>This method requires training the model with quantization-aware training, as discussed <a href="../quantization/#quantization-aware-training">here</a>. Use the <code>WRPNQuantizer</code> class to transform an existing model to a model suitable for training with quantization using WRPN.</p> <h3 id="notes_1">Notes:</h3> <ul> <li>The paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of <code>WRPNQuantizer</code> at the moment. To experiment with this, modify your model implementation to have wider layers.</li> <li>The paper defines special handling for binary weights which isn't supported in Distiller yet.</li> -</ul> -<h2 id="symmetric-linear-quantization">Symmetric Linear Quantization</h2> -<p>In this method, a float value is quantized by multiplying with a numeric constant (the <strong>scale factor</strong>), hence it is <strong>Linear</strong>. We use a signed integer to represent the quantized range, with no quantization bias (or "offset") used. As a result, the floating-point range considered for quantization is <strong>symmetric</strong> with respect to zero.<br /> -In the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers).<br /> -Let us denote the original floating-point tensor by <script type="math/tex">x_f</script>, the quantized tensor by <script type="math/tex">x_q</script>, the scale factor by <script type="math/tex">q_x</script> and the number of bits used for quantization by <script type="math/tex">n</script>. Then, we get: -<script type="math/tex; mode=display">q_x = \frac{2^{n-1}-1}{\max|x|}</script> -<script type="math/tex; mode=display">x_q = round(q_x x_f)</script> -(The <script type="math/tex">round</script> operation is round-to-nearest-integer) </p> -<p>Let's see how a <strong>convolution</strong> or <strong>fully-connected (FC)</strong> layer is quantized using this method: (we denote input, output, weights and bias with <script type="math/tex">x, y, w</script> and <script type="math/tex">b</script> respectively) -<script type="math/tex; mode=display">y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q}{q_x} \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right)</script> -<script type="math/tex; mode=display">y_q = round(q_y y_f) = round\left(\frac{q_y}{q_x q_w} \left( \sum { x_q w_q + \frac{q_x q_w}{q_b}b_q } \right) \right) </script> -Note how the bias has to be re-scaled to match the scale of the summation.</p> -<h3 id="implementation">Implementation</h3> -<p>We've implemented <strong>convolution</strong> and <strong>FC</strong> using this method. </p> -<ul> -<li>They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the <code>RangeLinearQuantParamLayerWrapper</code> class. </li> -<li>All other layers are unaffected and are executed using their original FP32 implementation. </li> -<li>To automatically transform an existing model to a quantized model using this method, use the <code>SymmetricLinearQuantizer</code> class.</li> -<li>For weights and bias the scale factor is determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online"). </li> -<li><strong>Important note:</strong> Currently, this method is implemented as <strong>inference only</strong>, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with <script type="math/tex">n < 8</script> is likely to lead to severe accuracy degradation for any non-trivial workload.</li> </ul> </div> diff --git a/docs/design/index.html b/docs/design/index.html index 4db0ce874a32e136bc9961b839a7c1553d5456a5..573e2c1217b6101da133e635f13bb599970dbd7e 100644 --- a/docs/design/index.html +++ b/docs/design/index.html @@ -242,11 +242,11 @@ To execute the model transformation, call the <code>prepare_model</code> functio </ul> <h3 id="weights-quantization">Weights Quantization</h3> <p>The <code>Quantizer</code> class also provides an API to quantize the weights of all layers at once. To use it, the <code>param_quantization_fn</code> attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the <code>Quantizer</code> class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the <code>quantize_params</code> function can be called, which will iterate over all parameters and quantize them using <code>params_quantization_fn</code>.</p> -<h3 id="training-with-quantization">Training with Quantization</h3> -<p>The <code>Quantizer</code> class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:</p> +<h3 id="quantization-aware-training">Quantization-Aware Training</h3> +<p>The <code>Quantizer</code> class supports quantization-aware training, that is - training with quantization in the loop. This requires handling of a couple of flows / scenarios:</p> <ol> <li> -<p>Maintaining a full precision copy of the weights, as described <a href="../quantization/index.html#training-with-quantization">here</a>. This is enabled by setting <code>train_with_fp_copy=True</code> in the <code>Quantizer</code> constructor. At model transformation, in each module that has parameters that should be quantized, a new <code>torch.nn.Parameter</code> is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module <strong>is not</strong> created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": </p> +<p>Maintaining a full precision copy of the weights, as described <a href="../quantization/index.html#quantization-aware-training">here</a>. This is enabled by setting <code>train_with_fp_copy=True</code> in the <code>Quantizer</code> constructor. At model transformation, in each module that has parameters that should be quantized, a new <code>torch.nn.Parameter</code> is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module <strong>is not</strong> created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": </p> <ol> <li>The existing <code>torch.nn.Parameter</code>, e.g. <code>weights</code>, is replaced by a <code>torch.nn.Parameter</code> named <code>float_weight</code>.</li> <li>To maintain the existing functionality of the module, we then register a <code>buffer</code> in the module with the original name - <code>weights</code>.</li> diff --git a/docs/imgs/quant_asym.png b/docs/imgs/quant_asym.png new file mode 100644 index 0000000000000000000000000000000000000000..9dc43817a595687119a5d6e3fee1a905be4c308e Binary files /dev/null and b/docs/imgs/quant_asym.png differ diff --git a/docs/imgs/quant_clipped.png b/docs/imgs/quant_clipped.png new file mode 100644 index 0000000000000000000000000000000000000000..a5cfa215a0deb90217e777a4a485ee8b577f25d1 Binary files /dev/null and b/docs/imgs/quant_clipped.png differ diff --git a/docs/imgs/quant_sym.png b/docs/imgs/quant_sym.png new file mode 100644 index 0000000000000000000000000000000000000000..7a1d390c2abb8d537de6763b2fbb271778c36605 Binary files /dev/null and b/docs/imgs/quant_sym.png differ diff --git a/docs/index.html b/docs/index.html index 856e0fa710e1ebd776d70d05aebaa24c065f70ed..ccc1498994f5abb04558f9d48992cf0b51de0259 100644 --- a/docs/index.html +++ b/docs/index.html @@ -258,5 +258,5 @@ And of course, if we used a sparse or compressed representation, then we are red <!-- MkDocs version : 0.17.2 -Build Date UTC : 2018-11-25 09:04:15 +Build Date UTC : 2018-12-04 15:19:48 --> diff --git a/docs/quantization/index.html b/docs/quantization/index.html index 0fd0d831968baa6b1963d9541da21df21ec8e8ec..8b9dcb887a0a7a70365d5016c59edf4702ace4ed 100644 --- a/docs/quantization/index.html +++ b/docs/quantization/index.html @@ -97,7 +97,7 @@ <li><a class="toctree-l4" href="#aggressive-quantization-int4-and-lower">"Aggressive" Quantization: INT4 and Lower</a></li> - <li><a class="toctree-l4" href="#training-with-quantization">Training with Quantization</a></li> + <li><a class="toctree-l4" href="#quantization-aware-training">Quantization-Aware Training</a></li> <li><a class="toctree-l4" href="#references">References</a></li> @@ -233,12 +233,15 @@ As mentioned above, a scale factor is used to adapt the dynamic range of the ten <li><strong>Offline</strong> means gathering activations statistics before deploying the model, either during training or by running a few "calibration" batches on the trained FP32 model. Based on these gathered statistics, the scaled factors are calculated and are fixed once the model is deployed. This method has the risk of encountering values outside the previously observed ranges at runtime. These values will be clipped, which might lead to accuracy degradation.</li> <li><strong>Online</strong> means calculating the min/max values for each tensor dynamically during runtime. In this method clipping cannot occur, however the added computation resources required to calculate the min/max values at runtime might be prohibitive.</li> </ul> -<p>It is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. Statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible (<a href="#migacz-2017">Migacz, 2017</a>). </p> +<div id="outliers-removal"></div> + +<p>It is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. A simple method which can yield nice results is to simply use an average of the observed min/max values instead of the actual values. Alternatively, statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible (<a href="#migacz-2017">Migacz, 2017</a>). Going further, <a href="#banner-et-al-2018">Banner et al., 2018</a> have proposed a method for analytically computing the clipping value under certain conditions.</p> <p>Another possible optimization point is <strong>scale-factor scope</strong>. The most common way is use a single scale-factor per-layer, but it is also possible to calculate a scale-factor per-channel. This can be beneficial if the weight distributions vary greatly between channels.</p> +<p>When used to directly quantize a model without re-training, as described so far, this method is commonly referred to as <strong>post-training quantization</strong>. However, recent publications have shown that there are cases where post-training quantization to INT8 doesn't preserve accuracy (<a href="#benoit-et-al-2018">Benoit et al., 2018</a>, <a href="#krishnamoorthi-2018">Krishnamoorthi, 2018</a>). Namely, smaller models such as MobileNet seem to not respond as well to post-training quantization, presumabley due to their smaller representational capacity. In such cases, <a href="#quantization-aware-training">quantization-aware training</a> is used.</p> <h2 id="aggressive-quantization-int4-and-lower">"Aggressive" Quantization: INT4 and Lower</h2> <p>Naively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy:</p> <ul> -<li><strong>Training / Re-Training</strong>: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the <a href="#training-with-quantization">next section</a>.<br /> +<li><strong>Training / Re-Training</strong>: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the <a href="#quantization-aware-training">next section</a>.<br /> <a href="#zhou-et-al-2016">Zhou S et al., 2016</a> have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods <em>require</em> a trained FP32 model, either as a starting point (<a href="#zhou-et-al-2017">Zhou A et al., 2017</a>), or as a teacher network in a knowledge distillation training setup (see <a href="../knowledge_distillation/index.html#combining">here</a>).</li> <li><strong>Replacing the activation function</strong>: The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used (<a href="#zhou-et-al-2016">Zhou S et al., 2016</a>, <a href="#mishra-et-al-2018">Mishra et al., 2018</a>). Another method learns the clipping value per layer, with better results (<a href="#choi-et-al-2018">Choi et al., 2018</a>). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above).</li> <li><strong>Modifying network structure</strong>: <a href="#mishra-et-al-2018">Mishra et al., 2018</a> try to compensate for the loss of information due to quantization by using wider layers (more channels). <a href="#lin-et-al-2017">Lin et al., 2017</a> proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different "base", covering a larger dynamic range overall.</li> @@ -246,13 +249,13 @@ As mentioned above, a scale factor is used to adapt the dynamic range of the ten <li><strong>Iterative quantization</strong>: Most methods quantize the entire model at once. <a href="#zhou-et-al-2017">Zhou A et al., 2017</a> employ an iterative method, which starts with a trained FP32 baseline, and quantizes only a portion of the model at the time followed by several epochs of re-training to recover the accuracy loss from quantization.</li> <li><strong>Mixed Weights and Activations Precision</strong>: It has been observed that activations are more sensitive to quantization than weights (<a href="#zhou-et-al-2016">Zhou S et al., 2016</a>). Hence it is not uncommon to see experiments with activations quantized to a higher precision compared to weights. Some works have focused solely on quantizing weights, keeping the activations at FP32 (<a href="#li-et-al-2016">Li et al., 2016</a>, <a href="#zhu-et-al-2016">Zhu et al., 2016</a>).</li> </ul> -<h2 id="training-with-quantization">Training with Quantization</h2> -<p>As mentioned above, in order to minimize the loss of accuracy from "aggressive" quantization, many methods that target INT4 and lower involve training the model in a way that considers the quantization. This means training with quantization of weights and activations "baked" into the training procedure. The training graph usually looks like this:</p> -<p><img alt="Training with Quantization" src="../imgs/training_quant_flow.png" /></p> +<h2 id="quantization-aware-training">Quantization-Aware Training</h2> +<p>As mentioned above, in order to minimize the loss of accuracy from "aggressive" quantization, many methods that target INT4 and lower (and in some cases for INT8 as well) involve training the model in a way that considers the quantization. This means training with quantization of weights and activations "baked" into the training procedure. The training graph usually looks like this:</p> +<p><img alt="Quantization-Aware Training" src="../imgs/training_quant_flow.png" /></p> <p>A full precision copy of the weights is maintained throughout the training process ("weights_fp" in the diagram). Its purpose is to accumulate the small changes from the gradients without loss of precision (Note that the quantization of the weights is an integral part of the training graph, meaning that we back-propagate through it as well). Once the model is trained, only the quantized weights are used for inference.<br /> In the diagram we show "layer N" as the conv + batch-norm + activation combination, but the same applies to fully-connected layers, element-wise operations, etc. During training, the operations within "layer N" can still run in full precision, with the "quantize" operations in the boundaries ensuring discrete-valued weights and activations. This is sometimes called "simulated quantization". </p> <h3 id="straight-through-estimator">Straight-Through Estimator</h3> -<p>An important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severly hinder the learning process. An approximation commonly used to overcome this issue is the "straight-through estimator" (STE) (<a href="#hinton-et-al-2012">Hinton et al., 2012</a>, <a href="#bengio-et-al-2013">Bengio, 2013</a>), which simply passes the gradient through these functions as-is. </p> +<p>An important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severely hinder the learning process. An approximation commonly used to overcome this issue is the "straight-through estimator" (STE) (<a href="#hinton-et-al-2012">Hinton et al., 2012</a>, <a href="#bengio-et-al-2013">Bengio, 2013</a>), which simply passes the gradient through these functions as-is. </p> <h2 id="references">References</h2> <p><div id="dally-2015"></div> <strong>William Dally</strong>. High-Performance Hardware for Machine Learning. <a href="https://media.nips.cc/Conferences/2015/tutorialslides/Dally-NIPS-Tutorial-2015.pdf">Tutorial, NIPS, 2015</a></p> @@ -279,7 +282,7 @@ In the diagram we show "layer N" as the conv + batch-norm + activation combinati <p><strong>Asit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr</strong>. WRPN: Wide Reduced-Precision Networks. <a href="https://openreview.net/forum?id=B1ZvaaeAZ">ICLR, 2018</a></p> <div id="choi-et-al-2018"></div> -<p><strong>Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan</strong>. PACT: Parameterized Clipping Activation for Quantized Neural Networks. <a href="https://openreview.net/forum?id=By5ugjyCb">2018</a></p> +<p><strong>Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan</strong>. PACT: Parameterized Clipping Activation for Quantized Neural Networks. <a href="https://arxiv.org/abs/1805.06085">arxiv:1805.06085</a></p> <div id="lin-et-al-2017"></div> <p><strong>Xiaofan Lin, Cong Zhao and Wei Pan</strong>. Towards Accurate Binary Convolutional Neural Network. <a href="http://papers.nips.cc/paper/6638-towards-accurate-binary-convolutional-neural-network">NIPS, 2017</a></p> @@ -294,10 +297,19 @@ In the diagram we show "layer N" as the conv + batch-norm + activation combinati <p><strong>Chenzhuo Zhu, Song Han, Huizi Mao and William J. Dally</strong>. Trained Ternary Quantization. <a href="https://arxiv.org/abs/1612.01064">arxiv:1612.01064</a></p> <div id="bengio-et-al-2013"></div> -<p><strong>Yoshua Bengio, Nicholas Leonard and Aaron Courville</strong>. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. <a href="https://arxiv.org/abs/1308.3432">arxiv:1308.3432, 2013</a></p> +<p><strong>Yoshua Bengio, Nicholas Leonard and Aaron Courville</strong>. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. <a href="https://arxiv.org/abs/1308.3432">arxiv:1308.3432</a></p> <div id="hinton-et-al-2012"></div> <p><strong>Geoffrey Hinton, Nitish Srivastava, Kevin Swersky, Tijmen Tieleman and Abdelrahman Mohamed</strong>. Neural Networks for Machine Learning. <a href="https://www.coursera.org/learn/neural-networks">Coursera, video lectures, 2012</a></p> +<div id="benoit-et-al-2018"></div> + +<p><strong>Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam and Dmitry Kalenichenko</strong>. Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. <a href="http://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html">ECCV, 2018</a></p> +<div id="krishnamoorthi-2018"></div> + +<p><strong>Raghuraman Krishnamoorthi</strong>. Quantizing deep convolutional networks for efficient inference: A whitepaper <a href="https://arxiv.org/abs/1806.08342">arxiv:1806.08342</a></p> +<div id="banner-et-al-2018"></div> + +<p><strong>Ron Banner, Yury Nahshan, Elad Hoffer and Daniel Soudry</strong>. ACIQ: Analytical Clipping for Integer Quantization of neural networks <a href="https://arxiv.org/abs/1810.05723">arxiv:1810.05723</a></p> </div> </div> diff --git a/docs/schedule/index.html b/docs/schedule/index.html index e08da6cafb69e587fd0e4f076f6bd0bb83a41695..73d74302101f492718973f9183c24c403aa33797 100644 --- a/docs/schedule/index.html +++ b/docs/schedule/index.html @@ -415,9 +415,8 @@ policies: </code></pre> <h2 id="quantization">Quantization</h2> -<p>Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the <code>Quantizer</code> class (see details <a href="../design/index.html#quantization">here</a>).<br /> -<strong>Notes</strong>: Only a single quantizer instance may be defined.<br /> -Let's see an example:</p> +<p>Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the <code>Quantizer</code> class (see details <a href="../design/index.html#quantization">here</a>). <strong>Note</strong> that only a single quantizer instance may be defined per YAML.</p> +<p>Let's see an example:</p> <pre><code>quantizers: dorefa_quantizer: class: DorefaQuantizer @@ -469,7 +468,7 @@ Let's see an example:</p> <ul> <li><strong>Important Note</strong>: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed <strong>before</strong> the broad one.</li> </ul> -<p>The <code>QuantizationPolicy</code>, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the <code>prepare_model()</code> function of the <code>Quantizer</code> when it's initialized, followed by the first call to <code>quantize_params()</code>. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the <code>quantize_params()</code> function again. </p> +<p>The <code>QuantizationPolicy</code>, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the <code>prepare_model()</code> function of the <code>Quantizer</code> when it's initialized, followed by the first call to <code>quantize_params()</code>. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the <code>quantize_params()</code> function again.</p> <pre><code>policies: - quantizer: instance_name: dorefa_quantizer @@ -478,7 +477,7 @@ Let's see an example:</p> frequency: 1 </code></pre> -<p><strong>Important Note</strong>: As mentioned <a href="../design/index.html#training-with-quantization">here</a>, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to <code>prepare_model()</code> must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a "warm-startup" (or "boot-strapping"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.</p> +<p><strong>Important Note</strong>: As mentioned <a href="../design/index.html#quantization-aware-training">here</a>, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to <code>prepare_model()</code> must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a "warm-startup" (or "boot-strapping"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.</p> <h2 id="knowledge-distillation">Knowledge Distillation</h2> <p>Knowledge distillation (see <a href="../knowledge_distillation/index.html">here</a>) is also implemented as a <code>Policy</code>, which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above.</p> <p>To make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation:</p> diff --git a/docs/search/search_index.json b/docs/search/search_index.json index f58f13230966266c9f3cdefa20e1bd667eabbe82..9891fb20c74c622b45ecd918be4d038fcf06ab7d 100644 --- a/docs/search/search_index.json +++ b/docs/search/search_index.json @@ -1,673 +1,708 @@ { "docs": [ { - "location": "/index.html", - "text": "Distiller Documentation\n\n\nWhat is Distiller\n\n\nDistiller\n is an open-source Python package for neural network compression research.\n\n\nNetwork compression can reduce the footprint of a neural network, increase its inference speed and save energy. Distiller provides a \nPyTorch\n environment for prototyping and analyzing compression algorithms, such as sparsity-inducing methods and low precision arithmetic.\n\n\nDistiller contains:\n\n\n\n\nA framework for integrating pruning, regularization and quantization algorithms.\n\n\nA set of tools for analyzing and evaluating compression performance.\n\n\nExample implementations of state-of-the-art compression algorithms.\n\n\n\n\nMotivation\n\n\nA sparse tensor is any tensor that contains some zeros, but sparse tensors are usually only interesting if they contain a significant number of zeros. A sparse neural network performs computations using some sparse tensors (preferably many). These tensors can be parameters (weights and biases) or activations (feature maps).\n\n\nWhy do we care about sparsity?\n\nPresent day neural networks tend to be deep, with millions of weights and activations. Refer to GoogLeNet or ResNet50, for a couple of examples.\nThese large models are compute-intensive which means that even with dedicated acceleration hardware, the inference pass (network evaluation) will take time. You might think that latency is an issue only in certain cases, such as autonomous driving systems, but in fact, whenever we humans interact with our phones and computers, we are sensitive to the latency of the interaction. We don't like to wait for search results or for an application or web-page to load, and we are especially sensitive in realtime interactions such as speech recognition. So inference latency is often something we want to minimize.\n\n\nLarge models are also memory-intensive with millions of parameters. Moving around all of the data required to compute inference results consumes energy, which is a problem on a mobile device as well as in a server environment. Data center server-racks are limited by their power-envelope and their ToC (total cost of ownership) is correlated to their power consumption and thermal characteristics. In the mobile device environment, we are obviously always aware of the implications of power consumption on the device battery.\nInference performance in the data center is often measured using a KPI (key performance indicator) which folds latency and power considerations: inferences per second, per Watt (inferences/sec/watt).\n\n\nThe storage and transfer of large neural networks is also a challenge in mobile device environments, because of limitations on application sizes and long application download times.\n\n\nFor these reasons, we wish to compress the network as much as possible, to reduce the amount of bandwidth and compute required. Inducing sparseness, through regularization or pruning, in neural-network models, is one way to compress the network (quantization is another method).\nSparse neural networks hold the promise of speed, small size, and energy efficiency. \n\n\nSmaller\n\n\nSparse NN model representations can be compressed by taking advantage of the fact that the tensor elements are dominated by zeros. The compression format, if any, is very HW and SW specific, and the optimal format may be different per tensor (an obvious example: largely dense tensors should not be compressed). The compute hardware needs to support the compressions formats, for representation compression to be meaningful. Compression representation decisions might interact with algorithms such as the use of tiles for memory accesses. Data such as a parameter tensor is read/written from/to main system memory compressed, but the computation can be dense or sparse. In dense compute we use dense operators, so the compressed data eventually needs to be decompressed into its full, dense size. The best we can do is bring the compressed representation as close as possible to the compute engine.\n\nSparse compute, on the other hand, operates on the sparse representation which never requires decompression (we therefore distinguish between sparse representation and compressed representation). This is not a simple matter to implement in HW, and often means lower utilization of the vectorized compute engines. Therefore, there is a third class of representations, which take advantage of specific hardware characteristics. For example, for a vectorized compute engine we can remove an entire zero-weights vector and skip its computation (this uses structured pruning or regularization).\n\n\nFaster\n\n\nMany of the layers in modern neural-networks are bandwidth-bound, which means that the execution latency is dominated by the available bandwidth. In essence, the hardware spends more time bringing data close to the compute engines, than actually performing the computations. Fully-connected layers, RNNs and LSTMs are some examples of bandwidth-dominated operations.\n\nReducing the bandwidth required by these layers, will immediately speed them up.\n\nSome pruning algorithms prune entire kernels, filters and even layers from the network without adversely impacting the final accuracy. Depending on the hardware implementation, these methods can be leveraged to skip computations, thus reducing latency and power.\n\n\nMore energy efficient\n\n\nBecause we pay two orders-of-magnitude more energy to access off-chip memory (e.g. DDR) compared to on-chip memory (e.g. SRAM or cache), many hardware designs employ a multi-layered cache hierarchy. Fitting the parameters and activations of a network in these on-chip caches can make a big difference on the required bandwidth, the total inference latency, and off course reduce power consumption.\n\nAnd of course, if we used a sparse or compressed representation, then we are reducing the data throughput and therefore the energy consumption.", + "location": "/index.html", + "text": "Distiller Documentation\n\n\nWhat is Distiller\n\n\nDistiller\n is an open-source Python package for neural network compression research.\n\n\nNetwork compression can reduce the footprint of a neural network, increase its inference speed and save energy. Distiller provides a \nPyTorch\n environment for prototyping and analyzing compression algorithms, such as sparsity-inducing methods and low precision arithmetic.\n\n\nDistiller contains:\n\n\n\n\nA framework for integrating pruning, regularization and quantization algorithms.\n\n\nA set of tools for analyzing and evaluating compression performance.\n\n\nExample implementations of state-of-the-art compression algorithms.\n\n\n\n\nMotivation\n\n\nA sparse tensor is any tensor that contains some zeros, but sparse tensors are usually only interesting if they contain a significant number of zeros. A sparse neural network performs computations using some sparse tensors (preferably many). These tensors can be parameters (weights and biases) or activations (feature maps).\n\n\nWhy do we care about sparsity?\n\nPresent day neural networks tend to be deep, with millions of weights and activations. Refer to GoogLeNet or ResNet50, for a couple of examples.\nThese large models are compute-intensive which means that even with dedicated acceleration hardware, the inference pass (network evaluation) will take time. You might think that latency is an issue only in certain cases, such as autonomous driving systems, but in fact, whenever we humans interact with our phones and computers, we are sensitive to the latency of the interaction. We don't like to wait for search results or for an application or web-page to load, and we are especially sensitive in realtime interactions such as speech recognition. So inference latency is often something we want to minimize.\n\n\nLarge models are also memory-intensive with millions of parameters. Moving around all of the data required to compute inference results consumes energy, which is a problem on a mobile device as well as in a server environment. Data center server-racks are limited by their power-envelope and their ToC (total cost of ownership) is correlated to their power consumption and thermal characteristics. In the mobile device environment, we are obviously always aware of the implications of power consumption on the device battery.\nInference performance in the data center is often measured using a KPI (key performance indicator) which folds latency and power considerations: inferences per second, per Watt (inferences/sec/watt).\n\n\nThe storage and transfer of large neural networks is also a challenge in mobile device environments, because of limitations on application sizes and long application download times.\n\n\nFor these reasons, we wish to compress the network as much as possible, to reduce the amount of bandwidth and compute required. Inducing sparseness, through regularization or pruning, in neural-network models, is one way to compress the network (quantization is another method).\nSparse neural networks hold the promise of speed, small size, and energy efficiency. \n\n\nSmaller\n\n\nSparse NN model representations can be compressed by taking advantage of the fact that the tensor elements are dominated by zeros. The compression format, if any, is very HW and SW specific, and the optimal format may be different per tensor (an obvious example: largely dense tensors should not be compressed). The compute hardware needs to support the compressions formats, for representation compression to be meaningful. Compression representation decisions might interact with algorithms such as the use of tiles for memory accesses. Data such as a parameter tensor is read/written from/to main system memory compressed, but the computation can be dense or sparse. In dense compute we use dense operators, so the compressed data eventually needs to be decompressed into its full, dense size. The best we can do is bring the compressed representation as close as possible to the compute engine.\n\nSparse compute, on the other hand, operates on the sparse representation which never requires decompression (we therefore distinguish between sparse representation and compressed representation). This is not a simple matter to implement in HW, and often means lower utilization of the vectorized compute engines. Therefore, there is a third class of representations, which take advantage of specific hardware characteristics. For example, for a vectorized compute engine we can remove an entire zero-weights vector and skip its computation (this uses structured pruning or regularization).\n\n\nFaster\n\n\nMany of the layers in modern neural-networks are bandwidth-bound, which means that the execution latency is dominated by the available bandwidth. In essence, the hardware spends more time bringing data close to the compute engines, than actually performing the computations. Fully-connected layers, RNNs and LSTMs are some examples of bandwidth-dominated operations.\n\nReducing the bandwidth required by these layers, will immediately speed them up.\n\nSome pruning algorithms prune entire kernels, filters and even layers from the network without adversely impacting the final accuracy. Depending on the hardware implementation, these methods can be leveraged to skip computations, thus reducing latency and power.\n\n\nMore energy efficient\n\n\nBecause we pay two orders-of-magnitude more energy to access off-chip memory (e.g. DDR) compared to on-chip memory (e.g. SRAM or cache), many hardware designs employ a multi-layered cache hierarchy. Fitting the parameters and activations of a network in these on-chip caches can make a big difference on the required bandwidth, the total inference latency, and off course reduce power consumption.\n\nAnd of course, if we used a sparse or compressed representation, then we are reducing the data throughput and therefore the energy consumption.", "title": "Home" - }, + }, { - "location": "/index.html#distiller-documentation", - "text": "", + "location": "/index.html#distiller-documentation", + "text": "", "title": "Distiller Documentation" - }, + }, { - "location": "/index.html#what-is-distiller", - "text": "Distiller is an open-source Python package for neural network compression research. Network compression can reduce the footprint of a neural network, increase its inference speed and save energy. Distiller provides a PyTorch environment for prototyping and analyzing compression algorithms, such as sparsity-inducing methods and low precision arithmetic. Distiller contains: A framework for integrating pruning, regularization and quantization algorithms. A set of tools for analyzing and evaluating compression performance. Example implementations of state-of-the-art compression algorithms.", + "location": "/index.html#what-is-distiller", + "text": "Distiller is an open-source Python package for neural network compression research. Network compression can reduce the footprint of a neural network, increase its inference speed and save energy. Distiller provides a PyTorch environment for prototyping and analyzing compression algorithms, such as sparsity-inducing methods and low precision arithmetic. Distiller contains: A framework for integrating pruning, regularization and quantization algorithms. A set of tools for analyzing and evaluating compression performance. Example implementations of state-of-the-art compression algorithms.", "title": "What is Distiller" - }, + }, { - "location": "/index.html#motivation", - "text": "A sparse tensor is any tensor that contains some zeros, but sparse tensors are usually only interesting if they contain a significant number of zeros. A sparse neural network performs computations using some sparse tensors (preferably many). These tensors can be parameters (weights and biases) or activations (feature maps). Why do we care about sparsity? \nPresent day neural networks tend to be deep, with millions of weights and activations. Refer to GoogLeNet or ResNet50, for a couple of examples.\nThese large models are compute-intensive which means that even with dedicated acceleration hardware, the inference pass (network evaluation) will take time. You might think that latency is an issue only in certain cases, such as autonomous driving systems, but in fact, whenever we humans interact with our phones and computers, we are sensitive to the latency of the interaction. We don't like to wait for search results or for an application or web-page to load, and we are especially sensitive in realtime interactions such as speech recognition. So inference latency is often something we want to minimize. \nLarge models are also memory-intensive with millions of parameters. Moving around all of the data required to compute inference results consumes energy, which is a problem on a mobile device as well as in a server environment. Data center server-racks are limited by their power-envelope and their ToC (total cost of ownership) is correlated to their power consumption and thermal characteristics. In the mobile device environment, we are obviously always aware of the implications of power consumption on the device battery.\nInference performance in the data center is often measured using a KPI (key performance indicator) which folds latency and power considerations: inferences per second, per Watt (inferences/sec/watt). \nThe storage and transfer of large neural networks is also a challenge in mobile device environments, because of limitations on application sizes and long application download times. \nFor these reasons, we wish to compress the network as much as possible, to reduce the amount of bandwidth and compute required. Inducing sparseness, through regularization or pruning, in neural-network models, is one way to compress the network (quantization is another method).\nSparse neural networks hold the promise of speed, small size, and energy efficiency.", + "location": "/index.html#motivation", + "text": "A sparse tensor is any tensor that contains some zeros, but sparse tensors are usually only interesting if they contain a significant number of zeros. A sparse neural network performs computations using some sparse tensors (preferably many). These tensors can be parameters (weights and biases) or activations (feature maps). Why do we care about sparsity? \nPresent day neural networks tend to be deep, with millions of weights and activations. Refer to GoogLeNet or ResNet50, for a couple of examples.\nThese large models are compute-intensive which means that even with dedicated acceleration hardware, the inference pass (network evaluation) will take time. You might think that latency is an issue only in certain cases, such as autonomous driving systems, but in fact, whenever we humans interact with our phones and computers, we are sensitive to the latency of the interaction. We don't like to wait for search results or for an application or web-page to load, and we are especially sensitive in realtime interactions such as speech recognition. So inference latency is often something we want to minimize. \nLarge models are also memory-intensive with millions of parameters. Moving around all of the data required to compute inference results consumes energy, which is a problem on a mobile device as well as in a server environment. Data center server-racks are limited by their power-envelope and their ToC (total cost of ownership) is correlated to their power consumption and thermal characteristics. In the mobile device environment, we are obviously always aware of the implications of power consumption on the device battery.\nInference performance in the data center is often measured using a KPI (key performance indicator) which folds latency and power considerations: inferences per second, per Watt (inferences/sec/watt). \nThe storage and transfer of large neural networks is also a challenge in mobile device environments, because of limitations on application sizes and long application download times. \nFor these reasons, we wish to compress the network as much as possible, to reduce the amount of bandwidth and compute required. Inducing sparseness, through regularization or pruning, in neural-network models, is one way to compress the network (quantization is another method).\nSparse neural networks hold the promise of speed, small size, and energy efficiency.", "title": "Motivation" - }, + }, { - "location": "/index.html#smaller", - "text": "Sparse NN model representations can be compressed by taking advantage of the fact that the tensor elements are dominated by zeros. The compression format, if any, is very HW and SW specific, and the optimal format may be different per tensor (an obvious example: largely dense tensors should not be compressed). The compute hardware needs to support the compressions formats, for representation compression to be meaningful. Compression representation decisions might interact with algorithms such as the use of tiles for memory accesses. Data such as a parameter tensor is read/written from/to main system memory compressed, but the computation can be dense or sparse. In dense compute we use dense operators, so the compressed data eventually needs to be decompressed into its full, dense size. The best we can do is bring the compressed representation as close as possible to the compute engine. \nSparse compute, on the other hand, operates on the sparse representation which never requires decompression (we therefore distinguish between sparse representation and compressed representation). This is not a simple matter to implement in HW, and often means lower utilization of the vectorized compute engines. Therefore, there is a third class of representations, which take advantage of specific hardware characteristics. For example, for a vectorized compute engine we can remove an entire zero-weights vector and skip its computation (this uses structured pruning or regularization).", + "location": "/index.html#smaller", + "text": "Sparse NN model representations can be compressed by taking advantage of the fact that the tensor elements are dominated by zeros. The compression format, if any, is very HW and SW specific, and the optimal format may be different per tensor (an obvious example: largely dense tensors should not be compressed). The compute hardware needs to support the compressions formats, for representation compression to be meaningful. Compression representation decisions might interact with algorithms such as the use of tiles for memory accesses. Data such as a parameter tensor is read/written from/to main system memory compressed, but the computation can be dense or sparse. In dense compute we use dense operators, so the compressed data eventually needs to be decompressed into its full, dense size. The best we can do is bring the compressed representation as close as possible to the compute engine. \nSparse compute, on the other hand, operates on the sparse representation which never requires decompression (we therefore distinguish between sparse representation and compressed representation). This is not a simple matter to implement in HW, and often means lower utilization of the vectorized compute engines. Therefore, there is a third class of representations, which take advantage of specific hardware characteristics. For example, for a vectorized compute engine we can remove an entire zero-weights vector and skip its computation (this uses structured pruning or regularization).", "title": "Smaller" - }, + }, { - "location": "/index.html#faster", - "text": "Many of the layers in modern neural-networks are bandwidth-bound, which means that the execution latency is dominated by the available bandwidth. In essence, the hardware spends more time bringing data close to the compute engines, than actually performing the computations. Fully-connected layers, RNNs and LSTMs are some examples of bandwidth-dominated operations. \nReducing the bandwidth required by these layers, will immediately speed them up. \nSome pruning algorithms prune entire kernels, filters and even layers from the network without adversely impacting the final accuracy. Depending on the hardware implementation, these methods can be leveraged to skip computations, thus reducing latency and power.", + "location": "/index.html#faster", + "text": "Many of the layers in modern neural-networks are bandwidth-bound, which means that the execution latency is dominated by the available bandwidth. In essence, the hardware spends more time bringing data close to the compute engines, than actually performing the computations. Fully-connected layers, RNNs and LSTMs are some examples of bandwidth-dominated operations. \nReducing the bandwidth required by these layers, will immediately speed them up. \nSome pruning algorithms prune entire kernels, filters and even layers from the network without adversely impacting the final accuracy. Depending on the hardware implementation, these methods can be leveraged to skip computations, thus reducing latency and power.", "title": "Faster" - }, + }, { - "location": "/index.html#more-energy-efficient", - "text": "Because we pay two orders-of-magnitude more energy to access off-chip memory (e.g. DDR) compared to on-chip memory (e.g. SRAM or cache), many hardware designs employ a multi-layered cache hierarchy. Fitting the parameters and activations of a network in these on-chip caches can make a big difference on the required bandwidth, the total inference latency, and off course reduce power consumption. \nAnd of course, if we used a sparse or compressed representation, then we are reducing the data throughput and therefore the energy consumption.", + "location": "/index.html#more-energy-efficient", + "text": "Because we pay two orders-of-magnitude more energy to access off-chip memory (e.g. DDR) compared to on-chip memory (e.g. SRAM or cache), many hardware designs employ a multi-layered cache hierarchy. Fitting the parameters and activations of a network in these on-chip caches can make a big difference on the required bandwidth, the total inference latency, and off course reduce power consumption. \nAnd of course, if we used a sparse or compressed representation, then we are reducing the data throughput and therefore the energy consumption.", "title": "More energy efficient" - }, + }, { - "location": "/install/index.html", - "text": "Distiller Installation\n\n\nThese instructions will help get Distiller up and running on your local machine.\n\n\nYou may also want to refer to these resources:\n\n\n\n\nDataset installation\n instructions.\n\n\nJupyter installation\n instructions.\n\n\n\n\nNotes:\n- Distiller has only been tested on Ubuntu 16.04 LTS, and with Python 3.5.\n- If you are not using a GPU, you might need to make small adjustments to the code.\n\n\nClone Distiller\n\n\nClone the Distiller code repository from github:\n\n\n$ git clone https://github.com/NervanaSystems/distiller.git\n\n\n\n\nThe rest of the documentation that follows, assumes that you have cloned your repository to a directory called \ndistiller\n. \n\n\nCreate a Python virtual environment\n\n\nWe recommend using a \nPython virtual environment\n, but that of course, is up to you.\nThere's nothing special about using Distiller in a virtual environment, but we provide some instructions, for completeness.\n\nBefore creating the virtual environment, make sure you are located in directory \ndistiller\n. After creating the environment, you should see a directory called \ndistiller/env\n.\n\n\n\nUsing virtualenv\n\n\nIf you don't have virtualenv installed, you can find the installation instructions \nhere\n.\n\n\nTo create the environment, execute:\n\n\n$ python3 -m virtualenv env\n\n\n\n\nThis creates a subdirectory named \nenv\n where the python virtual environment is stored, and configures the current shell to use it as the default python environment.\n\n\nUsing venv\n\n\nIf you prefer to use \nvenv\n, then begin by installing it:\n\n\n$ sudo apt-get install python3-venv\n\n\n\n\nThen create the environment:\n\n\n$ python3 -m venv env\n\n\n\n\nAs with virtualenv, this creates a directory called \ndistiller/env\n.\n\n\nActivate the environment\n\n\nThe environment activation and deactivation commands for \nvenv\n and \nvirtualenv\n are the same.\n\n\n!NOTE: Make sure to activate the environment, before proceeding with the installation of the dependency packages:\n\n\n$ source env/bin/activate\n\n\n\n\nInstall dependencies\n\n\nFinally, install Distiller's dependency packages using \npip3\n:\n\n\n$ pip3 install -r requirements.txt\n\n\n\n\nPyTorch is included in the \nrequirements.txt\n file, and will currently download PyTorch version 3.1 for CUDA 8.0. This is the setup we've used for testing Distiller.", + "location": "/install/index.html", + "text": "Distiller Installation\n\n\nThese instructions will help get Distiller up and running on your local machine.\n\n\nYou may also want to refer to these resources:\n\n\n\n\nDataset installation\n instructions.\n\n\nJupyter installation\n instructions.\n\n\n\n\nNotes:\n- Distiller has only been tested on Ubuntu 16.04 LTS, and with Python 3.5.\n- If you are not using a GPU, you might need to make small adjustments to the code.\n\n\nClone Distiller\n\n\nClone the Distiller code repository from github:\n\n\n$ git clone https://github.com/NervanaSystems/distiller.git\n\n\n\n\nThe rest of the documentation that follows, assumes that you have cloned your repository to a directory called \ndistiller\n. \n\n\nCreate a Python virtual environment\n\n\nWe recommend using a \nPython virtual environment\n, but that of course, is up to you.\nThere's nothing special about using Distiller in a virtual environment, but we provide some instructions, for completeness.\n\nBefore creating the virtual environment, make sure you are located in directory \ndistiller\n. After creating the environment, you should see a directory called \ndistiller/env\n.\n\n\n\nUsing virtualenv\n\n\nIf you don't have virtualenv installed, you can find the installation instructions \nhere\n.\n\n\nTo create the environment, execute:\n\n\n$ python3 -m virtualenv env\n\n\n\n\nThis creates a subdirectory named \nenv\n where the python virtual environment is stored, and configures the current shell to use it as the default python environment.\n\n\nUsing venv\n\n\nIf you prefer to use \nvenv\n, then begin by installing it:\n\n\n$ sudo apt-get install python3-venv\n\n\n\n\nThen create the environment:\n\n\n$ python3 -m venv env\n\n\n\n\nAs with virtualenv, this creates a directory called \ndistiller/env\n.\n\n\nActivate the environment\n\n\nThe environment activation and deactivation commands for \nvenv\n and \nvirtualenv\n are the same.\n\n\n!NOTE: Make sure to activate the environment, before proceeding with the installation of the dependency packages:\n\n\n$ source env/bin/activate\n\n\n\n\nInstall dependencies\n\n\nFinally, install Distiller's dependency packages using \npip3\n:\n\n\n$ pip3 install -r requirements.txt\n\n\n\n\nPyTorch is included in the \nrequirements.txt\n file, and will currently download PyTorch version 3.1 for CUDA 8.0. This is the setup we've used for testing Distiller.", "title": "Installation" - }, + }, { - "location": "/install/index.html#distiller-installation", - "text": "These instructions will help get Distiller up and running on your local machine. You may also want to refer to these resources: Dataset installation instructions. Jupyter installation instructions. Notes:\n- Distiller has only been tested on Ubuntu 16.04 LTS, and with Python 3.5.\n- If you are not using a GPU, you might need to make small adjustments to the code.", + "location": "/install/index.html#distiller-installation", + "text": "These instructions will help get Distiller up and running on your local machine. You may also want to refer to these resources: Dataset installation instructions. Jupyter installation instructions. Notes:\n- Distiller has only been tested on Ubuntu 16.04 LTS, and with Python 3.5.\n- If you are not using a GPU, you might need to make small adjustments to the code.", "title": "Distiller Installation" - }, + }, { - "location": "/install/index.html#clone-distiller", - "text": "Clone the Distiller code repository from github: $ git clone https://github.com/NervanaSystems/distiller.git The rest of the documentation that follows, assumes that you have cloned your repository to a directory called distiller .", + "location": "/install/index.html#clone-distiller", + "text": "Clone the Distiller code repository from github: $ git clone https://github.com/NervanaSystems/distiller.git The rest of the documentation that follows, assumes that you have cloned your repository to a directory called distiller .", "title": "Clone Distiller" - }, + }, { - "location": "/install/index.html#create-a-python-virtual-environment", - "text": "We recommend using a Python virtual environment , but that of course, is up to you.\nThere's nothing special about using Distiller in a virtual environment, but we provide some instructions, for completeness. \nBefore creating the virtual environment, make sure you are located in directory distiller . After creating the environment, you should see a directory called distiller/env .", + "location": "/install/index.html#create-a-python-virtual-environment", + "text": "We recommend using a Python virtual environment , but that of course, is up to you.\nThere's nothing special about using Distiller in a virtual environment, but we provide some instructions, for completeness. \nBefore creating the virtual environment, make sure you are located in directory distiller . After creating the environment, you should see a directory called distiller/env .", "title": "Create a Python virtual environment" - }, + }, { - "location": "/install/index.html#using-virtualenv", - "text": "If you don't have virtualenv installed, you can find the installation instructions here . To create the environment, execute: $ python3 -m virtualenv env This creates a subdirectory named env where the python virtual environment is stored, and configures the current shell to use it as the default python environment.", + "location": "/install/index.html#using-virtualenv", + "text": "If you don't have virtualenv installed, you can find the installation instructions here . To create the environment, execute: $ python3 -m virtualenv env This creates a subdirectory named env where the python virtual environment is stored, and configures the current shell to use it as the default python environment.", "title": "Using virtualenv" - }, + }, { - "location": "/install/index.html#using-venv", - "text": "If you prefer to use venv , then begin by installing it: $ sudo apt-get install python3-venv Then create the environment: $ python3 -m venv env As with virtualenv, this creates a directory called distiller/env .", + "location": "/install/index.html#using-venv", + "text": "If you prefer to use venv , then begin by installing it: $ sudo apt-get install python3-venv Then create the environment: $ python3 -m venv env As with virtualenv, this creates a directory called distiller/env .", "title": "Using venv" - }, + }, { - "location": "/install/index.html#activate-the-environment", - "text": "The environment activation and deactivation commands for venv and virtualenv are the same. !NOTE: Make sure to activate the environment, before proceeding with the installation of the dependency packages: $ source env/bin/activate", + "location": "/install/index.html#activate-the-environment", + "text": "The environment activation and deactivation commands for venv and virtualenv are the same. !NOTE: Make sure to activate the environment, before proceeding with the installation of the dependency packages: $ source env/bin/activate", "title": "Activate the environment" - }, + }, { - "location": "/install/index.html#install-dependencies", - "text": "Finally, install Distiller's dependency packages using pip3 : $ pip3 install -r requirements.txt PyTorch is included in the requirements.txt file, and will currently download PyTorch version 3.1 for CUDA 8.0. This is the setup we've used for testing Distiller.", + "location": "/install/index.html#install-dependencies", + "text": "Finally, install Distiller's dependency packages using pip3 : $ pip3 install -r requirements.txt PyTorch is included in the requirements.txt file, and will currently download PyTorch version 3.1 for CUDA 8.0. This is the setup we've used for testing Distiller.", "title": "Install dependencies" - }, + }, { - "location": "/usage/index.html", - "text": "Using the sample application\n\n\nThe Distiller repository contains a sample application, \ndistiller/examples/classifier_compression/compress_classifier.py\n, and a set of scheduling files which demonstrate Distiller's features. Following is a brief discussion of how to use this application and the accompanying schedules.\n\n\nYou might also want to refer to the following resources:\n\n\n\n\nAn \nexplanation\n of the scheduler file format.\n\n\nAn in-depth \ndiscussion\n of how we used these schedule files to implement several state-of-the-art DNN compression research papers.\n\n\n\n\nThe sample application supports various features for compression of image classification DNNs, and gives an example of how to integrate distiller in your own application. The code is documented and should be considered the best source of documentation, but we provide some elaboration here.\n\n\nThis diagram shows how where \ncompress_classifier.py\n fits in the compression workflow, and how we integrate the Jupyter notebooks as part of our research work.\n\n\n\nCommand line arguments\n\n\nTo get help on the command line arguments, invoke:\n\n\n$ python3 compress_classifier.py --help\n\n\n\n\nFor example:\n\n\n$ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\nParameters:\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n | 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n | 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n | 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n | 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n | 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n | 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n | 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n | 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n | 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:04,646 - Epoch: [89][ 50/ 500] Loss 2.175988 Top1 51.289063 Top5 74.023438\n 2018-04-04 21:31:06,427 - Epoch: [89][ 100/ 500] Loss 2.171564 Top1 51.175781 Top5 74.308594\n 2018-04-04 21:31:11,432 - Epoch: [89][ 150/ 500] Loss 2.159347 Top1 51.546875 Top5 74.473958\n 2018-04-04 21:31:14,364 - Epoch: [89][ 200/ 500] Loss 2.156857 Top1 51.585938 Top5 74.568359\n 2018-04-04 21:31:18,381 - Epoch: [89][ 250/ 500] Loss 2.152790 Top1 51.707813 Top5 74.681250\n 2018-04-04 21:31:22,195 - Epoch: [89][ 300/ 500] Loss 2.149962 Top1 51.791667 Top5 74.755208\n 2018-04-04 21:31:25,508 - Epoch: [89][ 350/ 500] Loss 2.150936 Top1 51.827009 Top5 74.767857\n 2018-04-04 21:31:29,538 - Epoch: [89][ 400/ 500] Loss 2.150853 Top1 51.781250 Top5 74.763672\n 2018-04-04 21:31:32,842 - Epoch: [89][ 450/ 500] Loss 2.150156 Top1 51.828125 Top5 74.821181\n 2018-04-04 21:31:35,338 - Epoch: [89][ 500/ 500] Loss 2.150417 Top1 51.833594 Top5 74.817187\n 2018-04-04 21:31:35,357 - ==> Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:35,364 - Saving checkpoint\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:31:51,512 - Test: [ 50/ 195] Loss 1.487607 Top1 63.273438 Top5 85.695312\n 2018-04-04 21:31:55,015 - Test: [ 100/ 195] Loss 1.638043 Top1 60.636719 Top5 83.664062\n 2018-04-04 21:31:58,732 - Test: [ 150/ 195] Loss 1.833214 Top1 57.619792 Top5 80.447917\n 2018-04-04 21:32:01,274 - ==> Top1: 56.606 Top5: 79.446 Loss: 1.893\n\n\n\n\nLet's look at the command line again:\n\n\n$ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\n\n\n\nIn this example, we prune a TorchVision pre-trained AlexNet network, using the following configuration:\n\n\n\n\nLearning-rate of 0.005\n\n\nPrint progress every 50 mini-batches.\n\n\nUse 44 worker threads to load data (make sure to use something suitable for your machine).\n\n\nRun for 90 epochs. Torchvision's pre-trained models did not store the epoch metadata, so pruning starts at epoch 0. When you train and prune your own networks, the last training epoch is saved as a metadata with the model. Therefore, when you load such models, the first epoch is not 0, but it is the last training epoch.\n\n\nThe pruning schedule is provided in \nalexnet.schedule_sensitivity.yaml\n\n\nLog files are written to directory \nlogs\n.\n\n\n\n\nExamples\n\n\nDistiller comes with several example schedules which can be used together with \ncompress_classifier.py\n.\nThese example schedules (YAML) files, contain the command line that is used in order to invoke the schedule (so that you can easily recreate the results in your environment), together with the results of the pruning or regularization. The results usually contain a table showing the sparsity of each of the model parameters, together with the validation and test top1, top5 and loss scores.\n\n\nFor more details on the example schedules, you can refer to the coverage of the \nModel Zoo\n.\n\n\n\n\nexamples/agp-pruning\n:\n\n\nAutomated Gradual Pruning (AGP) on MobileNet and ResNet18 (ImageNet dataset)\n\n\n\n\n\n\n\nexamples/hybrid\n:\n\n\nAlexNet AGP with 2D (kernel) regularization (ImageNet dataset)\n\n\nAlexNet sensitivity pruning with 2D regularization\n\n\n\n\n\n\n\nexamples/network_slimming\n:\n\n\nResNet20 Network Slimming (this is work-in-progress)\n\n\n\n\n\n\n\nexamples/pruning_filters_for_efficient_convnets\n:\n\n\nResNet56 baseline training (CIFAR10 dataset)\n\n\nResNet56 filter removal using filter ranking\n\n\n\n\n\n\n\nexamples/sensitivity_analysis\n:\n\n\nElement-wise pruning sensitivity-analysis:\n\n\nAlexNet (ImageNet)\n\n\nMobileNet (ImageNet)\n\n\nResNet18 (ImageNet)\n\n\nResNet20 (CIFAR10)\n\n\nResNet34 (ImageNet)\n\n\nFilter-wise pruning sensitivity-analysis:\n\n\nResNet20 (CIFAR10)\n\n\nResNet56 (CIFAR10)\n\n\n\n\n\n\n\nexamples/sensitivity-pruning\n:\n\n\nAlexNet sensitivity pruning with Iterative Pruning\n\n\nAlexNet sensitivity pruning with One-Shot Pruning\n\n\n\n\n\n\n\nexamples/ssl\n:\n\n\nResNet20 baseline training (CIFAR10 dataset)\n\n\nStructured Sparsity Learning (SSL) with layer removal on ResNet20\n\n\nSSL with channels removal on ResNet20\n\n\n\n\n\n\n\nexamples/quantization\n:\n\n\nAlexNet w. Batch-Norm (base FP32 + DoReFa)\n\n\nPre-activation ResNet20 on CIFAR10 (base FP32 + DoReFa)\n\n\nPre-activation ResNet18 on ImageNEt (base FP32 + DoReFa)\n\n\n\n\n\n\n\n\nExperiment reproducibility\n\n\nExperiment reproducibility is sometimes important. Pete Warden recently expounded about this in his \nblog\n.\n\nPyTorch's support for deterministic execution requires us to use only one thread for loading data (other wise the multi-threaded execution of the data loaders can create random order and change the results), and to set the seed of the CPU and GPU PRNGs. Using the \n--deterministic\n command-line flag and setting \nj=1\n will produce reproducible results (for the same PyTorch version).\n\n\nPerforming pruning sensitivity analysis\n\n\nDistiller supports element-wise and filter-wise pruning sensitivity analysis. In both cases, L1-norm is used to rank which elements or filters to prune. For example, when running filter-pruning sensitivity analysis, the L1-norm of the filters of each layer's weights tensor are calculated, and the bottom x% are set to zero. \n\nThe analysis process is quite long, because currently we use the entire test dataset to assess the accuracy performance at each pruning level of each weights tensor. Using a small dataset for this would save much time and we plan on assessing if this will provide sufficient results.\n\nResults are output as a CSV file (\nsensitivity.csv\n) and PNG file (\nsensitivity.png\n). The implementation is in \ndistiller/sensitivity.py\n and it contains further details about process and the format of the CSV file.\n\n\nThe example below performs element-wise pruning sensitivity analysis on ResNet20 for CIFAR10:\n\n\n$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j=1 --resume=../cifar10/resnet20/checkpoint_trained_dense.pth.tar --sense=element\n\n\n\n\nThe \nsense\n command-line argument can be set to either \nelement\n or \nfilter\n, depending on the type of analysis you want done.\n\n\nThere is also a \nJupyter notebook\n with example invocations, outputs and explanations.\n\n\n\"Direct\" Quantization Without Training\n\n\nDistiller supports 8-bit quantization of trained modules without re-training (using \nSymmetric Linear Quantization\n). So, any model (whether pruned or not) can be quantized.\n\nUse the \n--quantize\n command-line flag, together with \n--evaluate\n to evaluate the accuracy of your model after quantization. The following example qunatizes ResNet18 for ImageNet:\n\n\n$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize --evaluate\n\n\n\n\nGenerates:\n\n\nPreparing model for quantization\n--- test ---------------------\n50000 samples (256 per mini-batch)\nTest: [ 10/ 195] Loss 0.856354 Top1 79.257812 Top5 92.500000\nTest: [ 20/ 195] Loss 0.923131 Top1 76.953125 Top5 92.246094\nTest: [ 30/ 195] Loss 0.885186 Top1 77.955729 Top5 92.486979\nTest: [ 40/ 195] Loss 0.930263 Top1 76.181641 Top5 92.597656\nTest: [ 50/ 195] Loss 0.931062 Top1 75.726562 Top5 92.906250\nTest: [ 60/ 195] Loss 0.932019 Top1 75.651042 Top5 93.151042\nTest: [ 70/ 195] Loss 0.921287 Top1 76.060268 Top5 93.270089\nTest: [ 80/ 195] Loss 0.932539 Top1 75.986328 Top5 93.100586\nTest: [ 90/ 195] Loss 0.996000 Top1 74.700521 Top5 92.330729\nTest: [ 100/ 195] Loss 1.066699 Top1 73.289062 Top5 91.437500\nTest: [ 110/ 195] Loss 1.100970 Top1 72.574574 Top5 91.001420\nTest: [ 120/ 195] Loss 1.122376 Top1 72.268880 Top5 90.696615\nTest: [ 130/ 195] Loss 1.171726 Top1 71.198918 Top5 90.120192\nTest: [ 140/ 195] Loss 1.191500 Top1 70.797991 Top5 89.902344\nTest: [ 150/ 195] Loss 1.219954 Top1 70.210938 Top5 89.453125\nTest: [ 160/ 195] Loss 1.240942 Top1 69.855957 Top5 89.162598\nTest: [ 170/ 195] Loss 1.265741 Top1 69.342831 Top5 88.807445\nTest: [ 180/ 195] Loss 1.281185 Top1 69.051649 Top5 88.589410\nTest: [ 190/ 195] Loss 1.279682 Top1 69.019326 Top5 88.632812\n==> Top1: 69.130 Top5: 88.732 Loss: 1.276\n\n\n\n\nSummaries\n\n\nYou can use the sample compression application to generate model summary reports, such as the attributes and compute summary report (see screen capture below).\nYou can log sparsity statistics (written to console and CSV file), performance, optimizer and model information, and also create a PNG image of the DNN.\nCreating a PNG image is an experimental feature (it relies on features which are not available on PyTorch 3.1 and that we hope will be available in PyTorch's next release), so to use it you will need to compile the PyTorch master branch, and hope for the best ;-).\n\n\n$ python3 compress_classifier.py --resume=../ssl/checkpoints/checkpoint_trained_ch_regularized_dense.pth.tar -a=resnet20_cifar ../../../data.cifar10 --summary=compute\n\n\n\n\nGenerates:\n\n\n+----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\n| | Name | Type | Attrs | IFM | IFM volume | OFM | OFM volume | Weights volume | MACs |\n|----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------|\n| 0 | module.conv1 | Conv2d | k=(3, 3) | (1, 3, 32, 32) | 3072 | (1, 16, 32, 32) | 16384 | 432 | 442368 |\n| 1 | module.layer1.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 2 | module.layer1.0.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 3 | module.layer1.1.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 4 | module.layer1.1.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 5 | module.layer1.2.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 6 | module.layer1.2.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 7 | module.layer2.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 4608 | 1179648 |\n| 8 | module.layer2.0.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 9 | module.layer2.0.downsample.0 | Conv2d | k=(1, 1) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 512 | 131072 |\n| 10 | module.layer2.1.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 11 | module.layer2.1.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 12 | module.layer2.2.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 13 | module.layer2.2.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 14 | module.layer3.0.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 18432 | 1179648 |\n| 15 | module.layer3.0.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 16 | module.layer3.0.downsample.0 | Conv2d | k=(1, 1) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 2048 | 131072 |\n| 17 | module.layer3.1.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 18 | module.layer3.1.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 19 | module.layer3.2.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 20 | module.layer3.2.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 21 | module.fc | Linear | | (1, 64) | 64 | (1, 10) | 10 | 640 | 640 |\n+----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\nTotal MACs: 40,813,184\n\n\n\n\nUsing TensorBoard\n\n\nGoogle's \nTensorBoard\n is an excellent tool for visualizing the progress of DNN training. Distiller's logger supports writing performance indicators and parameter statistics in a file format that can be read by TensorBoard (Distiller uses TensorFlow's APIs in order to do this, which is why Distiller requires the installation of TensorFlow).\n\nTo view the graphs, invoke the TensorBoard server. For example:\n\n\n$ tensorboard --logdir=logs\n\n\n\n\nDistillers's setup (requirements.txt) installs TensorFlow for CPU. If you want a different installation, please follow the \nTensorFlow installation instructions\n.\n\n\nCollecting activations statistics\n\n\nIn CNNs with ReLU layers, ReLU activations (feature-maps) also exhibit a nice level of sparsity (50-60% sparsity is typical). \n\nYou can collect activation statistics using the \n--act_stats\n command-line flag.\n\nFor example:\n\n\n$ python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --resume=checkpoint.resnet56_cifar_baseline.pth.tar --act-stats=test -e\n\n\n\n\nThe \ntest\n parameter indicates that, in this example, we want to collect activation statistics during the \ntest\n phase. Note that we also used the \n-e\n command-line argument to indicate that we want to run a \ntest\n phase. The other two legal parameter values are \ntrain\n and \nvalid\n which collect activation statistics during the \ntraining\n and \nvalidation\n phases, respectively. \n\n\nCollectors and their collaterals\n\n\nAn instance of a subclass of \nActivationStatsCollector\n can be used to collect activation statistics. Currently, \nActivationStatsCollector\n has two types of subclasses: \nSummaryActivationStatsCollector\n and \nRecordsActivationStatsCollector\n.\n\nInstances of \nSummaryActivationStatsCollector\n compute the mean of some statistic of the activation. It is rather\nlight-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor.\n\nIn the sample compression application, \ncompress_classifier.py\n, we create a dictionary of collectors. For example:\n\n\nSummaryActivationStatsCollector(model,\n \"sparsity\",\n lambda t: 100 * distiller.utils.sparsity(t))\n\n\n\n\nThe lambda expression is invoked per activation encountered during forward passes, and the value it returns (in this case, the sparsity of the activation tensors, multiplied by 100) is stored in \nmodule.sparsity\n (\n\"sparsity\"\n is this collector's name). To access the statistics, you can invoke \ncollector.value()\n, or you can access each module's data directly.\n\n\nAnother type of collector is \nRecordsActivationStatsCollector\n which computes a hard-coded set of activations statistics and collects a\n\nrecord per activation\n. For obvious reasons, this is slower than instances of \nSummaryActivationStatsCollector\n.\nActivationStatsCollector\n default to collecting activations statistics only on the output activations of ReLU layers, but we can choose any layer type we want. In the example below we collect statistics from outputs of \ntorch.nn.Conv2d\n layers.\n\n\nRecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])\n\n\n\n\nCollectors can write their data to Excel workbooks (which are named using the collector's name), by invoking \ncollector.to_xlsx(path_to_workbook)\n. In \ncompress_classifier.py\n we currently create four different collectors which you can selectively disable. You can also add other statistics collectors and use a different function to compute your new statistic.\n\n\ncollectors = missingdict({\n \"sparsity\": SummaryActivationStatsCollector(model, \"sparsity\",\n lambda t: 100 * distiller.utils.sparsity(t)),\n \"l1_channels\": SummaryActivationStatsCollector(model, \"l1_channels\",\n distiller.utils.activation_channels_l1),\n \"apoz_channels\": SummaryActivationStatsCollector(model, \"apoz_channels\",\n distiller.utils.activation_channels_apoz),\n \"records\": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])})\n\n\n\n\nBy default, these Collectors write their data to files in the active log directory.\n\n\nYou can use a utility function, \ndistiller.log_activation_statsitics\n, to log the data of an \nActivationStatsCollector\n instance to one of the backend-loggers. For an example, the code below logs the \n\"sparsity\"\n collector to a TensorBoard log file.\n\n\ndistiller.log_activation_statsitics(epoch, \"train\", loggers=[tflogger],\n collector=collectors[\"sparsity\"])\n\n\n\n\nCaveats\n\n\nDistiller collects activations statistics using PyTorch's forward-hooks mechanism. Collectors iteratively register the modules' forward-hooks, and collectors are called during the forward traversal and get exposed to activation data. Registering for forward callbacks is performed like this:\n\n\nmodule.register_forward_hook\n\n\n\n\nThis makes apparent two limitations of this mechanism:\n\n\n\n\nWe can only register on PyTorch modules. This means that we can't register on the forward hook of a functionals such as \ntorch.nn.functional.relu\n and \ntorch.nn.functional.max_pool2d\n.\n\n Therefore, you may need to replace functionals with their module alternative. For example: \n\n\n\n\nclass MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n\n def forward(self, x):\n x = F.relu(self.conv1(x))\n return x\n\n\n\n\nCan be changed to: \n\n\nclass MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.relu = nn.ReLU(inplace=True)\n\n def forward(self, x):\n x = self.relu(self.conv1(x))\n return x\n\n\n\n\n\n\nWe can only use a module instance once in our models. If we use the same module several times, then we can't determine which node in the graph has invoked the callback, because the PyTorch callback signature \ndef hook(module, input, output)\n doesn't provide enough contextual information.\n\nTorchVision's \nResNet\n is an example of a model that uses the same instance of nn.ReLU multiple times: \n\n\n\n\nclass BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu(out) # <================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu(out) # <================\n return out\n\n\n\n\nIn Distiller we changed \nResNet\n to use multiple instances of nn.ReLU, and each instance is used only once: \n\n\nclass BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu1 = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.relu2 = nn.ReLU(inplace=True)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu1(out) # <================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu2(out) # <================\n return out\n\n\n\n\nUsing the Jupyter notebooks\n\n\nThe Jupyter notebooks contain many examples of how to use the statistics summaries generated by Distiller. They are explained in a separate page.\n\n\nGenerating this documentation\n\n\nInstall mkdocs and the required packages by executing:\n\n\n$ pip3 install -r doc-requirements.txt\n\n\n\n\nTo build the project documentation run:\n\n\n$ cd distiller/docs-src\n$ mkdocs build --clean\n\n\n\n\nThis will create a folder named 'site' which contains the documentation website.\nOpen distiller/docs/site/index.html to view the documentation home page.", + "location": "/usage/index.html", + "text": "Using the sample application\n\n\nThe Distiller repository contains a sample application, \ndistiller/examples/classifier_compression/compress_classifier.py\n, and a set of scheduling files which demonstrate Distiller's features. Following is a brief discussion of how to use this application and the accompanying schedules.\n\n\nYou might also want to refer to the following resources:\n\n\n\n\nAn \nexplanation\n of the scheduler file format.\n\n\nAn in-depth \ndiscussion\n of how we used these schedule files to implement several state-of-the-art DNN compression research papers.\n\n\n\n\nThe sample application supports various features for compression of image classification DNNs, and gives an example of how to integrate distiller in your own application. The code is documented and should be considered the best source of documentation, but we provide some elaboration here.\n\n\nThis diagram shows how where \ncompress_classifier.py\n fits in the compression workflow, and how we integrate the Jupyter notebooks as part of our research work.\n\n\n\nCommand line arguments\n\n\nTo get help on the command line arguments, invoke:\n\n\n$ python3 compress_classifier.py --help\n\n\n\n\nFor example:\n\n\n$ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\nParameters:\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n | 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n | 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n | 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n | 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n | 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n | 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n | 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n | 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n | 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:04,646 - Epoch: [89][ 50/ 500] Loss 2.175988 Top1 51.289063 Top5 74.023438\n 2018-04-04 21:31:06,427 - Epoch: [89][ 100/ 500] Loss 2.171564 Top1 51.175781 Top5 74.308594\n 2018-04-04 21:31:11,432 - Epoch: [89][ 150/ 500] Loss 2.159347 Top1 51.546875 Top5 74.473958\n 2018-04-04 21:31:14,364 - Epoch: [89][ 200/ 500] Loss 2.156857 Top1 51.585938 Top5 74.568359\n 2018-04-04 21:31:18,381 - Epoch: [89][ 250/ 500] Loss 2.152790 Top1 51.707813 Top5 74.681250\n 2018-04-04 21:31:22,195 - Epoch: [89][ 300/ 500] Loss 2.149962 Top1 51.791667 Top5 74.755208\n 2018-04-04 21:31:25,508 - Epoch: [89][ 350/ 500] Loss 2.150936 Top1 51.827009 Top5 74.767857\n 2018-04-04 21:31:29,538 - Epoch: [89][ 400/ 500] Loss 2.150853 Top1 51.781250 Top5 74.763672\n 2018-04-04 21:31:32,842 - Epoch: [89][ 450/ 500] Loss 2.150156 Top1 51.828125 Top5 74.821181\n 2018-04-04 21:31:35,338 - Epoch: [89][ 500/ 500] Loss 2.150417 Top1 51.833594 Top5 74.817187\n 2018-04-04 21:31:35,357 - ==\n Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:35,364 - Saving checkpoint\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:31:51,512 - Test: [ 50/ 195] Loss 1.487607 Top1 63.273438 Top5 85.695312\n 2018-04-04 21:31:55,015 - Test: [ 100/ 195] Loss 1.638043 Top1 60.636719 Top5 83.664062\n 2018-04-04 21:31:58,732 - Test: [ 150/ 195] Loss 1.833214 Top1 57.619792 Top5 80.447917\n 2018-04-04 21:32:01,274 - ==\n Top1: 56.606 Top5: 79.446 Loss: 1.893\n\n\n\n\nLet's look at the command line again:\n\n\n$ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\n\n\n\nIn this example, we prune a TorchVision pre-trained AlexNet network, using the following configuration:\n\n\n\n\nLearning-rate of 0.005\n\n\nPrint progress every 50 mini-batches.\n\n\nUse 44 worker threads to load data (make sure to use something suitable for your machine).\n\n\nRun for 90 epochs. Torchvision's pre-trained models did not store the epoch metadata, so pruning starts at epoch 0. When you train and prune your own networks, the last training epoch is saved as a metadata with the model. Therefore, when you load such models, the first epoch is not 0, but it is the last training epoch.\n\n\nThe pruning schedule is provided in \nalexnet.schedule_sensitivity.yaml\n\n\nLog files are written to directory \nlogs\n.\n\n\n\n\nExamples\n\n\nDistiller comes with several example schedules which can be used together with \ncompress_classifier.py\n.\nThese example schedules (YAML) files, contain the command line that is used in order to invoke the schedule (so that you can easily recreate the results in your environment), together with the results of the pruning or regularization. The results usually contain a table showing the sparsity of each of the model parameters, together with the validation and test top1, top5 and loss scores.\n\n\nFor more details on the example schedules, you can refer to the coverage of the \nModel Zoo\n.\n\n\n\n\nexamples/agp-pruning\n:\n\n\nAutomated Gradual Pruning (AGP) on MobileNet and ResNet18 (ImageNet dataset)\n\n\n\n\n\n\n\nexamples/hybrid\n:\n\n\nAlexNet AGP with 2D (kernel) regularization (ImageNet dataset)\n\n\nAlexNet sensitivity pruning with 2D regularization\n\n\n\n\n\n\n\nexamples/network_slimming\n:\n\n\nResNet20 Network Slimming (this is work-in-progress)\n\n\n\n\n\n\n\nexamples/pruning_filters_for_efficient_convnets\n:\n\n\nResNet56 baseline training (CIFAR10 dataset)\n\n\nResNet56 filter removal using filter ranking\n\n\n\n\n\n\n\nexamples/sensitivity_analysis\n:\n\n\nElement-wise pruning sensitivity-analysis:\n\n\nAlexNet (ImageNet)\n\n\nMobileNet (ImageNet)\n\n\nResNet18 (ImageNet)\n\n\nResNet20 (CIFAR10)\n\n\nResNet34 (ImageNet)\n\n\nFilter-wise pruning sensitivity-analysis:\n\n\nResNet20 (CIFAR10)\n\n\nResNet56 (CIFAR10)\n\n\n\n\n\n\n\nexamples/sensitivity-pruning\n:\n\n\nAlexNet sensitivity pruning with Iterative Pruning\n\n\nAlexNet sensitivity pruning with One-Shot Pruning\n\n\n\n\n\n\n\nexamples/ssl\n:\n\n\nResNet20 baseline training (CIFAR10 dataset)\n\n\nStructured Sparsity Learning (SSL) with layer removal on ResNet20\n\n\nSSL with channels removal on ResNet20\n\n\n\n\n\n\n\nexamples/quantization\n:\n\n\nAlexNet w. Batch-Norm (base FP32 + DoReFa)\n\n\nPre-activation ResNet20 on CIFAR10 (base FP32 + DoReFa)\n\n\nPre-activation ResNet18 on ImageNEt (base FP32 + DoReFa)\n\n\n\n\n\n\n\n\nExperiment reproducibility\n\n\nExperiment reproducibility is sometimes important. Pete Warden recently expounded about this in his \nblog\n.\n\nPyTorch's support for deterministic execution requires us to use only one thread for loading data (other wise the multi-threaded execution of the data loaders can create random order and change the results), and to set the seed of the CPU and GPU PRNGs. Using the \n--deterministic\n command-line flag and setting \nj=1\n will produce reproducible results (for the same PyTorch version).\n\n\nPerforming pruning sensitivity analysis\n\n\nDistiller supports element-wise and filter-wise pruning sensitivity analysis. In both cases, L1-norm is used to rank which elements or filters to prune. For example, when running filter-pruning sensitivity analysis, the L1-norm of the filters of each layer's weights tensor are calculated, and the bottom x% are set to zero. \n\nThe analysis process is quite long, because currently we use the entire test dataset to assess the accuracy performance at each pruning level of each weights tensor. Using a small dataset for this would save much time and we plan on assessing if this will provide sufficient results.\n\nResults are output as a CSV file (\nsensitivity.csv\n) and PNG file (\nsensitivity.png\n). The implementation is in \ndistiller/sensitivity.py\n and it contains further details about process and the format of the CSV file.\n\n\nThe example below performs element-wise pruning sensitivity analysis on ResNet20 for CIFAR10:\n\n\n$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j=1 --resume=../cifar10/resnet20/checkpoint_trained_dense.pth.tar --sense=element\n\n\n\n\nThe \nsense\n command-line argument can be set to either \nelement\n or \nfilter\n, depending on the type of analysis you want done.\n\n\nThere is also a \nJupyter notebook\n with example invocations, outputs and explanations.\n\n\nPost-Training Quantization\n\n\nDistiller supports post-training quantization of trained modules without re-training (using \nRange-Based Linear Quantization\n). So, any model (whether pruned or not) can be quantized. To invoke post-training quantization, use \n--quantize-eval\n along with \n--evaluate\n. Additional arguments are available to control parameters of the quantization:\n\n\nArguments controlling quantization at evaluation time(\npost-training quantization\n):\n --quantize-eval, --qe\n Apply linear quantization to model before evaluation.\n Applicable only if --evaluate is also set\n --qe-mode QE_MODE, --qem QE_MODE\n Linear quantization mode. Choices: asym_s | asym_u |\n sym\n --qe-bits-acts NUM_BITS, --qeba NUM_BITS\n Number of bits for quantization of activations\n --qe-bits-wts NUM_BITS, --qebw NUM_BITS\n Number of bits for quantization of weights\n --qe-bits-accum NUM_BITS\n Number of bits for quantization of the accumulator\n --qe-clip-acts, --qeca\n Enable clipping of activations using min/max values\n averaging over batch\n --qe-no-clip-layers LAYER_NAME [LAYER_NAME ...], --qencl LAYER_NAME [LAYER_NAME ...]\n List of fully-qualified layer names for which not to\n clip activations. Applicable only if --qe-clip-acts is\n also set\n --qe-per-channel, --qepc\n Enable per-channel quantization of weights (per output channel)\n\n\n\n\n\nThe following example qunatizes ResNet18 for ImageNet:\n\n\n$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize-eval --evaluate\n\n\n\n\nA checkpoint with the quantized model will be dumped in the run directory. It will contain the quantized model parameters (the data type will still be FP32, but the values will be integers). The calculated quantization parameters (scale and zero-point) are stored as well in each quantized layer.\n\n\nFor more examples of post-training quantization see \nhere\n\n\nSummaries\n\n\nYou can use the sample compression application to generate model summary reports, such as the attributes and compute summary report (see screen capture below).\nYou can log sparsity statistics (written to console and CSV file), performance, optimizer and model information, and also create a PNG image of the DNN.\nCreating a PNG image is an experimental feature (it relies on features which are not available on PyTorch 3.1 and that we hope will be available in PyTorch's next release), so to use it you will need to compile the PyTorch master branch, and hope for the best ;-).\n\n\n$ python3 compress_classifier.py --resume=../ssl/checkpoints/checkpoint_trained_ch_regularized_dense.pth.tar -a=resnet20_cifar ../../../data.cifar10 --summary=compute\n\n\n\n\nGenerates:\n\n\n+----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\n| | Name | Type | Attrs | IFM | IFM volume | OFM | OFM volume | Weights volume | MACs |\n|----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------|\n| 0 | module.conv1 | Conv2d | k=(3, 3) | (1, 3, 32, 32) | 3072 | (1, 16, 32, 32) | 16384 | 432 | 442368 |\n| 1 | module.layer1.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 2 | module.layer1.0.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 3 | module.layer1.1.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 4 | module.layer1.1.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 5 | module.layer1.2.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 6 | module.layer1.2.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 7 | module.layer2.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 4608 | 1179648 |\n| 8 | module.layer2.0.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 9 | module.layer2.0.downsample.0 | Conv2d | k=(1, 1) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 512 | 131072 |\n| 10 | module.layer2.1.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 11 | module.layer2.1.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 12 | module.layer2.2.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 13 | module.layer2.2.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 14 | module.layer3.0.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 18432 | 1179648 |\n| 15 | module.layer3.0.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 16 | module.layer3.0.downsample.0 | Conv2d | k=(1, 1) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 2048 | 131072 |\n| 17 | module.layer3.1.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 18 | module.layer3.1.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 19 | module.layer3.2.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 20 | module.layer3.2.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 21 | module.fc | Linear | | (1, 64) | 64 | (1, 10) | 10 | 640 | 640 |\n+----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\nTotal MACs: 40,813,184\n\n\n\n\nUsing TensorBoard\n\n\nGoogle's \nTensorBoard\n is an excellent tool for visualizing the progress of DNN training. Distiller's logger supports writing performance indicators and parameter statistics in a file format that can be read by TensorBoard (Distiller uses TensorFlow's APIs in order to do this, which is why Distiller requires the installation of TensorFlow).\n\nTo view the graphs, invoke the TensorBoard server. For example:\n\n\n$ tensorboard --logdir=logs\n\n\n\n\nDistillers's setup (requirements.txt) installs TensorFlow for CPU. If you want a different installation, please follow the \nTensorFlow installation instructions\n.\n\n\nCollecting activations statistics\n\n\nIn CNNs with ReLU layers, ReLU activations (feature-maps) also exhibit a nice level of sparsity (50-60% sparsity is typical). \n\nYou can collect activation statistics using the \n--act_stats\n command-line flag.\n\nFor example:\n\n\n$ python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --resume=checkpoint.resnet56_cifar_baseline.pth.tar --act-stats=test -e\n\n\n\n\nThe \ntest\n parameter indicates that, in this example, we want to collect activation statistics during the \ntest\n phase. Note that we also used the \n-e\n command-line argument to indicate that we want to run a \ntest\n phase. The other two legal parameter values are \ntrain\n and \nvalid\n which collect activation statistics during the \ntraining\n and \nvalidation\n phases, respectively. \n\n\nCollectors and their collaterals\n\n\nAn instance of a subclass of \nActivationStatsCollector\n can be used to collect activation statistics. Currently, \nActivationStatsCollector\n has two types of subclasses: \nSummaryActivationStatsCollector\n and \nRecordsActivationStatsCollector\n.\n\nInstances of \nSummaryActivationStatsCollector\n compute the mean of some statistic of the activation. It is rather\nlight-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor.\n\nIn the sample compression application, \ncompress_classifier.py\n, we create a dictionary of collectors. For example:\n\n\nSummaryActivationStatsCollector(model,\n \nsparsity\n,\n lambda t: 100 * distiller.utils.sparsity(t))\n\n\n\n\nThe lambda expression is invoked per activation encountered during forward passes, and the value it returns (in this case, the sparsity of the activation tensors, multiplied by 100) is stored in \nmodule.sparsity\n (\n\"sparsity\"\n is this collector's name). To access the statistics, you can invoke \ncollector.value()\n, or you can access each module's data directly.\n\n\nAnother type of collector is \nRecordsActivationStatsCollector\n which computes a hard-coded set of activations statistics and collects a\n\nrecord per activation\n. For obvious reasons, this is slower than instances of \nSummaryActivationStatsCollector\n.\nActivationStatsCollector\n default to collecting activations statistics only on the output activations of ReLU layers, but we can choose any layer type we want. In the example below we collect statistics from outputs of \ntorch.nn.Conv2d\n layers.\n\n\nRecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])\n\n\n\n\nCollectors can write their data to Excel workbooks (which are named using the collector's name), by invoking \ncollector.to_xlsx(path_to_workbook)\n. In \ncompress_classifier.py\n we currently create four different collectors which you can selectively disable. You can also add other statistics collectors and use a different function to compute your new statistic.\n\n\ncollectors = missingdict({\n \nsparsity\n: SummaryActivationStatsCollector(model, \nsparsity\n,\n lambda t: 100 * distiller.utils.sparsity(t)),\n \nl1_channels\n: SummaryActivationStatsCollector(model, \nl1_channels\n,\n distiller.utils.activation_channels_l1),\n \napoz_channels\n: SummaryActivationStatsCollector(model, \napoz_channels\n,\n distiller.utils.activation_channels_apoz),\n \nrecords\n: RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])})\n\n\n\n\nBy default, these Collectors write their data to files in the active log directory.\n\n\nYou can use a utility function, \ndistiller.log_activation_statsitics\n, to log the data of an \nActivationStatsCollector\n instance to one of the backend-loggers. For an example, the code below logs the \n\"sparsity\"\n collector to a TensorBoard log file.\n\n\ndistiller.log_activation_statsitics(epoch, \ntrain\n, loggers=[tflogger],\n collector=collectors[\nsparsity\n])\n\n\n\n\nCaveats\n\n\nDistiller collects activations statistics using PyTorch's forward-hooks mechanism. Collectors iteratively register the modules' forward-hooks, and collectors are called during the forward traversal and get exposed to activation data. Registering for forward callbacks is performed like this:\n\n\nmodule.register_forward_hook\n\n\n\n\nThis makes apparent two limitations of this mechanism:\n\n\n\n\nWe can only register on PyTorch modules. This means that we can't register on the forward hook of a functionals such as \ntorch.nn.functional.relu\n and \ntorch.nn.functional.max_pool2d\n.\n\n Therefore, you may need to replace functionals with their module alternative. For example: \n\n\n\n\nclass MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n\n def forward(self, x):\n x = F.relu(self.conv1(x))\n return x\n\n\n\n\nCan be changed to: \n\n\nclass MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.relu = nn.ReLU(inplace=True)\n\n def forward(self, x):\n x = self.relu(self.conv1(x))\n return x\n\n\n\n\n\n\nWe can only use a module instance once in our models. If we use the same module several times, then we can't determine which node in the graph has invoked the callback, because the PyTorch callback signature \ndef hook(module, input, output)\n doesn't provide enough contextual information.\n\nTorchVision's \nResNet\n is an example of a model that uses the same instance of nn.ReLU multiple times: \n\n\n\n\nclass BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu(out) # \n================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu(out) # \n================\n return out\n\n\n\n\nIn Distiller we changed \nResNet\n to use multiple instances of nn.ReLU, and each instance is used only once: \n\n\nclass BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu1 = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.relu2 = nn.ReLU(inplace=True)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu1(out) # \n================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu2(out) # \n================\n return out\n\n\n\n\nUsing the Jupyter notebooks\n\n\nThe Jupyter notebooks contain many examples of how to use the statistics summaries generated by Distiller. They are explained in a separate page.\n\n\nGenerating this documentation\n\n\nInstall mkdocs and the required packages by executing:\n\n\n$ pip3 install -r doc-requirements.txt\n\n\n\n\nTo build the project documentation run:\n\n\n$ cd distiller/docs-src\n$ mkdocs build --clean\n\n\n\n\nThis will create a folder named 'site' which contains the documentation website.\nOpen distiller/docs/site/index.html to view the documentation home page.", "title": "Usage" - }, + }, { - "location": "/usage/index.html#using-the-sample-application", - "text": "The Distiller repository contains a sample application, distiller/examples/classifier_compression/compress_classifier.py , and a set of scheduling files which demonstrate Distiller's features. Following is a brief discussion of how to use this application and the accompanying schedules. You might also want to refer to the following resources: An explanation of the scheduler file format. An in-depth discussion of how we used these schedule files to implement several state-of-the-art DNN compression research papers. The sample application supports various features for compression of image classification DNNs, and gives an example of how to integrate distiller in your own application. The code is documented and should be considered the best source of documentation, but we provide some elaboration here. This diagram shows how where compress_classifier.py fits in the compression workflow, and how we integrate the Jupyter notebooks as part of our research work.", + "location": "/usage/index.html#using-the-sample-application", + "text": "The Distiller repository contains a sample application, distiller/examples/classifier_compression/compress_classifier.py , and a set of scheduling files which demonstrate Distiller's features. Following is a brief discussion of how to use this application and the accompanying schedules. You might also want to refer to the following resources: An explanation of the scheduler file format. An in-depth discussion of how we used these schedule files to implement several state-of-the-art DNN compression research papers. The sample application supports various features for compression of image classification DNNs, and gives an example of how to integrate distiller in your own application. The code is documented and should be considered the best source of documentation, but we provide some elaboration here. This diagram shows how where compress_classifier.py fits in the compression workflow, and how we integrate the Jupyter notebooks as part of our research work.", "title": "Using the sample application" - }, + }, { - "location": "/usage/index.html#command-line-arguments", - "text": "To get help on the command line arguments, invoke: $ python3 compress_classifier.py --help For example: $ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\nParameters:\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n | 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n | 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n | 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n | 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n | 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n | 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n | 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n | 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n | 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:04,646 - Epoch: [89][ 50/ 500] Loss 2.175988 Top1 51.289063 Top5 74.023438\n 2018-04-04 21:31:06,427 - Epoch: [89][ 100/ 500] Loss 2.171564 Top1 51.175781 Top5 74.308594\n 2018-04-04 21:31:11,432 - Epoch: [89][ 150/ 500] Loss 2.159347 Top1 51.546875 Top5 74.473958\n 2018-04-04 21:31:14,364 - Epoch: [89][ 200/ 500] Loss 2.156857 Top1 51.585938 Top5 74.568359\n 2018-04-04 21:31:18,381 - Epoch: [89][ 250/ 500] Loss 2.152790 Top1 51.707813 Top5 74.681250\n 2018-04-04 21:31:22,195 - Epoch: [89][ 300/ 500] Loss 2.149962 Top1 51.791667 Top5 74.755208\n 2018-04-04 21:31:25,508 - Epoch: [89][ 350/ 500] Loss 2.150936 Top1 51.827009 Top5 74.767857\n 2018-04-04 21:31:29,538 - Epoch: [89][ 400/ 500] Loss 2.150853 Top1 51.781250 Top5 74.763672\n 2018-04-04 21:31:32,842 - Epoch: [89][ 450/ 500] Loss 2.150156 Top1 51.828125 Top5 74.821181\n 2018-04-04 21:31:35,338 - Epoch: [89][ 500/ 500] Loss 2.150417 Top1 51.833594 Top5 74.817187\n 2018-04-04 21:31:35,357 - ==> Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:35,364 - Saving checkpoint\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:31:51,512 - Test: [ 50/ 195] Loss 1.487607 Top1 63.273438 Top5 85.695312\n 2018-04-04 21:31:55,015 - Test: [ 100/ 195] Loss 1.638043 Top1 60.636719 Top5 83.664062\n 2018-04-04 21:31:58,732 - Test: [ 150/ 195] Loss 1.833214 Top1 57.619792 Top5 80.447917\n 2018-04-04 21:32:01,274 - ==> Top1: 56.606 Top5: 79.446 Loss: 1.893 Let's look at the command line again: $ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml In this example, we prune a TorchVision pre-trained AlexNet network, using the following configuration: Learning-rate of 0.005 Print progress every 50 mini-batches. Use 44 worker threads to load data (make sure to use something suitable for your machine). Run for 90 epochs. Torchvision's pre-trained models did not store the epoch metadata, so pruning starts at epoch 0. When you train and prune your own networks, the last training epoch is saved as a metadata with the model. Therefore, when you load such models, the first epoch is not 0, but it is the last training epoch. The pruning schedule is provided in alexnet.schedule_sensitivity.yaml Log files are written to directory logs .", + "location": "/usage/index.html#command-line-arguments", + "text": "To get help on the command line arguments, invoke: $ python3 compress_classifier.py --help For example: $ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\nParameters:\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n | 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n | 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n | 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n | 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n | 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n | 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n | 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n | 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n | 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:04,646 - Epoch: [89][ 50/ 500] Loss 2.175988 Top1 51.289063 Top5 74.023438\n 2018-04-04 21:31:06,427 - Epoch: [89][ 100/ 500] Loss 2.171564 Top1 51.175781 Top5 74.308594\n 2018-04-04 21:31:11,432 - Epoch: [89][ 150/ 500] Loss 2.159347 Top1 51.546875 Top5 74.473958\n 2018-04-04 21:31:14,364 - Epoch: [89][ 200/ 500] Loss 2.156857 Top1 51.585938 Top5 74.568359\n 2018-04-04 21:31:18,381 - Epoch: [89][ 250/ 500] Loss 2.152790 Top1 51.707813 Top5 74.681250\n 2018-04-04 21:31:22,195 - Epoch: [89][ 300/ 500] Loss 2.149962 Top1 51.791667 Top5 74.755208\n 2018-04-04 21:31:25,508 - Epoch: [89][ 350/ 500] Loss 2.150936 Top1 51.827009 Top5 74.767857\n 2018-04-04 21:31:29,538 - Epoch: [89][ 400/ 500] Loss 2.150853 Top1 51.781250 Top5 74.763672\n 2018-04-04 21:31:32,842 - Epoch: [89][ 450/ 500] Loss 2.150156 Top1 51.828125 Top5 74.821181\n 2018-04-04 21:31:35,338 - Epoch: [89][ 500/ 500] Loss 2.150417 Top1 51.833594 Top5 74.817187\n 2018-04-04 21:31:35,357 - == Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:35,364 - Saving checkpoint\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:31:51,512 - Test: [ 50/ 195] Loss 1.487607 Top1 63.273438 Top5 85.695312\n 2018-04-04 21:31:55,015 - Test: [ 100/ 195] Loss 1.638043 Top1 60.636719 Top5 83.664062\n 2018-04-04 21:31:58,732 - Test: [ 150/ 195] Loss 1.833214 Top1 57.619792 Top5 80.447917\n 2018-04-04 21:32:01,274 - == Top1: 56.606 Top5: 79.446 Loss: 1.893 Let's look at the command line again: $ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml In this example, we prune a TorchVision pre-trained AlexNet network, using the following configuration: Learning-rate of 0.005 Print progress every 50 mini-batches. Use 44 worker threads to load data (make sure to use something suitable for your machine). Run for 90 epochs. Torchvision's pre-trained models did not store the epoch metadata, so pruning starts at epoch 0. When you train and prune your own networks, the last training epoch is saved as a metadata with the model. Therefore, when you load such models, the first epoch is not 0, but it is the last training epoch. The pruning schedule is provided in alexnet.schedule_sensitivity.yaml Log files are written to directory logs .", "title": "Command line arguments" - }, + }, { - "location": "/usage/index.html#examples", - "text": "Distiller comes with several example schedules which can be used together with compress_classifier.py .\nThese example schedules (YAML) files, contain the command line that is used in order to invoke the schedule (so that you can easily recreate the results in your environment), together with the results of the pruning or regularization. The results usually contain a table showing the sparsity of each of the model parameters, together with the validation and test top1, top5 and loss scores. For more details on the example schedules, you can refer to the coverage of the Model Zoo . examples/agp-pruning : Automated Gradual Pruning (AGP) on MobileNet and ResNet18 (ImageNet dataset) examples/hybrid : AlexNet AGP with 2D (kernel) regularization (ImageNet dataset) AlexNet sensitivity pruning with 2D regularization examples/network_slimming : ResNet20 Network Slimming (this is work-in-progress) examples/pruning_filters_for_efficient_convnets : ResNet56 baseline training (CIFAR10 dataset) ResNet56 filter removal using filter ranking examples/sensitivity_analysis : Element-wise pruning sensitivity-analysis: AlexNet (ImageNet) MobileNet (ImageNet) ResNet18 (ImageNet) ResNet20 (CIFAR10) ResNet34 (ImageNet) Filter-wise pruning sensitivity-analysis: ResNet20 (CIFAR10) ResNet56 (CIFAR10) examples/sensitivity-pruning : AlexNet sensitivity pruning with Iterative Pruning AlexNet sensitivity pruning with One-Shot Pruning examples/ssl : ResNet20 baseline training (CIFAR10 dataset) Structured Sparsity Learning (SSL) with layer removal on ResNet20 SSL with channels removal on ResNet20 examples/quantization : AlexNet w. Batch-Norm (base FP32 + DoReFa) Pre-activation ResNet20 on CIFAR10 (base FP32 + DoReFa) Pre-activation ResNet18 on ImageNEt (base FP32 + DoReFa)", + "location": "/usage/index.html#examples", + "text": "Distiller comes with several example schedules which can be used together with compress_classifier.py .\nThese example schedules (YAML) files, contain the command line that is used in order to invoke the schedule (so that you can easily recreate the results in your environment), together with the results of the pruning or regularization. The results usually contain a table showing the sparsity of each of the model parameters, together with the validation and test top1, top5 and loss scores. For more details on the example schedules, you can refer to the coverage of the Model Zoo . examples/agp-pruning : Automated Gradual Pruning (AGP) on MobileNet and ResNet18 (ImageNet dataset) examples/hybrid : AlexNet AGP with 2D (kernel) regularization (ImageNet dataset) AlexNet sensitivity pruning with 2D regularization examples/network_slimming : ResNet20 Network Slimming (this is work-in-progress) examples/pruning_filters_for_efficient_convnets : ResNet56 baseline training (CIFAR10 dataset) ResNet56 filter removal using filter ranking examples/sensitivity_analysis : Element-wise pruning sensitivity-analysis: AlexNet (ImageNet) MobileNet (ImageNet) ResNet18 (ImageNet) ResNet20 (CIFAR10) ResNet34 (ImageNet) Filter-wise pruning sensitivity-analysis: ResNet20 (CIFAR10) ResNet56 (CIFAR10) examples/sensitivity-pruning : AlexNet sensitivity pruning with Iterative Pruning AlexNet sensitivity pruning with One-Shot Pruning examples/ssl : ResNet20 baseline training (CIFAR10 dataset) Structured Sparsity Learning (SSL) with layer removal on ResNet20 SSL with channels removal on ResNet20 examples/quantization : AlexNet w. Batch-Norm (base FP32 + DoReFa) Pre-activation ResNet20 on CIFAR10 (base FP32 + DoReFa) Pre-activation ResNet18 on ImageNEt (base FP32 + DoReFa)", "title": "Examples" - }, + }, { - "location": "/usage/index.html#experiment-reproducibility", - "text": "Experiment reproducibility is sometimes important. Pete Warden recently expounded about this in his blog . \nPyTorch's support for deterministic execution requires us to use only one thread for loading data (other wise the multi-threaded execution of the data loaders can create random order and change the results), and to set the seed of the CPU and GPU PRNGs. Using the --deterministic command-line flag and setting j=1 will produce reproducible results (for the same PyTorch version).", + "location": "/usage/index.html#experiment-reproducibility", + "text": "Experiment reproducibility is sometimes important. Pete Warden recently expounded about this in his blog . \nPyTorch's support for deterministic execution requires us to use only one thread for loading data (other wise the multi-threaded execution of the data loaders can create random order and change the results), and to set the seed of the CPU and GPU PRNGs. Using the --deterministic command-line flag and setting j=1 will produce reproducible results (for the same PyTorch version).", "title": "Experiment reproducibility" - }, + }, { - "location": "/usage/index.html#performing-pruning-sensitivity-analysis", - "text": "Distiller supports element-wise and filter-wise pruning sensitivity analysis. In both cases, L1-norm is used to rank which elements or filters to prune. For example, when running filter-pruning sensitivity analysis, the L1-norm of the filters of each layer's weights tensor are calculated, and the bottom x% are set to zero. \nThe analysis process is quite long, because currently we use the entire test dataset to assess the accuracy performance at each pruning level of each weights tensor. Using a small dataset for this would save much time and we plan on assessing if this will provide sufficient results. \nResults are output as a CSV file ( sensitivity.csv ) and PNG file ( sensitivity.png ). The implementation is in distiller/sensitivity.py and it contains further details about process and the format of the CSV file. The example below performs element-wise pruning sensitivity analysis on ResNet20 for CIFAR10: $ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j=1 --resume=../cifar10/resnet20/checkpoint_trained_dense.pth.tar --sense=element The sense command-line argument can be set to either element or filter , depending on the type of analysis you want done. There is also a Jupyter notebook with example invocations, outputs and explanations.", + "location": "/usage/index.html#performing-pruning-sensitivity-analysis", + "text": "Distiller supports element-wise and filter-wise pruning sensitivity analysis. In both cases, L1-norm is used to rank which elements or filters to prune. For example, when running filter-pruning sensitivity analysis, the L1-norm of the filters of each layer's weights tensor are calculated, and the bottom x% are set to zero. \nThe analysis process is quite long, because currently we use the entire test dataset to assess the accuracy performance at each pruning level of each weights tensor. Using a small dataset for this would save much time and we plan on assessing if this will provide sufficient results. \nResults are output as a CSV file ( sensitivity.csv ) and PNG file ( sensitivity.png ). The implementation is in distiller/sensitivity.py and it contains further details about process and the format of the CSV file. The example below performs element-wise pruning sensitivity analysis on ResNet20 for CIFAR10: $ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j=1 --resume=../cifar10/resnet20/checkpoint_trained_dense.pth.tar --sense=element The sense command-line argument can be set to either element or filter , depending on the type of analysis you want done. There is also a Jupyter notebook with example invocations, outputs and explanations.", "title": "Performing pruning sensitivity analysis" - }, + }, { - "location": "/usage/index.html#direct-quantization-without-training", - "text": "Distiller supports 8-bit quantization of trained modules without re-training (using Symmetric Linear Quantization ). So, any model (whether pruned or not) can be quantized. \nUse the --quantize command-line flag, together with --evaluate to evaluate the accuracy of your model after quantization. The following example qunatizes ResNet18 for ImageNet: $ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize --evaluate Generates: Preparing model for quantization\n--- test ---------------------\n50000 samples (256 per mini-batch)\nTest: [ 10/ 195] Loss 0.856354 Top1 79.257812 Top5 92.500000\nTest: [ 20/ 195] Loss 0.923131 Top1 76.953125 Top5 92.246094\nTest: [ 30/ 195] Loss 0.885186 Top1 77.955729 Top5 92.486979\nTest: [ 40/ 195] Loss 0.930263 Top1 76.181641 Top5 92.597656\nTest: [ 50/ 195] Loss 0.931062 Top1 75.726562 Top5 92.906250\nTest: [ 60/ 195] Loss 0.932019 Top1 75.651042 Top5 93.151042\nTest: [ 70/ 195] Loss 0.921287 Top1 76.060268 Top5 93.270089\nTest: [ 80/ 195] Loss 0.932539 Top1 75.986328 Top5 93.100586\nTest: [ 90/ 195] Loss 0.996000 Top1 74.700521 Top5 92.330729\nTest: [ 100/ 195] Loss 1.066699 Top1 73.289062 Top5 91.437500\nTest: [ 110/ 195] Loss 1.100970 Top1 72.574574 Top5 91.001420\nTest: [ 120/ 195] Loss 1.122376 Top1 72.268880 Top5 90.696615\nTest: [ 130/ 195] Loss 1.171726 Top1 71.198918 Top5 90.120192\nTest: [ 140/ 195] Loss 1.191500 Top1 70.797991 Top5 89.902344\nTest: [ 150/ 195] Loss 1.219954 Top1 70.210938 Top5 89.453125\nTest: [ 160/ 195] Loss 1.240942 Top1 69.855957 Top5 89.162598\nTest: [ 170/ 195] Loss 1.265741 Top1 69.342831 Top5 88.807445\nTest: [ 180/ 195] Loss 1.281185 Top1 69.051649 Top5 88.589410\nTest: [ 190/ 195] Loss 1.279682 Top1 69.019326 Top5 88.632812\n==> Top1: 69.130 Top5: 88.732 Loss: 1.276", - "title": "\"Direct\" Quantization Without Training" - }, + "location": "/usage/index.html#post-training-quantization", + "text": "Distiller supports post-training quantization of trained modules without re-training (using Range-Based Linear Quantization ). So, any model (whether pruned or not) can be quantized. To invoke post-training quantization, use --quantize-eval along with --evaluate . Additional arguments are available to control parameters of the quantization: Arguments controlling quantization at evaluation time( post-training quantization ):\n --quantize-eval, --qe\n Apply linear quantization to model before evaluation.\n Applicable only if --evaluate is also set\n --qe-mode QE_MODE, --qem QE_MODE\n Linear quantization mode. Choices: asym_s | asym_u |\n sym\n --qe-bits-acts NUM_BITS, --qeba NUM_BITS\n Number of bits for quantization of activations\n --qe-bits-wts NUM_BITS, --qebw NUM_BITS\n Number of bits for quantization of weights\n --qe-bits-accum NUM_BITS\n Number of bits for quantization of the accumulator\n --qe-clip-acts, --qeca\n Enable clipping of activations using min/max values\n averaging over batch\n --qe-no-clip-layers LAYER_NAME [LAYER_NAME ...], --qencl LAYER_NAME [LAYER_NAME ...]\n List of fully-qualified layer names for which not to\n clip activations. Applicable only if --qe-clip-acts is\n also set\n --qe-per-channel, --qepc\n Enable per-channel quantization of weights (per output channel) The following example qunatizes ResNet18 for ImageNet: $ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize-eval --evaluate A checkpoint with the quantized model will be dumped in the run directory. It will contain the quantized model parameters (the data type will still be FP32, but the values will be integers). The calculated quantization parameters (scale and zero-point) are stored as well in each quantized layer. For more examples of post-training quantization see here", + "title": "Post-Training Quantization" + }, { - "location": "/usage/index.html#summaries", - "text": "You can use the sample compression application to generate model summary reports, such as the attributes and compute summary report (see screen capture below).\nYou can log sparsity statistics (written to console and CSV file), performance, optimizer and model information, and also create a PNG image of the DNN.\nCreating a PNG image is an experimental feature (it relies on features which are not available on PyTorch 3.1 and that we hope will be available in PyTorch's next release), so to use it you will need to compile the PyTorch master branch, and hope for the best ;-). $ python3 compress_classifier.py --resume=../ssl/checkpoints/checkpoint_trained_ch_regularized_dense.pth.tar -a=resnet20_cifar ../../../data.cifar10 --summary=compute Generates: +----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\n| | Name | Type | Attrs | IFM | IFM volume | OFM | OFM volume | Weights volume | MACs |\n|----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------|\n| 0 | module.conv1 | Conv2d | k=(3, 3) | (1, 3, 32, 32) | 3072 | (1, 16, 32, 32) | 16384 | 432 | 442368 |\n| 1 | module.layer1.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 2 | module.layer1.0.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 3 | module.layer1.1.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 4 | module.layer1.1.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 5 | module.layer1.2.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 6 | module.layer1.2.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 7 | module.layer2.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 4608 | 1179648 |\n| 8 | module.layer2.0.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 9 | module.layer2.0.downsample.0 | Conv2d | k=(1, 1) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 512 | 131072 |\n| 10 | module.layer2.1.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 11 | module.layer2.1.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 12 | module.layer2.2.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 13 | module.layer2.2.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 14 | module.layer3.0.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 18432 | 1179648 |\n| 15 | module.layer3.0.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 16 | module.layer3.0.downsample.0 | Conv2d | k=(1, 1) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 2048 | 131072 |\n| 17 | module.layer3.1.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 18 | module.layer3.1.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 19 | module.layer3.2.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 20 | module.layer3.2.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 21 | module.fc | Linear | | (1, 64) | 64 | (1, 10) | 10 | 640 | 640 |\n+----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\nTotal MACs: 40,813,184", + "location": "/usage/index.html#summaries", + "text": "You can use the sample compression application to generate model summary reports, such as the attributes and compute summary report (see screen capture below).\nYou can log sparsity statistics (written to console and CSV file), performance, optimizer and model information, and also create a PNG image of the DNN.\nCreating a PNG image is an experimental feature (it relies on features which are not available on PyTorch 3.1 and that we hope will be available in PyTorch's next release), so to use it you will need to compile the PyTorch master branch, and hope for the best ;-). $ python3 compress_classifier.py --resume=../ssl/checkpoints/checkpoint_trained_ch_regularized_dense.pth.tar -a=resnet20_cifar ../../../data.cifar10 --summary=compute Generates: +----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\n| | Name | Type | Attrs | IFM | IFM volume | OFM | OFM volume | Weights volume | MACs |\n|----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------|\n| 0 | module.conv1 | Conv2d | k=(3, 3) | (1, 3, 32, 32) | 3072 | (1, 16, 32, 32) | 16384 | 432 | 442368 |\n| 1 | module.layer1.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 2 | module.layer1.0.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 3 | module.layer1.1.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 4 | module.layer1.1.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 5 | module.layer1.2.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 6 | module.layer1.2.conv2 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 16, 32, 32) | 16384 | 2304 | 2359296 |\n| 7 | module.layer2.0.conv1 | Conv2d | k=(3, 3) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 4608 | 1179648 |\n| 8 | module.layer2.0.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 9 | module.layer2.0.downsample.0 | Conv2d | k=(1, 1) | (1, 16, 32, 32) | 16384 | (1, 32, 16, 16) | 8192 | 512 | 131072 |\n| 10 | module.layer2.1.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 11 | module.layer2.1.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 12 | module.layer2.2.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 13 | module.layer2.2.conv2 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 32, 16, 16) | 8192 | 9216 | 2359296 |\n| 14 | module.layer3.0.conv1 | Conv2d | k=(3, 3) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 18432 | 1179648 |\n| 15 | module.layer3.0.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 16 | module.layer3.0.downsample.0 | Conv2d | k=(1, 1) | (1, 32, 16, 16) | 8192 | (1, 64, 8, 8) | 4096 | 2048 | 131072 |\n| 17 | module.layer3.1.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 18 | module.layer3.1.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 19 | module.layer3.2.conv1 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 20 | module.layer3.2.conv2 | Conv2d | k=(3, 3) | (1, 64, 8, 8) | 4096 | (1, 64, 8, 8) | 4096 | 36864 | 2359296 |\n| 21 | module.fc | Linear | | (1, 64) | 64 | (1, 10) | 10 | 640 | 640 |\n+----+------------------------------+--------+----------+-----------------+--------------+-----------------+--------------+------------------+---------+\nTotal MACs: 40,813,184", "title": "Summaries" - }, + }, { - "location": "/usage/index.html#using-tensorboard", - "text": "Google's TensorBoard is an excellent tool for visualizing the progress of DNN training. Distiller's logger supports writing performance indicators and parameter statistics in a file format that can be read by TensorBoard (Distiller uses TensorFlow's APIs in order to do this, which is why Distiller requires the installation of TensorFlow). \nTo view the graphs, invoke the TensorBoard server. For example: $ tensorboard --logdir=logs Distillers's setup (requirements.txt) installs TensorFlow for CPU. If you want a different installation, please follow the TensorFlow installation instructions .", + "location": "/usage/index.html#using-tensorboard", + "text": "Google's TensorBoard is an excellent tool for visualizing the progress of DNN training. Distiller's logger supports writing performance indicators and parameter statistics in a file format that can be read by TensorBoard (Distiller uses TensorFlow's APIs in order to do this, which is why Distiller requires the installation of TensorFlow). \nTo view the graphs, invoke the TensorBoard server. For example: $ tensorboard --logdir=logs Distillers's setup (requirements.txt) installs TensorFlow for CPU. If you want a different installation, please follow the TensorFlow installation instructions .", "title": "Using TensorBoard" - }, + }, { - "location": "/usage/index.html#collecting-activations-statistics", - "text": "In CNNs with ReLU layers, ReLU activations (feature-maps) also exhibit a nice level of sparsity (50-60% sparsity is typical). \nYou can collect activation statistics using the --act_stats command-line flag. \nFor example: $ python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --resume=checkpoint.resnet56_cifar_baseline.pth.tar --act-stats=test -e The test parameter indicates that, in this example, we want to collect activation statistics during the test phase. Note that we also used the -e command-line argument to indicate that we want to run a test phase. The other two legal parameter values are train and valid which collect activation statistics during the training and validation phases, respectively.", + "location": "/usage/index.html#collecting-activations-statistics", + "text": "In CNNs with ReLU layers, ReLU activations (feature-maps) also exhibit a nice level of sparsity (50-60% sparsity is typical). \nYou can collect activation statistics using the --act_stats command-line flag. \nFor example: $ python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --resume=checkpoint.resnet56_cifar_baseline.pth.tar --act-stats=test -e The test parameter indicates that, in this example, we want to collect activation statistics during the test phase. Note that we also used the -e command-line argument to indicate that we want to run a test phase. The other two legal parameter values are train and valid which collect activation statistics during the training and validation phases, respectively.", "title": "Collecting activations statistics" - }, + }, { - "location": "/usage/index.html#collectors-and-their-collaterals", - "text": "An instance of a subclass of ActivationStatsCollector can be used to collect activation statistics. Currently, ActivationStatsCollector has two types of subclasses: SummaryActivationStatsCollector and RecordsActivationStatsCollector . \nInstances of SummaryActivationStatsCollector compute the mean of some statistic of the activation. It is rather\nlight-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor. \nIn the sample compression application, compress_classifier.py , we create a dictionary of collectors. For example: SummaryActivationStatsCollector(model,\n \"sparsity\",\n lambda t: 100 * distiller.utils.sparsity(t)) The lambda expression is invoked per activation encountered during forward passes, and the value it returns (in this case, the sparsity of the activation tensors, multiplied by 100) is stored in module.sparsity ( \"sparsity\" is this collector's name). To access the statistics, you can invoke collector.value() , or you can access each module's data directly. Another type of collector is RecordsActivationStatsCollector which computes a hard-coded set of activations statistics and collects a record per activation . For obvious reasons, this is slower than instances of SummaryActivationStatsCollector . ActivationStatsCollector default to collecting activations statistics only on the output activations of ReLU layers, but we can choose any layer type we want. In the example below we collect statistics from outputs of torch.nn.Conv2d layers. RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) Collectors can write their data to Excel workbooks (which are named using the collector's name), by invoking collector.to_xlsx(path_to_workbook) . In compress_classifier.py we currently create four different collectors which you can selectively disable. You can also add other statistics collectors and use a different function to compute your new statistic. collectors = missingdict({\n \"sparsity\": SummaryActivationStatsCollector(model, \"sparsity\",\n lambda t: 100 * distiller.utils.sparsity(t)),\n \"l1_channels\": SummaryActivationStatsCollector(model, \"l1_channels\",\n distiller.utils.activation_channels_l1),\n \"apoz_channels\": SummaryActivationStatsCollector(model, \"apoz_channels\",\n distiller.utils.activation_channels_apoz),\n \"records\": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])}) By default, these Collectors write their data to files in the active log directory. You can use a utility function, distiller.log_activation_statsitics , to log the data of an ActivationStatsCollector instance to one of the backend-loggers. For an example, the code below logs the \"sparsity\" collector to a TensorBoard log file. distiller.log_activation_statsitics(epoch, \"train\", loggers=[tflogger],\n collector=collectors[\"sparsity\"])", + "location": "/usage/index.html#collectors-and-their-collaterals", + "text": "An instance of a subclass of ActivationStatsCollector can be used to collect activation statistics. Currently, ActivationStatsCollector has two types of subclasses: SummaryActivationStatsCollector and RecordsActivationStatsCollector . \nInstances of SummaryActivationStatsCollector compute the mean of some statistic of the activation. It is rather\nlight-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor. \nIn the sample compression application, compress_classifier.py , we create a dictionary of collectors. For example: SummaryActivationStatsCollector(model,\n sparsity ,\n lambda t: 100 * distiller.utils.sparsity(t)) The lambda expression is invoked per activation encountered during forward passes, and the value it returns (in this case, the sparsity of the activation tensors, multiplied by 100) is stored in module.sparsity ( \"sparsity\" is this collector's name). To access the statistics, you can invoke collector.value() , or you can access each module's data directly. Another type of collector is RecordsActivationStatsCollector which computes a hard-coded set of activations statistics and collects a record per activation . For obvious reasons, this is slower than instances of SummaryActivationStatsCollector . ActivationStatsCollector default to collecting activations statistics only on the output activations of ReLU layers, but we can choose any layer type we want. In the example below we collect statistics from outputs of torch.nn.Conv2d layers. RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) Collectors can write their data to Excel workbooks (which are named using the collector's name), by invoking collector.to_xlsx(path_to_workbook) . In compress_classifier.py we currently create four different collectors which you can selectively disable. You can also add other statistics collectors and use a different function to compute your new statistic. collectors = missingdict({\n sparsity : SummaryActivationStatsCollector(model, sparsity ,\n lambda t: 100 * distiller.utils.sparsity(t)),\n l1_channels : SummaryActivationStatsCollector(model, l1_channels ,\n distiller.utils.activation_channels_l1),\n apoz_channels : SummaryActivationStatsCollector(model, apoz_channels ,\n distiller.utils.activation_channels_apoz),\n records : RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])}) By default, these Collectors write their data to files in the active log directory. You can use a utility function, distiller.log_activation_statsitics , to log the data of an ActivationStatsCollector instance to one of the backend-loggers. For an example, the code below logs the \"sparsity\" collector to a TensorBoard log file. distiller.log_activation_statsitics(epoch, train , loggers=[tflogger],\n collector=collectors[ sparsity ])", "title": "Collectors and their collaterals" - }, + }, { - "location": "/usage/index.html#caveats", - "text": "Distiller collects activations statistics using PyTorch's forward-hooks mechanism. Collectors iteratively register the modules' forward-hooks, and collectors are called during the forward traversal and get exposed to activation data. Registering for forward callbacks is performed like this: module.register_forward_hook This makes apparent two limitations of this mechanism: We can only register on PyTorch modules. This means that we can't register on the forward hook of a functionals such as torch.nn.functional.relu and torch.nn.functional.max_pool2d . \n Therefore, you may need to replace functionals with their module alternative. For example: class MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n\n def forward(self, x):\n x = F.relu(self.conv1(x))\n return x Can be changed to: class MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.relu = nn.ReLU(inplace=True)\n\n def forward(self, x):\n x = self.relu(self.conv1(x))\n return x We can only use a module instance once in our models. If we use the same module several times, then we can't determine which node in the graph has invoked the callback, because the PyTorch callback signature def hook(module, input, output) doesn't provide enough contextual information. \nTorchVision's ResNet is an example of a model that uses the same instance of nn.ReLU multiple times: class BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu(out) # <================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu(out) # <================\n return out In Distiller we changed ResNet to use multiple instances of nn.ReLU, and each instance is used only once: class BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu1 = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.relu2 = nn.ReLU(inplace=True)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu1(out) # <================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu2(out) # <================\n return out", + "location": "/usage/index.html#caveats", + "text": "Distiller collects activations statistics using PyTorch's forward-hooks mechanism. Collectors iteratively register the modules' forward-hooks, and collectors are called during the forward traversal and get exposed to activation data. Registering for forward callbacks is performed like this: module.register_forward_hook This makes apparent two limitations of this mechanism: We can only register on PyTorch modules. This means that we can't register on the forward hook of a functionals such as torch.nn.functional.relu and torch.nn.functional.max_pool2d . \n Therefore, you may need to replace functionals with their module alternative. For example: class MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n\n def forward(self, x):\n x = F.relu(self.conv1(x))\n return x Can be changed to: class MadeUpNet(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.relu = nn.ReLU(inplace=True)\n\n def forward(self, x):\n x = self.relu(self.conv1(x))\n return x We can only use a module instance once in our models. If we use the same module several times, then we can't determine which node in the graph has invoked the callback, because the PyTorch callback signature def hook(module, input, output) doesn't provide enough contextual information. \nTorchVision's ResNet is an example of a model that uses the same instance of nn.ReLU multiple times: class BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu(out) # ================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu(out) # ================\n return out In Distiller we changed ResNet to use multiple instances of nn.ReLU, and each instance is used only once: class BasicBlock(nn.Module):\n expansion = 1\n def __init__(self, inplanes, planes, stride=1, downsample=None):\n super(BasicBlock, self).__init__()\n self.conv1 = conv3x3(inplanes, planes, stride)\n self.bn1 = nn.BatchNorm2d(planes)\n self.relu1 = nn.ReLU(inplace=True)\n self.conv2 = conv3x3(planes, planes)\n self.bn2 = nn.BatchNorm2d(planes)\n self.relu2 = nn.ReLU(inplace=True)\n self.downsample = downsample\n self.stride = stride\n\n def forward(self, x):\n residual = x\n out = self.conv1(x)\n out = self.bn1(out)\n out = self.relu1(out) # ================\n out = self.conv2(out)\n out = self.bn2(out)\n if self.downsample is not None:\n residual = self.downsample(x)\n out += residual\n out = self.relu2(out) # ================\n return out", "title": "Caveats" - }, + }, { - "location": "/usage/index.html#using-the-jupyter-notebooks", - "text": "The Jupyter notebooks contain many examples of how to use the statistics summaries generated by Distiller. They are explained in a separate page.", + "location": "/usage/index.html#using-the-jupyter-notebooks", + "text": "The Jupyter notebooks contain many examples of how to use the statistics summaries generated by Distiller. They are explained in a separate page.", "title": "Using the Jupyter notebooks" - }, + }, { - "location": "/usage/index.html#generating-this-documentation", - "text": "Install mkdocs and the required packages by executing: $ pip3 install -r doc-requirements.txt To build the project documentation run: $ cd distiller/docs-src\n$ mkdocs build --clean This will create a folder named 'site' which contains the documentation website.\nOpen distiller/docs/site/index.html to view the documentation home page.", + "location": "/usage/index.html#generating-this-documentation", + "text": "Install mkdocs and the required packages by executing: $ pip3 install -r doc-requirements.txt To build the project documentation run: $ cd distiller/docs-src\n$ mkdocs build --clean This will create a folder named 'site' which contains the documentation website.\nOpen distiller/docs/site/index.html to view the documentation home page.", "title": "Generating this documentation" - }, + }, { - "location": "/schedule/index.html", - "text": "Compression scheduler\n\n\nIn iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of \nCompressionScheduler\n: it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code. \n\n\nHigh level overview\n\n\nLet's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies.\n\n\n\n\nPruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. \n\n\nAn LR-scheduler specifies the LR-decay algorithm. \n\n\n\n\nThese define the \nwhat\n part of the schedule. \n\n\nThe Policies define the \nwhen\n part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing.\n\nThe \nCompressionScheduler\n is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.\n\n\nSyntax through example\n\n\nWe'll use \nalexnet.schedule_agp.yaml\n to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThere is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2.\n\n\nversion: 1\n\n\n\n\nIn the \npruners\n section, we define the instances of pruners we want the scheduler to instantiate and use.\n\nWe define a single pruner instance, named \nmy_pruner\n, of algorithm \nSensitivityPruner\n. We will refer to this instance in the \nPolicies\n section.\n\nThen we list the sensitivity multipliers, \\(s\\), of each of the weight tensors.\n\nYou may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule.\n\n\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.6\n\n\n\n\nNext, we want to specify the learning-rate decay scheduling in the \nlr_schedulers\n section. We assign a name to this instance: \npruning_lr\n. As in the \npruners\n section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's \n_LRScheduler\n. You can use any of the schedulers defined in \ntorch.optim.lr_scheduler\n (see \nhere\n). In addition, we've implemented some additional schedulers in Distiller (see \nhere\n). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to \ntorch.optim.lr_scheduler\n, they can be used without changing the application code.\n\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\n\n\n\nFinally, we define the \npolicies\n section which defines the actual scheduling. A \nPolicy\n manages an instance of a \nPruner\n, \nRegularizer\n, \nQuantizer\n, or \nLRScheduler\n, by naming the instance. In the example below, a \nPruningPolicy\n uses the pruner instance named \nmy_pruner\n: it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. \n\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThis is \niterative pruning\n:\n\n\n\n\n\n\nTrain Connectivity\n\n\n\n\n\n\nPrune Connections\n\n\n\n\n\n\nRetrain Weights\n\n\n\n\n\n\nGoto 2\n\n\n\n\n\n\nIt is described in \nLearning both Weights and Connections for Efficient Neural Networks\n:\n\n\n\n\n\"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks \u2014 learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity.\"\n\n\n\n\nRegularization\n\n\nYou can also define and schedule regularization.\n\n\nL1 regularization\n\n\nFormat (this is an informal specification, not a valid \nABNF\n specification):\n\n\nregularizers:\n <REGULARIZER_NAME_STR>:\n class: L1Regularizer\n reg_regims:\n <PYTORCH_PARAM_NAME_STR>: <STRENGTH_FLOAT>\n ...\n <PYTORCH_PARAM_NAME_STR>: <STRENGTH_FLOAT>\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_L1_reg:\n class: L1Regularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': 0.000002\n 'module.layer3.1.conv2.weight': 0.000002\n 'module.layer3.1.conv3.weight': 0.000002\n 'module.layer3.2.conv1.weight': 0.000002\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_L1_reg\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nGroup regularization\n\n\nFormat (informal specification):\n\n\nFormat:\n regularizers:\n <REGULARIZER_NAME_STR>:\n class: L1Regularizer\n reg_regims:\n <PYTORCH_PARAM_NAME_STR>: [<STRENGTH_FLOAT>, <'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'>]\n <PYTORCH_PARAM_NAME_STR>: [<STRENGTH_FLOAT>, <'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'>]\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_filter_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': [0.00005, '3D']\n 'module.layer3.1.conv2.weight': [0.00005, '3D']\n 'module.layer3.1.conv3.weight': [0.00005, '3D']\n 'module.layer3.2.conv1.weight': [0.00005, '3D']\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_filter_regularizer\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nMixing it up\n\n\nYou can mix pruning and regularization.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nregularizers:\n 2d_groups_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'features.module.0.weight': [0.000012, '2D']\n 'features.module.3.weight': [0.000012, '2D']\n 'features.module.6.weight': [0.000012, '2D']\n 'features.module.8.weight': [0.000012, '2D']\n 'features.module.10.weight': [0.000012, '2D']\n\n\nlr_schedulers:\n # Learning rate decay scheduler\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - regularizer:\n instance_name: '2d_groups_regularizer'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 1\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\n\nQuantization\n\n\nSimilarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the \nQuantizer\n class (see details \nhere\n).\n\n\nNotes\n: Only a single quantizer instance may be defined.\n\nLet's see an example:\n\n\nquantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null\n\n\n\n\n\n\nThe specific quantization method we're instantiating here is \nDorefaQuantizer\n.\n\n\nThen we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. \n\n\nThen, we define the \nbits_overrides\n mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of \nDorefaQuantizer\n, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters \nconv1\n, the first activation layer \nrelu1\n, the last activation layer \nfinal_relu\n and the last layer with parameters \nfc\n.\n\n\nSpecifying \nnull\n means \"do not quantize\".\n\n\nNote that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.\n\n\n\n\nDefining overrides for \ngroups of layers\n using regular expressions\n\n\nSuppose we have a sub-module in our model named \nblock1\n, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named \nconv1\n, \nconv2\n and so on. In that case we would define the following:\n\n\nbits_overrides:\n 'block1\\.conv*':\n wts: 2\n acts: null\n\n\n\n\n\n\nRegEx Note\n: Remember that the dot (\n.\n) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: \n\\.\n\n\n\n\nOverlapping patterns\n are also possible, which allows to define some override for a groups of layers and also \"single-out\" specific layers for different overrides. For example, let's take the last example and configure a different override for \nblock1.conv1\n:\n\n\nbits_overrides:\n 'block1\\.conv1':\n wts: 4\n acts: null\n 'block1\\.conv*':\n wts: 2\n acts: null\n\n\n\n\n\n\nImportant Note\n: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using \"broad\" patterns and more \"specific\" patterns as just shown, make sure the specific pattern is listed \nbefore\n the broad one.\n\n\n\n\nThe \nQuantizationPolicy\n, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the \nprepare_model()\n function of the \nQuantizer\n when it's initialized, followed by the first call to \nquantize_params()\n. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the \nquantize_params()\n function again. \n\n\npolicies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1\n\n\n\n\nImportant Note\n: As mentioned \nhere\n, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to \nprepare_model()\n must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.\n\n\nKnowledge Distillation\n\n\nKnowledge distillation (see \nhere\n) is also implemented as a \nPolicy\n, which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above.\n\n\nTo make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation:\n\n\nimport argparse\nimport distiller\n\nparser = argparse.ArgumentParser()\ndistiller.knowledge_distillation.add_distillation_args(parser)\n\n\n\n\n(The \nadd_distillation_args\n function accepts some optional arguments, see its implementation at \ndistiller/knowledge_distillation.py\n for details)\n\n\nThese are the command line arguments exposed by this function:\n\n\nKnowledge Distillation Training Arguments:\n --kd-teacher ARCH Model architecture for teacher model\n --kd-pretrained Use pre-trained model for teacher\n --kd-resume PATH Path to checkpoint from which to load teacher weights\n --kd-temperature TEMP, --kd-temp TEMP\n Knowledge distillation softmax temperature\n --kd-distill-wt WEIGHT, --kd-dw WEIGHT\n Weight for distillation loss (student vs. teacher soft\n targets)\n --kd-student-wt WEIGHT, --kd-sw WEIGHT\n Weight for student vs. labels loss\n --kd-teacher-wt WEIGHT, --kd-tw WEIGHT\n Weight for teacher vs. labels loss\n --kd-start-epoch EPOCH_NUM\n Epoch from which to enable distillation\n\n\n\n\n\nOnce arguments have been parsed, some initialization code is required, similar to the following:\n\n\n# Assuming:\n# \"args\" variable holds command line arguments\n# \"model\" variable holds the model we're going to train, that is - the student model\n# \"compression_scheduler\" variable holds a CompressionScheduler instance\n\nargs.kd_policy = None\nif args.kd_teacher:\n # Create teacher model - replace this with your model creation code\n teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)\n if args.kd_resume:\n teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)\n\n # Create policy and add to scheduler\n dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)\n args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)\n compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,\n frequency=1)\n\n\n\n\nFinally, during the training loop, we need to perform forward propagation through the teacher model as well. The \nKnowledgeDistillationPolicy\n class keeps a reference to both the student and teacher models, and exposes a \nforward\n function that performs forward propagation on both of them. Since this is not one of the standard policy callbacks, we need to call this function manually from our training loop, as follows:\n\n\nif args.kd_policy is None:\n # Revert to a \"normal\" forward-prop call if no knowledge distillation policy is present\n output = model(input_var)\nelse:\n output = args.kd_policy.forward(input_var)\n\n\n\n\nTo see this integration in action, take a look at the image classification sample at \nexamples/classifier_compression/compress_classifier.py\n.", + "location": "/schedule/index.html", + "text": "Compression scheduler\n\n\nIn iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of \nCompressionScheduler\n: it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code. \n\n\nHigh level overview\n\n\nLet's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies.\n\n\n\n\nPruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. \n\n\nAn LR-scheduler specifies the LR-decay algorithm. \n\n\n\n\nThese define the \nwhat\n part of the schedule. \n\n\nThe Policies define the \nwhen\n part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing.\n\nThe \nCompressionScheduler\n is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.\n\n\nSyntax through example\n\n\nWe'll use \nalexnet.schedule_agp.yaml\n to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThere is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2.\n\n\nversion: 1\n\n\n\n\nIn the \npruners\n section, we define the instances of pruners we want the scheduler to instantiate and use.\n\nWe define a single pruner instance, named \nmy_pruner\n, of algorithm \nSensitivityPruner\n. We will refer to this instance in the \nPolicies\n section.\n\nThen we list the sensitivity multipliers, \\(s\\), of each of the weight tensors.\n\nYou may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule.\n\n\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.6\n\n\n\n\nNext, we want to specify the learning-rate decay scheduling in the \nlr_schedulers\n section. We assign a name to this instance: \npruning_lr\n. As in the \npruners\n section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's \n_LRScheduler\n. You can use any of the schedulers defined in \ntorch.optim.lr_scheduler\n (see \nhere\n). In addition, we've implemented some additional schedulers in Distiller (see \nhere\n). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to \ntorch.optim.lr_scheduler\n, they can be used without changing the application code.\n\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\n\n\n\nFinally, we define the \npolicies\n section which defines the actual scheduling. A \nPolicy\n manages an instance of a \nPruner\n, \nRegularizer\n, \nQuantizer\n, or \nLRScheduler\n, by naming the instance. In the example below, a \nPruningPolicy\n uses the pruner instance named \nmy_pruner\n: it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. \n\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThis is \niterative pruning\n:\n\n\n\n\n\n\nTrain Connectivity\n\n\n\n\n\n\nPrune Connections\n\n\n\n\n\n\nRetrain Weights\n\n\n\n\n\n\nGoto 2\n\n\n\n\n\n\nIt is described in \nLearning both Weights and Connections for Efficient Neural Networks\n:\n\n\n\n\n\"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks \u2014 learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity.\"\n\n\n\n\nRegularization\n\n\nYou can also define and schedule regularization.\n\n\nL1 regularization\n\n\nFormat (this is an informal specification, not a valid \nABNF\n specification):\n\n\nregularizers:\n \nREGULARIZER_NAME_STR\n:\n class: L1Regularizer\n reg_regims:\n \nPYTORCH_PARAM_NAME_STR\n: \nSTRENGTH_FLOAT\n\n ...\n \nPYTORCH_PARAM_NAME_STR\n: \nSTRENGTH_FLOAT\n\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_L1_reg:\n class: L1Regularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': 0.000002\n 'module.layer3.1.conv2.weight': 0.000002\n 'module.layer3.1.conv3.weight': 0.000002\n 'module.layer3.2.conv1.weight': 0.000002\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_L1_reg\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nGroup regularization\n\n\nFormat (informal specification):\n\n\nFormat:\n regularizers:\n \nREGULARIZER_NAME_STR\n:\n class: L1Regularizer\n reg_regims:\n \nPYTORCH_PARAM_NAME_STR\n: [\nSTRENGTH_FLOAT\n, \n'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'\n]\n \nPYTORCH_PARAM_NAME_STR\n: [\nSTRENGTH_FLOAT\n, \n'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'\n]\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_filter_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': [0.00005, '3D']\n 'module.layer3.1.conv2.weight': [0.00005, '3D']\n 'module.layer3.1.conv3.weight': [0.00005, '3D']\n 'module.layer3.2.conv1.weight': [0.00005, '3D']\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_filter_regularizer\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nMixing it up\n\n\nYou can mix pruning and regularization.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nregularizers:\n 2d_groups_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'features.module.0.weight': [0.000012, '2D']\n 'features.module.3.weight': [0.000012, '2D']\n 'features.module.6.weight': [0.000012, '2D']\n 'features.module.8.weight': [0.000012, '2D']\n 'features.module.10.weight': [0.000012, '2D']\n\n\nlr_schedulers:\n # Learning rate decay scheduler\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - regularizer:\n instance_name: '2d_groups_regularizer'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 1\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\n\nQuantization\n\n\nSimilarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the \nQuantizer\n class (see details \nhere\n). \nNote\n that only a single quantizer instance may be defined per YAML.\n\n\nLet's see an example:\n\n\nquantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null\n\n\n\n\n\n\nThe specific quantization method we're instantiating here is \nDorefaQuantizer\n.\n\n\nThen we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. \n\n\nThen, we define the \nbits_overrides\n mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of \nDorefaQuantizer\n, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters \nconv1\n, the first activation layer \nrelu1\n, the last activation layer \nfinal_relu\n and the last layer with parameters \nfc\n.\n\n\nSpecifying \nnull\n means \"do not quantize\".\n\n\nNote that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.\n\n\n\n\nDefining overrides for \ngroups of layers\n using regular expressions\n\n\nSuppose we have a sub-module in our model named \nblock1\n, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named \nconv1\n, \nconv2\n and so on. In that case we would define the following:\n\n\nbits_overrides:\n 'block1\\.conv*':\n wts: 2\n acts: null\n\n\n\n\n\n\nRegEx Note\n: Remember that the dot (\n.\n) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: \n\\.\n\n\n\n\nOverlapping patterns\n are also possible, which allows to define some override for a groups of layers and also \"single-out\" specific layers for different overrides. For example, let's take the last example and configure a different override for \nblock1.conv1\n:\n\n\nbits_overrides:\n 'block1\\.conv1':\n wts: 4\n acts: null\n 'block1\\.conv*':\n wts: 2\n acts: null\n\n\n\n\n\n\nImportant Note\n: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using \"broad\" patterns and more \"specific\" patterns as just shown, make sure the specific pattern is listed \nbefore\n the broad one.\n\n\n\n\nThe \nQuantizationPolicy\n, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the \nprepare_model()\n function of the \nQuantizer\n when it's initialized, followed by the first call to \nquantize_params()\n. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the \nquantize_params()\n function again.\n\n\npolicies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1\n\n\n\n\nImportant Note\n: As mentioned \nhere\n, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to \nprepare_model()\n must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.\n\n\nKnowledge Distillation\n\n\nKnowledge distillation (see \nhere\n) is also implemented as a \nPolicy\n, which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above.\n\n\nTo make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation:\n\n\nimport argparse\nimport distiller\n\nparser = argparse.ArgumentParser()\ndistiller.knowledge_distillation.add_distillation_args(parser)\n\n\n\n\n(The \nadd_distillation_args\n function accepts some optional arguments, see its implementation at \ndistiller/knowledge_distillation.py\n for details)\n\n\nThese are the command line arguments exposed by this function:\n\n\nKnowledge Distillation Training Arguments:\n --kd-teacher ARCH Model architecture for teacher model\n --kd-pretrained Use pre-trained model for teacher\n --kd-resume PATH Path to checkpoint from which to load teacher weights\n --kd-temperature TEMP, --kd-temp TEMP\n Knowledge distillation softmax temperature\n --kd-distill-wt WEIGHT, --kd-dw WEIGHT\n Weight for distillation loss (student vs. teacher soft\n targets)\n --kd-student-wt WEIGHT, --kd-sw WEIGHT\n Weight for student vs. labels loss\n --kd-teacher-wt WEIGHT, --kd-tw WEIGHT\n Weight for teacher vs. labels loss\n --kd-start-epoch EPOCH_NUM\n Epoch from which to enable distillation\n\n\n\n\n\nOnce arguments have been parsed, some initialization code is required, similar to the following:\n\n\n# Assuming:\n# \nargs\n variable holds command line arguments\n# \nmodel\n variable holds the model we're going to train, that is - the student model\n# \ncompression_scheduler\n variable holds a CompressionScheduler instance\n\nargs.kd_policy = None\nif args.kd_teacher:\n # Create teacher model - replace this with your model creation code\n teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)\n if args.kd_resume:\n teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)\n\n # Create policy and add to scheduler\n dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)\n args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)\n compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,\n frequency=1)\n\n\n\n\nFinally, during the training loop, we need to perform forward propagation through the teacher model as well. The \nKnowledgeDistillationPolicy\n class keeps a reference to both the student and teacher models, and exposes a \nforward\n function that performs forward propagation on both of them. Since this is not one of the standard policy callbacks, we need to call this function manually from our training loop, as follows:\n\n\nif args.kd_policy is None:\n # Revert to a \nnormal\n forward-prop call if no knowledge distillation policy is present\n output = model(input_var)\nelse:\n output = args.kd_policy.forward(input_var)\n\n\n\n\nTo see this integration in action, take a look at the image classification sample at \nexamples/classifier_compression/compress_classifier.py\n.", "title": "Compression scheduling" - }, + }, { - "location": "/schedule/index.html#compression-scheduler", - "text": "In iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of CompressionScheduler : it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code.", + "location": "/schedule/index.html#compression-scheduler", + "text": "In iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of CompressionScheduler : it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code.", "title": "Compression scheduler" - }, + }, { - "location": "/schedule/index.html#high-level-overview", - "text": "Let's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies. Pruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. An LR-scheduler specifies the LR-decay algorithm. These define the what part of the schedule. The Policies define the when part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing. \nThe CompressionScheduler is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.", + "location": "/schedule/index.html#high-level-overview", + "text": "Let's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies. Pruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. An LR-scheduler specifies the LR-decay algorithm. These define the what part of the schedule. The Policies define the when part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing. \nThe CompressionScheduler is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.", "title": "High level overview" - }, + }, { - "location": "/schedule/index.html#syntax-through-example", - "text": "We'll use alexnet.schedule_agp.yaml to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet. version: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1 There is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2. version: 1 In the pruners section, we define the instances of pruners we want the scheduler to instantiate and use. \nWe define a single pruner instance, named my_pruner , of algorithm SensitivityPruner . We will refer to this instance in the Policies section. \nThen we list the sensitivity multipliers, \\(s\\), of each of the weight tensors. \nYou may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule. pruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.6 Next, we want to specify the learning-rate decay scheduling in the lr_schedulers section. We assign a name to this instance: pruning_lr . As in the pruners section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's _LRScheduler . You can use any of the schedulers defined in torch.optim.lr_scheduler (see here ). In addition, we've implemented some additional schedulers in Distiller (see here ). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to torch.optim.lr_scheduler , they can be used without changing the application code. lr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9 Finally, we define the policies section which defines the actual scheduling. A Policy manages an instance of a Pruner , Regularizer , Quantizer , or LRScheduler , by naming the instance. In the example below, a PruningPolicy uses the pruner instance named my_pruner : it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. policies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1 This is iterative pruning : Train Connectivity Prune Connections Retrain Weights Goto 2 It is described in Learning both Weights and Connections for Efficient Neural Networks : \"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks \u2014 learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity.\"", + "location": "/schedule/index.html#syntax-through-example", + "text": "We'll use alexnet.schedule_agp.yaml to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet. version: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1 There is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2. version: 1 In the pruners section, we define the instances of pruners we want the scheduler to instantiate and use. \nWe define a single pruner instance, named my_pruner , of algorithm SensitivityPruner . We will refer to this instance in the Policies section. \nThen we list the sensitivity multipliers, \\(s\\), of each of the weight tensors. \nYou may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule. pruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.6 Next, we want to specify the learning-rate decay scheduling in the lr_schedulers section. We assign a name to this instance: pruning_lr . As in the pruners section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's _LRScheduler . You can use any of the schedulers defined in torch.optim.lr_scheduler (see here ). In addition, we've implemented some additional schedulers in Distiller (see here ). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to torch.optim.lr_scheduler , they can be used without changing the application code. lr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9 Finally, we define the policies section which defines the actual scheduling. A Policy manages an instance of a Pruner , Regularizer , Quantizer , or LRScheduler , by naming the instance. In the example below, a PruningPolicy uses the pruner instance named my_pruner : it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. policies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1 This is iterative pruning : Train Connectivity Prune Connections Retrain Weights Goto 2 It is described in Learning both Weights and Connections for Efficient Neural Networks : \"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks \u2014 learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity.\"", "title": "Syntax through example" - }, + }, { - "location": "/schedule/index.html#regularization", - "text": "You can also define and schedule regularization.", + "location": "/schedule/index.html#regularization", + "text": "You can also define and schedule regularization.", "title": "Regularization" - }, + }, { - "location": "/schedule/index.html#l1-regularization", - "text": "Format (this is an informal specification, not a valid ABNF specification): regularizers:\n <REGULARIZER_NAME_STR>:\n class: L1Regularizer\n reg_regims:\n <PYTORCH_PARAM_NAME_STR>: <STRENGTH_FLOAT>\n ...\n <PYTORCH_PARAM_NAME_STR>: <STRENGTH_FLOAT>\n threshold_criteria: [Mean_Abs | Max] For example: version: 1\n\nregularizers:\n my_L1_reg:\n class: L1Regularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': 0.000002\n 'module.layer3.1.conv2.weight': 0.000002\n 'module.layer3.1.conv3.weight': 0.000002\n 'module.layer3.2.conv1.weight': 0.000002\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_L1_reg\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1", + "location": "/schedule/index.html#l1-regularization", + "text": "Format (this is an informal specification, not a valid ABNF specification): regularizers:\n REGULARIZER_NAME_STR :\n class: L1Regularizer\n reg_regims:\n PYTORCH_PARAM_NAME_STR : STRENGTH_FLOAT \n ...\n PYTORCH_PARAM_NAME_STR : STRENGTH_FLOAT \n threshold_criteria: [Mean_Abs | Max] For example: version: 1\n\nregularizers:\n my_L1_reg:\n class: L1Regularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': 0.000002\n 'module.layer3.1.conv2.weight': 0.000002\n 'module.layer3.1.conv3.weight': 0.000002\n 'module.layer3.2.conv1.weight': 0.000002\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_L1_reg\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1", "title": "L1 regularization" - }, + }, { - "location": "/schedule/index.html#group-regularization", - "text": "Format (informal specification): Format:\n regularizers:\n <REGULARIZER_NAME_STR>:\n class: L1Regularizer\n reg_regims:\n <PYTORCH_PARAM_NAME_STR>: [<STRENGTH_FLOAT>, <'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'>]\n <PYTORCH_PARAM_NAME_STR>: [<STRENGTH_FLOAT>, <'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'>]\n threshold_criteria: [Mean_Abs | Max] For example: version: 1\n\nregularizers:\n my_filter_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': [0.00005, '3D']\n 'module.layer3.1.conv2.weight': [0.00005, '3D']\n 'module.layer3.1.conv3.weight': [0.00005, '3D']\n 'module.layer3.2.conv1.weight': [0.00005, '3D']\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_filter_regularizer\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1", + "location": "/schedule/index.html#group-regularization", + "text": "Format (informal specification): Format:\n regularizers:\n REGULARIZER_NAME_STR :\n class: L1Regularizer\n reg_regims:\n PYTORCH_PARAM_NAME_STR : [ STRENGTH_FLOAT , '2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows' ]\n PYTORCH_PARAM_NAME_STR : [ STRENGTH_FLOAT , '2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows' ]\n threshold_criteria: [Mean_Abs | Max] For example: version: 1\n\nregularizers:\n my_filter_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': [0.00005, '3D']\n 'module.layer3.1.conv2.weight': [0.00005, '3D']\n 'module.layer3.1.conv3.weight': [0.00005, '3D']\n 'module.layer3.2.conv1.weight': [0.00005, '3D']\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_filter_regularizer\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1", "title": "Group regularization" - }, + }, { - "location": "/schedule/index.html#mixing-it-up", - "text": "You can mix pruning and regularization. version: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nregularizers:\n 2d_groups_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'features.module.0.weight': [0.000012, '2D']\n 'features.module.3.weight': [0.000012, '2D']\n 'features.module.6.weight': [0.000012, '2D']\n 'features.module.8.weight': [0.000012, '2D']\n 'features.module.10.weight': [0.000012, '2D']\n\n\nlr_schedulers:\n # Learning rate decay scheduler\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - regularizer:\n instance_name: '2d_groups_regularizer'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 1\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1", + "location": "/schedule/index.html#mixing-it-up", + "text": "You can mix pruning and regularization. version: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nregularizers:\n 2d_groups_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'features.module.0.weight': [0.000012, '2D']\n 'features.module.3.weight': [0.000012, '2D']\n 'features.module.6.weight': [0.000012, '2D']\n 'features.module.8.weight': [0.000012, '2D']\n 'features.module.10.weight': [0.000012, '2D']\n\n\nlr_schedulers:\n # Learning rate decay scheduler\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - regularizer:\n instance_name: '2d_groups_regularizer'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 1\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1", "title": "Mixing it up" - }, + }, { - "location": "/schedule/index.html#quantization", - "text": "Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the Quantizer class (see details here ). Notes : Only a single quantizer instance may be defined. \nLet's see an example: quantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null The specific quantization method we're instantiating here is DorefaQuantizer . Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. Then, we define the bits_overrides mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of DorefaQuantizer , the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters conv1 , the first activation layer relu1 , the last activation layer final_relu and the last layer with parameters fc . Specifying null means \"do not quantize\". Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.", + "location": "/schedule/index.html#quantization", + "text": "Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the Quantizer class (see details here ). Note that only a single quantizer instance may be defined per YAML. Let's see an example: quantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null The specific quantization method we're instantiating here is DorefaQuantizer . Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. Then, we define the bits_overrides mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of DorefaQuantizer , the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters conv1 , the first activation layer relu1 , the last activation layer final_relu and the last layer with parameters fc . Specifying null means \"do not quantize\". Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.", "title": "Quantization" - }, + }, { - "location": "/schedule/index.html#defining-overrides-for-groups-of-layers-using-regular-expressions", - "text": "Suppose we have a sub-module in our model named block1 , which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named conv1 , conv2 and so on. In that case we would define the following: bits_overrides:\n 'block1\\.conv*':\n wts: 2\n acts: null RegEx Note : Remember that the dot ( . ) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: \\. Overlapping patterns are also possible, which allows to define some override for a groups of layers and also \"single-out\" specific layers for different overrides. For example, let's take the last example and configure a different override for block1.conv1 : bits_overrides:\n 'block1\\.conv1':\n wts: 4\n acts: null\n 'block1\\.conv*':\n wts: 2\n acts: null Important Note : The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using \"broad\" patterns and more \"specific\" patterns as just shown, make sure the specific pattern is listed before the broad one. The QuantizationPolicy , which controls the quantization procedure during training, is actually quite simplistic. All it does is call the prepare_model() function of the Quantizer when it's initialized, followed by the first call to quantize_params() . Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the quantize_params() function again. policies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1 Important Note : As mentioned here , since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to prepare_model() must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.", + "location": "/schedule/index.html#defining-overrides-for-groups-of-layers-using-regular-expressions", + "text": "Suppose we have a sub-module in our model named block1 , which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named conv1 , conv2 and so on. In that case we would define the following: bits_overrides:\n 'block1\\.conv*':\n wts: 2\n acts: null RegEx Note : Remember that the dot ( . ) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: \\. Overlapping patterns are also possible, which allows to define some override for a groups of layers and also \"single-out\" specific layers for different overrides. For example, let's take the last example and configure a different override for block1.conv1 : bits_overrides:\n 'block1\\.conv1':\n wts: 4\n acts: null\n 'block1\\.conv*':\n wts: 2\n acts: null Important Note : The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using \"broad\" patterns and more \"specific\" patterns as just shown, make sure the specific pattern is listed before the broad one. The QuantizationPolicy , which controls the quantization procedure during training, is actually quite simplistic. All it does is call the prepare_model() function of the Quantizer when it's initialized, followed by the first call to quantize_params() . Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the quantize_params() function again. policies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1 Important Note : As mentioned here , since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to prepare_model() must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.", "title": "Defining overrides for groups of layers using regular expressions" - }, + }, { - "location": "/schedule/index.html#knowledge-distillation", - "text": "Knowledge distillation (see here ) is also implemented as a Policy , which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above. To make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation: import argparse\nimport distiller\n\nparser = argparse.ArgumentParser()\ndistiller.knowledge_distillation.add_distillation_args(parser) (The add_distillation_args function accepts some optional arguments, see its implementation at distiller/knowledge_distillation.py for details) These are the command line arguments exposed by this function: Knowledge Distillation Training Arguments:\n --kd-teacher ARCH Model architecture for teacher model\n --kd-pretrained Use pre-trained model for teacher\n --kd-resume PATH Path to checkpoint from which to load teacher weights\n --kd-temperature TEMP, --kd-temp TEMP\n Knowledge distillation softmax temperature\n --kd-distill-wt WEIGHT, --kd-dw WEIGHT\n Weight for distillation loss (student vs. teacher soft\n targets)\n --kd-student-wt WEIGHT, --kd-sw WEIGHT\n Weight for student vs. labels loss\n --kd-teacher-wt WEIGHT, --kd-tw WEIGHT\n Weight for teacher vs. labels loss\n --kd-start-epoch EPOCH_NUM\n Epoch from which to enable distillation Once arguments have been parsed, some initialization code is required, similar to the following: # Assuming:\n# \"args\" variable holds command line arguments\n# \"model\" variable holds the model we're going to train, that is - the student model\n# \"compression_scheduler\" variable holds a CompressionScheduler instance\n\nargs.kd_policy = None\nif args.kd_teacher:\n # Create teacher model - replace this with your model creation code\n teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)\n if args.kd_resume:\n teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)\n\n # Create policy and add to scheduler\n dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)\n args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)\n compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,\n frequency=1) Finally, during the training loop, we need to perform forward propagation through the teacher model as well. The KnowledgeDistillationPolicy class keeps a reference to both the student and teacher models, and exposes a forward function that performs forward propagation on both of them. Since this is not one of the standard policy callbacks, we need to call this function manually from our training loop, as follows: if args.kd_policy is None:\n # Revert to a \"normal\" forward-prop call if no knowledge distillation policy is present\n output = model(input_var)\nelse:\n output = args.kd_policy.forward(input_var) To see this integration in action, take a look at the image classification sample at examples/classifier_compression/compress_classifier.py .", + "location": "/schedule/index.html#knowledge-distillation", + "text": "Knowledge distillation (see here ) is also implemented as a Policy , which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above. To make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation: import argparse\nimport distiller\n\nparser = argparse.ArgumentParser()\ndistiller.knowledge_distillation.add_distillation_args(parser) (The add_distillation_args function accepts some optional arguments, see its implementation at distiller/knowledge_distillation.py for details) These are the command line arguments exposed by this function: Knowledge Distillation Training Arguments:\n --kd-teacher ARCH Model architecture for teacher model\n --kd-pretrained Use pre-trained model for teacher\n --kd-resume PATH Path to checkpoint from which to load teacher weights\n --kd-temperature TEMP, --kd-temp TEMP\n Knowledge distillation softmax temperature\n --kd-distill-wt WEIGHT, --kd-dw WEIGHT\n Weight for distillation loss (student vs. teacher soft\n targets)\n --kd-student-wt WEIGHT, --kd-sw WEIGHT\n Weight for student vs. labels loss\n --kd-teacher-wt WEIGHT, --kd-tw WEIGHT\n Weight for teacher vs. labels loss\n --kd-start-epoch EPOCH_NUM\n Epoch from which to enable distillation Once arguments have been parsed, some initialization code is required, similar to the following: # Assuming:\n# args variable holds command line arguments\n# model variable holds the model we're going to train, that is - the student model\n# compression_scheduler variable holds a CompressionScheduler instance\n\nargs.kd_policy = None\nif args.kd_teacher:\n # Create teacher model - replace this with your model creation code\n teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)\n if args.kd_resume:\n teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)\n\n # Create policy and add to scheduler\n dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)\n args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)\n compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,\n frequency=1) Finally, during the training loop, we need to perform forward propagation through the teacher model as well. The KnowledgeDistillationPolicy class keeps a reference to both the student and teacher models, and exposes a forward function that performs forward propagation on both of them. Since this is not one of the standard policy callbacks, we need to call this function manually from our training loop, as follows: if args.kd_policy is None:\n # Revert to a normal forward-prop call if no knowledge distillation policy is present\n output = model(input_var)\nelse:\n output = args.kd_policy.forward(input_var) To see this integration in action, take a look at the image classification sample at examples/classifier_compression/compress_classifier.py .", "title": "Knowledge Distillation" - }, + }, { - "location": "/pruning/index.html", - "text": "Pruning\n\n\nA common methodology for inducing sparsity in weights and activations is called \npruning\n. Pruning is the application of a binary criteria to decide which weights to prune: weights which match the pruning criteria are assigned a value of zero. Pruned elements are \"trimmed\" from the model: we zero their values and also make sure they don't take part in the back-propagation process.\n\n\nWe can prune weights, biases, and activations. Biases are few and their contribution to a layer's output is relatively large, so there is little incentive to prune them. We usually see sparse activations following a ReLU layer, because ReLU quenches negative activations to exact zero (\\(ReLU(x): max(0,x)\\)). Sparsity in weights is less common, as weights tend to be very small, but are often not exact zeros.\n\n\n\nLet's define sparsity\n\n\nSparsity is a a measure of how many elements in a tensor are exact zeros, relative to the tensor size. A tensor is considered sparse if \"most\" of its elements are zero. How much is \"most\", is not strictly defined, but when you see a sparse tensor you know it ;-)\n\nThe \n\\(l_0\\)-\"norm\" function\n measures how many zero-elements are in a tensor \nx\n:\n\\[\\lVert x \\rVert_0\\;=\\;|x_1|^0 + |x_2|^0 + ... + |x_n|^0 \\]\nIn other words, an element contributes either a value of 1 or 0 to \\(l_0\\). Anything but an exact zero contributes a value of 1 - that's pretty cool.\n\nSometimes it helps to think about density, the number of non-zero elements (NNZ) and sparsity's complement:\n\\[\ndensity = 1 - sparsity\n\\]\nYou can use \ndistiller.sparsity\n and \ndistiller.density\n to query a PyTorch tensor's sparsity and density.\n\n\nWhat is weights pruning?\n\n\nWeights pruning, or model pruning, is a set of methods to increase the sparsity (amount of zero-valued elements in a tensor) of a network's weights. In general, the term 'parameters' refers to both weights and bias tensors of a model. Biases are rarely, if ever, pruned because there are very few bias elements compared to weights elements, and it is just not worth the trouble.\n\n\nPruning requires a criteria for choosing which elements to prune - this is called the \npruning criteria\n. The most common pruning criteria is the absolute value of each element: the element's absolute value is compared to some threshold value, and if it is below the threshold the element is set to zero (i.e. pruned) . This is implemented by the \ndistiller.MagnitudeParameterPruner\n class. The idea behind this method, is that weights with small \\(l_1\\)-norms (absolute value) contribute little to the final result (low saliency), so they are less important and can be removed.\n\n\nA related idea motivating pruning, is that models are over-parametrized and contain redundant logic and features. Therefore, some of these redundancies can be removed by setting their weights to zero.\n\n\nAnd yet another way to think of pruning is to phrase it as a search for a set of weights with as many zeros as possible, which still produces acceptable inference accuracies compared to the dense-model (non-pruned model). Another way to look at it, is to imagine that because of the very high-dimensionality of the parameter space, the immediate space around the dense-model's solution likely contains some sparse solutions, and we want to use find these sparse solutions. \n\n\n\n\nPruning schedule\n\n\nThe most straight-forward to prune is to take a trained model and prune it once; also called \none-shot pruning\n. In \nLearning both Weights and Connections for Efficient Neural Networks\n Song Han et. al show that this is surprisingly effective, but also leaves a lot of potential sparsity untapped. The surprise is what they call the \"free lunch\" effect: \n\"reducing 2x the connections without losing accuracy even without retraining.\"\n\nHowever, they also note that when employing a pruning-followed-by-retraining regimen, they can achieve much better results (higher sparsity at no accuracy loss). This is called \niterative pruning\n, and the retraining that follows pruning is often referred to as \nfine-tuning\n. How the pruning criteria changes between iterations, how many iterations we perform and how often, and which tensors are pruned - this is collectively called the \npruning schedule\n.\n\n\nWe can think of iterative pruning as repeatedly learning which weights are important, removing the least important ones based on some importance criteria, and then retraining the model to let it \"recover\" from the pruning by adjusting the remaining weights. At each iteration, we prune more weights.\n\nThe decision of when to stop pruning is also expressed in the schedule, and it depends on the pruning algorithm. For example, if we are trying to achieve a specific sparsity level, then we stop when the pruning achieves that level. And if we are pruning weights structures in order to reduce the required compute budget, then we stop the pruning when this compute reduction is achieved.\n\n\nDistiller supports expressing the pruning schedule as a YAML file (which is then executed by an instance of a PruningScheduler).\n\n\nPruning granularity\n\n\nPruning individual weight elements is called \nelement-wise pruning\n, and it is also sometimes referred to as \nfine-grained\n pruning.\n\n\nCoarse-grained pruning\n - also referred to as \nstructured pruning\n, \ngroup pruning\n, or \nblock pruning\n - is pruning entire groups of elements which have some significance. Groups come in various shapes and sizes, but an easy to visualize group-pruning is filter-pruning, in which entire filters are removed.\n\n\nSensitivity analysis\n\n\nThe hard part about inducing sparsity via pruning is determining what threshold, or sparsity level, to use for each layer's tensors. Sensitivity analysis is a method that tries to help us rank the tensors by their sensitivity to pruning. \n\nThe idea is to set the pruning level (percentage) of a specific layer, and then to prune once, run an evaluation on the test dataset and record the accuracy score. We do this for all of the parameterized layers, and for each layer we examine several sparsity levels. This should teach us about the \"sensitivity\" of each of the layers to pruning.\n\n\nThe evaluated model should be trained to maximum accuracy before running the analysis, because we aim to understand the behavior of the trained model's performance in relation to pruning of a specific weights tensor.\n\n\nMuch as we can prune structures, we can also perform sensitivity analysis on structures. Distiller implements element-wise pruning sensitivity analysis using the \\(l_1\\)-norm of individual elements; and filter-wise pruning sensitivity analysis using the mean \\(l_1\\)-norm of filters.\n\n\n\nThe authors of \nPruning Filters for Efficient ConvNets\n describe how they do sensitivity analysis:\n\n\n\n\n\"To understand the sensitivity of each layer, we prune each layer independently and evaluate the resulting pruned network\u2019s accuracy on the validation set. Figure 2(b) shows that layers that maintain their accuracy as filters are pruned away correspond to layers with larger slopes in Figure 2(a). On the contrary, layers with relatively flat slopes are more sensitive to pruning. We empirically determine the number of filters to prune for each layer based on their sensitivity to pruning. For deep networks such as VGG-16 or ResNets, we observe that layers in the same stage (with the same feature map size) have a similar sensitivity to pruning. To avoid introducing layer-wise meta-parameters, we use the same pruning ratio for all layers in the same stage. For layers that are sensitive to pruning, we prune a smaller percentage of these layers or completely skip pruning them.\"\n\n\n\n\nThe diagram below shows the results of running an element-wise sensitivity analysis on Alexnet, using Distillers's \nperform_sensitivity_analysis\n utility function.\n\n\nAs reported by Song Han, and exhibited in the diagram, in Alexnet the feature detecting layers (convolution layers) are more sensitive to pruning, and their sensitivity drops, the deeper they are. The fully-connected layers are much less sensitive, which is great, because that's where most of the parameters are.\n\n\n\n\nReferences\n\n\n \nSong Han, Jeff Pool, John Tran, William J. Dally\n.\n \nLearning both Weights and Connections for Efficient Neural Networks\n,\n arXiv:1607.04381v2,\n 2015.\n\n\n\n\n\nHao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, Hans Peter Graf\n.\n \nPruning Filters for Efficient ConvNets\n,\n arXiv:1608.08710v3,\n 2017.", + "location": "/pruning/index.html", + "text": "Pruning\n\n\nA common methodology for inducing sparsity in weights and activations is called \npruning\n. Pruning is the application of a binary criteria to decide which weights to prune: weights which match the pruning criteria are assigned a value of zero. Pruned elements are \"trimmed\" from the model: we zero their values and also make sure they don't take part in the back-propagation process.\n\n\nWe can prune weights, biases, and activations. Biases are few and their contribution to a layer's output is relatively large, so there is little incentive to prune them. We usually see sparse activations following a ReLU layer, because ReLU quenches negative activations to exact zero (\\(ReLU(x): max(0,x)\\)). Sparsity in weights is less common, as weights tend to be very small, but are often not exact zeros.\n\n\n\nLet's define sparsity\n\n\nSparsity is a a measure of how many elements in a tensor are exact zeros, relative to the tensor size. A tensor is considered sparse if \"most\" of its elements are zero. How much is \"most\", is not strictly defined, but when you see a sparse tensor you know it ;-)\n\nThe \n\\(l_0\\)-\"norm\" function\n measures how many zero-elements are in a tensor \nx\n:\n\\[\\lVert x \\rVert_0\\;=\\;|x_1|^0 + |x_2|^0 + ... + |x_n|^0 \\]\nIn other words, an element contributes either a value of 1 or 0 to \\(l_0\\). Anything but an exact zero contributes a value of 1 - that's pretty cool.\n\nSometimes it helps to think about density, the number of non-zero elements (NNZ) and sparsity's complement:\n\\[\ndensity = 1 - sparsity\n\\]\nYou can use \ndistiller.sparsity\n and \ndistiller.density\n to query a PyTorch tensor's sparsity and density.\n\n\nWhat is weights pruning?\n\n\nWeights pruning, or model pruning, is a set of methods to increase the sparsity (amount of zero-valued elements in a tensor) of a network's weights. In general, the term 'parameters' refers to both weights and bias tensors of a model. Biases are rarely, if ever, pruned because there are very few bias elements compared to weights elements, and it is just not worth the trouble.\n\n\nPruning requires a criteria for choosing which elements to prune - this is called the \npruning criteria\n. The most common pruning criteria is the absolute value of each element: the element's absolute value is compared to some threshold value, and if it is below the threshold the element is set to zero (i.e. pruned) . This is implemented by the \ndistiller.MagnitudeParameterPruner\n class. The idea behind this method, is that weights with small \\(l_1\\)-norms (absolute value) contribute little to the final result (low saliency), so they are less important and can be removed.\n\n\nA related idea motivating pruning, is that models are over-parametrized and contain redundant logic and features. Therefore, some of these redundancies can be removed by setting their weights to zero.\n\n\nAnd yet another way to think of pruning is to phrase it as a search for a set of weights with as many zeros as possible, which still produces acceptable inference accuracies compared to the dense-model (non-pruned model). Another way to look at it, is to imagine that because of the very high-dimensionality of the parameter space, the immediate space around the dense-model's solution likely contains some sparse solutions, and we want to use find these sparse solutions. \n\n\n\n\nPruning schedule\n\n\nThe most straight-forward to prune is to take a trained model and prune it once; also called \none-shot pruning\n. In \nLearning both Weights and Connections for Efficient Neural Networks\n Song Han et. al show that this is surprisingly effective, but also leaves a lot of potential sparsity untapped. The surprise is what they call the \"free lunch\" effect: \n\"reducing 2x the connections without losing accuracy even without retraining.\"\n\nHowever, they also note that when employing a pruning-followed-by-retraining regimen, they can achieve much better results (higher sparsity at no accuracy loss). This is called \niterative pruning\n, and the retraining that follows pruning is often referred to as \nfine-tuning\n. How the pruning criteria changes between iterations, how many iterations we perform and how often, and which tensors are pruned - this is collectively called the \npruning schedule\n.\n\n\nWe can think of iterative pruning as repeatedly learning which weights are important, removing the least important ones based on some importance criteria, and then retraining the model to let it \"recover\" from the pruning by adjusting the remaining weights. At each iteration, we prune more weights.\n\nThe decision of when to stop pruning is also expressed in the schedule, and it depends on the pruning algorithm. For example, if we are trying to achieve a specific sparsity level, then we stop when the pruning achieves that level. And if we are pruning weights structures in order to reduce the required compute budget, then we stop the pruning when this compute reduction is achieved.\n\n\nDistiller supports expressing the pruning schedule as a YAML file (which is then executed by an instance of a PruningScheduler).\n\n\nPruning granularity\n\n\nPruning individual weight elements is called \nelement-wise pruning\n, and it is also sometimes referred to as \nfine-grained\n pruning.\n\n\nCoarse-grained pruning\n - also referred to as \nstructured pruning\n, \ngroup pruning\n, or \nblock pruning\n - is pruning entire groups of elements which have some significance. Groups come in various shapes and sizes, but an easy to visualize group-pruning is filter-pruning, in which entire filters are removed.\n\n\nSensitivity analysis\n\n\nThe hard part about inducing sparsity via pruning is determining what threshold, or sparsity level, to use for each layer's tensors. Sensitivity analysis is a method that tries to help us rank the tensors by their sensitivity to pruning. \n\nThe idea is to set the pruning level (percentage) of a specific layer, and then to prune once, run an evaluation on the test dataset and record the accuracy score. We do this for all of the parameterized layers, and for each layer we examine several sparsity levels. This should teach us about the \"sensitivity\" of each of the layers to pruning.\n\n\nThe evaluated model should be trained to maximum accuracy before running the analysis, because we aim to understand the behavior of the trained model's performance in relation to pruning of a specific weights tensor.\n\n\nMuch as we can prune structures, we can also perform sensitivity analysis on structures. Distiller implements element-wise pruning sensitivity analysis using the \\(l_1\\)-norm of individual elements; and filter-wise pruning sensitivity analysis using the mean \\(l_1\\)-norm of filters.\n\n\n\nThe authors of \nPruning Filters for Efficient ConvNets\n describe how they do sensitivity analysis:\n\n\n\n\n\"To understand the sensitivity of each layer, we prune each layer independently and evaluate the resulting pruned network\u2019s accuracy on the validation set. Figure 2(b) shows that layers that maintain their accuracy as filters are pruned away correspond to layers with larger slopes in Figure 2(a). On the contrary, layers with relatively flat slopes are more sensitive to pruning. We empirically determine the number of filters to prune for each layer based on their sensitivity to pruning. For deep networks such as VGG-16 or ResNets, we observe that layers in the same stage (with the same feature map size) have a similar sensitivity to pruning. To avoid introducing layer-wise meta-parameters, we use the same pruning ratio for all layers in the same stage. For layers that are sensitive to pruning, we prune a smaller percentage of these layers or completely skip pruning them.\"\n\n\n\n\nThe diagram below shows the results of running an element-wise sensitivity analysis on Alexnet, using Distillers's \nperform_sensitivity_analysis\n utility function.\n\n\nAs reported by Song Han, and exhibited in the diagram, in Alexnet the feature detecting layers (convolution layers) are more sensitive to pruning, and their sensitivity drops, the deeper they are. The fully-connected layers are much less sensitive, which is great, because that's where most of the parameters are.\n\n\n\n\nReferences\n\n\n \nSong Han, Jeff Pool, John Tran, William J. Dally\n.\n \nLearning both Weights and Connections for Efficient Neural Networks\n,\n arXiv:1607.04381v2,\n 2015.\n\n\n\n\n\nHao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, Hans Peter Graf\n.\n \nPruning Filters for Efficient ConvNets\n,\n arXiv:1608.08710v3,\n 2017.", "title": "Pruning" - }, + }, { - "location": "/pruning/index.html#pruning", - "text": "A common methodology for inducing sparsity in weights and activations is called pruning . Pruning is the application of a binary criteria to decide which weights to prune: weights which match the pruning criteria are assigned a value of zero. Pruned elements are \"trimmed\" from the model: we zero their values and also make sure they don't take part in the back-propagation process. We can prune weights, biases, and activations. Biases are few and their contribution to a layer's output is relatively large, so there is little incentive to prune them. We usually see sparse activations following a ReLU layer, because ReLU quenches negative activations to exact zero (\\(ReLU(x): max(0,x)\\)). Sparsity in weights is less common, as weights tend to be very small, but are often not exact zeros.", + "location": "/pruning/index.html#pruning", + "text": "A common methodology for inducing sparsity in weights and activations is called pruning . Pruning is the application of a binary criteria to decide which weights to prune: weights which match the pruning criteria are assigned a value of zero. Pruned elements are \"trimmed\" from the model: we zero their values and also make sure they don't take part in the back-propagation process. We can prune weights, biases, and activations. Biases are few and their contribution to a layer's output is relatively large, so there is little incentive to prune them. We usually see sparse activations following a ReLU layer, because ReLU quenches negative activations to exact zero (\\(ReLU(x): max(0,x)\\)). Sparsity in weights is less common, as weights tend to be very small, but are often not exact zeros.", "title": "Pruning" - }, + }, { - "location": "/pruning/index.html#lets-define-sparsity", - "text": "Sparsity is a a measure of how many elements in a tensor are exact zeros, relative to the tensor size. A tensor is considered sparse if \"most\" of its elements are zero. How much is \"most\", is not strictly defined, but when you see a sparse tensor you know it ;-) \nThe \\(l_0\\)-\"norm\" function measures how many zero-elements are in a tensor x :\n\\[\\lVert x \\rVert_0\\;=\\;|x_1|^0 + |x_2|^0 + ... + |x_n|^0 \\]\nIn other words, an element contributes either a value of 1 or 0 to \\(l_0\\). Anything but an exact zero contributes a value of 1 - that's pretty cool. \nSometimes it helps to think about density, the number of non-zero elements (NNZ) and sparsity's complement:\n\\[\ndensity = 1 - sparsity\n\\]\nYou can use distiller.sparsity and distiller.density to query a PyTorch tensor's sparsity and density.", + "location": "/pruning/index.html#lets-define-sparsity", + "text": "Sparsity is a a measure of how many elements in a tensor are exact zeros, relative to the tensor size. A tensor is considered sparse if \"most\" of its elements are zero. How much is \"most\", is not strictly defined, but when you see a sparse tensor you know it ;-) \nThe \\(l_0\\)-\"norm\" function measures how many zero-elements are in a tensor x :\n\\[\\lVert x \\rVert_0\\;=\\;|x_1|^0 + |x_2|^0 + ... + |x_n|^0 \\]\nIn other words, an element contributes either a value of 1 or 0 to \\(l_0\\). Anything but an exact zero contributes a value of 1 - that's pretty cool. \nSometimes it helps to think about density, the number of non-zero elements (NNZ) and sparsity's complement:\n\\[\ndensity = 1 - sparsity\n\\]\nYou can use distiller.sparsity and distiller.density to query a PyTorch tensor's sparsity and density.", "title": "Let's define sparsity" - }, + }, { - "location": "/pruning/index.html#what-is-weights-pruning", - "text": "Weights pruning, or model pruning, is a set of methods to increase the sparsity (amount of zero-valued elements in a tensor) of a network's weights. In general, the term 'parameters' refers to both weights and bias tensors of a model. Biases are rarely, if ever, pruned because there are very few bias elements compared to weights elements, and it is just not worth the trouble. \nPruning requires a criteria for choosing which elements to prune - this is called the pruning criteria . The most common pruning criteria is the absolute value of each element: the element's absolute value is compared to some threshold value, and if it is below the threshold the element is set to zero (i.e. pruned) . This is implemented by the distiller.MagnitudeParameterPruner class. The idea behind this method, is that weights with small \\(l_1\\)-norms (absolute value) contribute little to the final result (low saliency), so they are less important and can be removed. \nA related idea motivating pruning, is that models are over-parametrized and contain redundant logic and features. Therefore, some of these redundancies can be removed by setting their weights to zero. \nAnd yet another way to think of pruning is to phrase it as a search for a set of weights with as many zeros as possible, which still produces acceptable inference accuracies compared to the dense-model (non-pruned model). Another way to look at it, is to imagine that because of the very high-dimensionality of the parameter space, the immediate space around the dense-model's solution likely contains some sparse solutions, and we want to use find these sparse solutions.", + "location": "/pruning/index.html#what-is-weights-pruning", + "text": "Weights pruning, or model pruning, is a set of methods to increase the sparsity (amount of zero-valued elements in a tensor) of a network's weights. In general, the term 'parameters' refers to both weights and bias tensors of a model. Biases are rarely, if ever, pruned because there are very few bias elements compared to weights elements, and it is just not worth the trouble. \nPruning requires a criteria for choosing which elements to prune - this is called the pruning criteria . The most common pruning criteria is the absolute value of each element: the element's absolute value is compared to some threshold value, and if it is below the threshold the element is set to zero (i.e. pruned) . This is implemented by the distiller.MagnitudeParameterPruner class. The idea behind this method, is that weights with small \\(l_1\\)-norms (absolute value) contribute little to the final result (low saliency), so they are less important and can be removed. \nA related idea motivating pruning, is that models are over-parametrized and contain redundant logic and features. Therefore, some of these redundancies can be removed by setting their weights to zero. \nAnd yet another way to think of pruning is to phrase it as a search for a set of weights with as many zeros as possible, which still produces acceptable inference accuracies compared to the dense-model (non-pruned model). Another way to look at it, is to imagine that because of the very high-dimensionality of the parameter space, the immediate space around the dense-model's solution likely contains some sparse solutions, and we want to use find these sparse solutions.", "title": "What is weights pruning?" - }, + }, { - "location": "/pruning/index.html#pruning-schedule", - "text": "The most straight-forward to prune is to take a trained model and prune it once; also called one-shot pruning . In Learning both Weights and Connections for Efficient Neural Networks Song Han et. al show that this is surprisingly effective, but also leaves a lot of potential sparsity untapped. The surprise is what they call the \"free lunch\" effect: \"reducing 2x the connections without losing accuracy even without retraining.\" \nHowever, they also note that when employing a pruning-followed-by-retraining regimen, they can achieve much better results (higher sparsity at no accuracy loss). This is called iterative pruning , and the retraining that follows pruning is often referred to as fine-tuning . How the pruning criteria changes between iterations, how many iterations we perform and how often, and which tensors are pruned - this is collectively called the pruning schedule . \nWe can think of iterative pruning as repeatedly learning which weights are important, removing the least important ones based on some importance criteria, and then retraining the model to let it \"recover\" from the pruning by adjusting the remaining weights. At each iteration, we prune more weights. \nThe decision of when to stop pruning is also expressed in the schedule, and it depends on the pruning algorithm. For example, if we are trying to achieve a specific sparsity level, then we stop when the pruning achieves that level. And if we are pruning weights structures in order to reduce the required compute budget, then we stop the pruning when this compute reduction is achieved. \nDistiller supports expressing the pruning schedule as a YAML file (which is then executed by an instance of a PruningScheduler).", + "location": "/pruning/index.html#pruning-schedule", + "text": "The most straight-forward to prune is to take a trained model and prune it once; also called one-shot pruning . In Learning both Weights and Connections for Efficient Neural Networks Song Han et. al show that this is surprisingly effective, but also leaves a lot of potential sparsity untapped. The surprise is what they call the \"free lunch\" effect: \"reducing 2x the connections without losing accuracy even without retraining.\" \nHowever, they also note that when employing a pruning-followed-by-retraining regimen, they can achieve much better results (higher sparsity at no accuracy loss). This is called iterative pruning , and the retraining that follows pruning is often referred to as fine-tuning . How the pruning criteria changes between iterations, how many iterations we perform and how often, and which tensors are pruned - this is collectively called the pruning schedule . \nWe can think of iterative pruning as repeatedly learning which weights are important, removing the least important ones based on some importance criteria, and then retraining the model to let it \"recover\" from the pruning by adjusting the remaining weights. At each iteration, we prune more weights. \nThe decision of when to stop pruning is also expressed in the schedule, and it depends on the pruning algorithm. For example, if we are trying to achieve a specific sparsity level, then we stop when the pruning achieves that level. And if we are pruning weights structures in order to reduce the required compute budget, then we stop the pruning when this compute reduction is achieved. \nDistiller supports expressing the pruning schedule as a YAML file (which is then executed by an instance of a PruningScheduler).", "title": "Pruning schedule" - }, + }, { - "location": "/pruning/index.html#pruning-granularity", - "text": "Pruning individual weight elements is called element-wise pruning , and it is also sometimes referred to as fine-grained pruning. Coarse-grained pruning - also referred to as structured pruning , group pruning , or block pruning - is pruning entire groups of elements which have some significance. Groups come in various shapes and sizes, but an easy to visualize group-pruning is filter-pruning, in which entire filters are removed.", + "location": "/pruning/index.html#pruning-granularity", + "text": "Pruning individual weight elements is called element-wise pruning , and it is also sometimes referred to as fine-grained pruning. Coarse-grained pruning - also referred to as structured pruning , group pruning , or block pruning - is pruning entire groups of elements which have some significance. Groups come in various shapes and sizes, but an easy to visualize group-pruning is filter-pruning, in which entire filters are removed.", "title": "Pruning granularity" - }, + }, { - "location": "/pruning/index.html#sensitivity-analysis", - "text": "The hard part about inducing sparsity via pruning is determining what threshold, or sparsity level, to use for each layer's tensors. Sensitivity analysis is a method that tries to help us rank the tensors by their sensitivity to pruning. \nThe idea is to set the pruning level (percentage) of a specific layer, and then to prune once, run an evaluation on the test dataset and record the accuracy score. We do this for all of the parameterized layers, and for each layer we examine several sparsity levels. This should teach us about the \"sensitivity\" of each of the layers to pruning. \nThe evaluated model should be trained to maximum accuracy before running the analysis, because we aim to understand the behavior of the trained model's performance in relation to pruning of a specific weights tensor. \nMuch as we can prune structures, we can also perform sensitivity analysis on structures. Distiller implements element-wise pruning sensitivity analysis using the \\(l_1\\)-norm of individual elements; and filter-wise pruning sensitivity analysis using the mean \\(l_1\\)-norm of filters. The authors of Pruning Filters for Efficient ConvNets describe how they do sensitivity analysis: \"To understand the sensitivity of each layer, we prune each layer independently and evaluate the resulting pruned network\u2019s accuracy on the validation set. Figure 2(b) shows that layers that maintain their accuracy as filters are pruned away correspond to layers with larger slopes in Figure 2(a). On the contrary, layers with relatively flat slopes are more sensitive to pruning. We empirically determine the number of filters to prune for each layer based on their sensitivity to pruning. For deep networks such as VGG-16 or ResNets, we observe that layers in the same stage (with the same feature map size) have a similar sensitivity to pruning. To avoid introducing layer-wise meta-parameters, we use the same pruning ratio for all layers in the same stage. For layers that are sensitive to pruning, we prune a smaller percentage of these layers or completely skip pruning them.\" The diagram below shows the results of running an element-wise sensitivity analysis on Alexnet, using Distillers's perform_sensitivity_analysis utility function. \nAs reported by Song Han, and exhibited in the diagram, in Alexnet the feature detecting layers (convolution layers) are more sensitive to pruning, and their sensitivity drops, the deeper they are. The fully-connected layers are much less sensitive, which is great, because that's where most of the parameters are.", + "location": "/pruning/index.html#sensitivity-analysis", + "text": "The hard part about inducing sparsity via pruning is determining what threshold, or sparsity level, to use for each layer's tensors. Sensitivity analysis is a method that tries to help us rank the tensors by their sensitivity to pruning. \nThe idea is to set the pruning level (percentage) of a specific layer, and then to prune once, run an evaluation on the test dataset and record the accuracy score. We do this for all of the parameterized layers, and for each layer we examine several sparsity levels. This should teach us about the \"sensitivity\" of each of the layers to pruning. \nThe evaluated model should be trained to maximum accuracy before running the analysis, because we aim to understand the behavior of the trained model's performance in relation to pruning of a specific weights tensor. \nMuch as we can prune structures, we can also perform sensitivity analysis on structures. Distiller implements element-wise pruning sensitivity analysis using the \\(l_1\\)-norm of individual elements; and filter-wise pruning sensitivity analysis using the mean \\(l_1\\)-norm of filters. The authors of Pruning Filters for Efficient ConvNets describe how they do sensitivity analysis: \"To understand the sensitivity of each layer, we prune each layer independently and evaluate the resulting pruned network\u2019s accuracy on the validation set. Figure 2(b) shows that layers that maintain their accuracy as filters are pruned away correspond to layers with larger slopes in Figure 2(a). On the contrary, layers with relatively flat slopes are more sensitive to pruning. We empirically determine the number of filters to prune for each layer based on their sensitivity to pruning. For deep networks such as VGG-16 or ResNets, we observe that layers in the same stage (with the same feature map size) have a similar sensitivity to pruning. To avoid introducing layer-wise meta-parameters, we use the same pruning ratio for all layers in the same stage. For layers that are sensitive to pruning, we prune a smaller percentage of these layers or completely skip pruning them.\" The diagram below shows the results of running an element-wise sensitivity analysis on Alexnet, using Distillers's perform_sensitivity_analysis utility function. \nAs reported by Song Han, and exhibited in the diagram, in Alexnet the feature detecting layers (convolution layers) are more sensitive to pruning, and their sensitivity drops, the deeper they are. The fully-connected layers are much less sensitive, which is great, because that's where most of the parameters are.", "title": "Sensitivity analysis" - }, + }, { - "location": "/pruning/index.html#references", - "text": "Song Han, Jeff Pool, John Tran, William J. Dally .\n Learning both Weights and Connections for Efficient Neural Networks ,\n arXiv:1607.04381v2,\n 2015. Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, Hans Peter Graf .\n Pruning Filters for Efficient ConvNets ,\n arXiv:1608.08710v3,\n 2017.", + "location": "/pruning/index.html#references", + "text": "Song Han, Jeff Pool, John Tran, William J. Dally .\n Learning both Weights and Connections for Efficient Neural Networks ,\n arXiv:1607.04381v2,\n 2015. Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, Hans Peter Graf .\n Pruning Filters for Efficient ConvNets ,\n arXiv:1608.08710v3,\n 2017.", "title": "References" - }, + }, { - "location": "/regularization/index.html", - "text": "Regularization\n\n\nIn their book \nDeep Learning\n Ian Goodfellow et al. define regularization as\n\n\n\n\n\"any modification we make to a learning algorithm that is intended to reduce its generalization error, but not its training error.\"\n\n\n\n\nPyTorch's \noptimizers\n use \\(l_2\\) parameter regularization to limit the capacity of models (i.e. reduce the variance).\n\n\nIn general, we can write this as:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W)\n\\]\nAnd specifically,\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_2^2\n\\]\nWhere W is the collection of all weight elements in the network (i.e. this is model.parameters()), \\(loss(W;x;y)\\) is the total training loss, and \\(loss_D(W)\\) is the data loss (i.e. the error of the objective function, also called the loss function, or \ncriterion\n in the Distiller sample image classifier compression application).\n\n\noptimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9, weight_decay=0.0001)\ncriterion = nn.CrossEntropyLoss()\n...\nfor input, target in dataset:\n optimizer.zero_grad()\n output = model(input)\n loss = criterion(output, target)\n loss.backward()\n optimizer.step()\n\n\n\n\n\\(\\lambda_R\\) is a scalar called the \nregularization strength\n, and it balances the data error and the regularization error. In PyTorch, this is the \nweight_decay\n argument.\n\n\n\\(\\lVert W \\rVert_2^2\\) is the square of the \\(l_2\\)-norm of W, and as such it is a \nmagnitude\n, or sizing, of the weights tensor.\n\\[\n\\lVert W \\rVert_2^2 = \\sum_{l=1}^{L} \\sum_{i=1}^{n} |w_{l,i}|^2 \\;\\;where \\;n = torch.numel(w_l)\n\\]\n\n\n\\(L\\) is the number of layers in the network; and the notation about used 1-based numbering to simplify the notation.\n\n\nThe qualitative differences between the \\(l_2\\)-norm, and the squared \\(l_2\\)-norm is explained in \nDeep Learning\n.\n\n\nSparsity and Regularization\n\n\nWe mention regularization because there is an interesting interaction between regularization and some DNN sparsity-inducing methods.\n\n\nIn \nDense-Sparse-Dense (DSD)\n, Song Han et al. use pruning as a regularizer to improve a model's accuracy:\n\n\n\n\n\"Sparsity is a powerful form of regularization. Our intuition is that, once the network arrives at a local minimum given the sparsity constraint, relaxing the constraint gives the network more freedom to escape the saddle point and arrive at a higher-accuracy local minimum.\"\n\n\n\n\nRegularization can also be used to induce sparsity. To induce element-wise sparsity we can use the \\(l_1\\)-norm, \\(\\lVert W \\rVert_1\\).\n\\[\n\\lVert W \\rVert_1 = l_1(W) = \\sum_{i=1}^{|W|} |w_i|\n\\]\n\n\n\\(l_2\\)-norm regularization reduces overfitting and improves a model's accuracy by shrinking large parameters, but it does not force these parameters to absolute zero. \\(l_1\\)-norm regularization sets some of the parameter elements to zero, therefore limiting the model's capacity while making the model simpler. This is sometimes referred to as \nfeature selection\n and gives us another interpretation of pruning.\n\n\nOne\n of Distiller's Jupyter notebooks explains how the \\(l_1\\)-norm regularizer induces sparsity, and how it interacts with \\(l_2\\)-norm regularization.\n\n\nIf we configure \nweight_decay\n to zero and use \\(l_1\\)-norm regularization, then we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_1\n\\]\nIf we use both regularizers, we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_{R_2} \\lVert W \\rVert_2^2 + \\lambda_{R_1} \\lVert W \\rVert_1\n\\]\n\n\nClass \ndistiller.L1Regularizer\n implements \\(l_1\\)-norm regularization, and of course, you can also schedule regularization.\n\n\nl1_regularizer = distiller.s(model.parameters())\n...\nloss = criterion(output, target) + lambda * l1_regularizer()\n\n\n\n\nGroup Regularization\n\n\nIn Group Regularization, we penalize entire groups of parameter elements, instead of individual elements. Therefore, entire groups are either sparsified (i.e. all of the group elements have a value of zero) or not. The group structures have to be pre-defined.\n\n\nTo the data loss, and the element-wise regularization (if any), we can add group-wise regularization penalty. We represent all of the parameter groups in layer \\(l\\) as \\( W_l^{(G)} \\), and we add the penalty of all groups for all layers. It gets a bit messy, but not overly complicated:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W) + \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)})\n\\]\n\n\nLet's denote all of the weight elements in group \\(g\\) as \\(w^{(g)}\\).\n\n\n\\[\nR_g(w^{(g)}) = \\sum_{g=1}^{G} \\lVert w^{(g)} \\rVert_g = \\sum_{g=1}^{G} \\sum_{i=1}^{|w^{(g)}|} {(w_i^{(g)})}^2\n\\]\nwhere \\(w^{(g)} \\in w^{(l)} \\) and \\( |w^{(g)}| \\) is the number of elements in \\( w^{(g)} \\).\n\n\n\\( \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)}) \\) is called the Group Lasso regularizer. Much as in \\(l_1\\)-norm regularization we sum the magnitudes of all tensor elements, in Group Lasso we sum the magnitudes of element structures (i.e. groups).\n\n\n\nGroup Regularization is also called Block Regularization, Structured Regularization, or coarse-grained sparsity (remember that element-wise sparsity is sometimes referred to as fine-grained sparsity). Group sparsity exhibits regularity (i.e. its shape is regular), and therefore\nit can be beneficial to improve inference speed.\n\n\nHuizi-et-al-2017\n provides an overview of some of the different groups: kernel, channel, filter, layers. Fiber structures such as matrix columns and rows, as well as various shaped structures (block sparsity), and even \nintra kernel strided sparsity\n can also be used.\n\n\ndistiller.GroupLassoRegularizer\n currently implements most of these groups, and you can easily add new groups.\n\n\nReferences\n\n\n \nIan Goodfellow and Yoshua Bengio and Aaron Courville\n.\n \nDeep Learning\n,\n arXiv:1607.04381v2,\n 2017.\n\n\n\n\n\nSong Han, Jeff Pool, Sharan Narang, Huizi Mao, Enhao Gong, Shijian Tang, Erich Elsen, Peter Vajda, Manohar Paluri, John Tran, Bryan Catanzaro, William J. Dally\n.\n \nDSD: Dense-Sparse-Dense Training for Deep Neural Networks\n,\n arXiv:1607.04381v2,\n 2017.\n\n\n\n\n\nHuizi Mao, Song Han, Jeff Pool, Wenshuo Li, Xingyu Liu, Yu Wang, William J. Dally\n.\n \nExploring the Regularity of Sparse Structure in Convolutional Neural Networks\n,\n arXiv:1705.08922v3,\n 2017.\n\n\n\n\n\nSajid Anwar, Kyuyeon Hwang, and Wonyong Sung\n.\n \nStructured pruning of deep convolutional neural networks\n,\n arXiv:1512.08571,\n 2015", + "location": "/regularization/index.html", + "text": "Regularization\n\n\nIn their book \nDeep Learning\n Ian Goodfellow et al. define regularization as\n\n\n\n\n\"any modification we make to a learning algorithm that is intended to reduce its generalization error, but not its training error.\"\n\n\n\n\nPyTorch's \noptimizers\n use \\(l_2\\) parameter regularization to limit the capacity of models (i.e. reduce the variance).\n\n\nIn general, we can write this as:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W)\n\\]\nAnd specifically,\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_2^2\n\\]\nWhere W is the collection of all weight elements in the network (i.e. this is model.parameters()), \\(loss(W;x;y)\\) is the total training loss, and \\(loss_D(W)\\) is the data loss (i.e. the error of the objective function, also called the loss function, or \ncriterion\n in the Distiller sample image classifier compression application).\n\n\noptimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9, weight_decay=0.0001)\ncriterion = nn.CrossEntropyLoss()\n...\nfor input, target in dataset:\n optimizer.zero_grad()\n output = model(input)\n loss = criterion(output, target)\n loss.backward()\n optimizer.step()\n\n\n\n\n\\(\\lambda_R\\) is a scalar called the \nregularization strength\n, and it balances the data error and the regularization error. In PyTorch, this is the \nweight_decay\n argument.\n\n\n\\(\\lVert W \\rVert_2^2\\) is the square of the \\(l_2\\)-norm of W, and as such it is a \nmagnitude\n, or sizing, of the weights tensor.\n\\[\n\\lVert W \\rVert_2^2 = \\sum_{l=1}^{L} \\sum_{i=1}^{n} |w_{l,i}|^2 \\;\\;where \\;n = torch.numel(w_l)\n\\]\n\n\n\\(L\\) is the number of layers in the network; and the notation about used 1-based numbering to simplify the notation.\n\n\nThe qualitative differences between the \\(l_2\\)-norm, and the squared \\(l_2\\)-norm is explained in \nDeep Learning\n.\n\n\nSparsity and Regularization\n\n\nWe mention regularization because there is an interesting interaction between regularization and some DNN sparsity-inducing methods.\n\n\nIn \nDense-Sparse-Dense (DSD)\n, Song Han et al. use pruning as a regularizer to improve a model's accuracy:\n\n\n\n\n\"Sparsity is a powerful form of regularization. Our intuition is that, once the network arrives at a local minimum given the sparsity constraint, relaxing the constraint gives the network more freedom to escape the saddle point and arrive at a higher-accuracy local minimum.\"\n\n\n\n\nRegularization can also be used to induce sparsity. To induce element-wise sparsity we can use the \\(l_1\\)-norm, \\(\\lVert W \\rVert_1\\).\n\\[\n\\lVert W \\rVert_1 = l_1(W) = \\sum_{i=1}^{|W|} |w_i|\n\\]\n\n\n\\(l_2\\)-norm regularization reduces overfitting and improves a model's accuracy by shrinking large parameters, but it does not force these parameters to absolute zero. \\(l_1\\)-norm regularization sets some of the parameter elements to zero, therefore limiting the model's capacity while making the model simpler. This is sometimes referred to as \nfeature selection\n and gives us another interpretation of pruning.\n\n\nOne\n of Distiller's Jupyter notebooks explains how the \\(l_1\\)-norm regularizer induces sparsity, and how it interacts with \\(l_2\\)-norm regularization.\n\n\nIf we configure \nweight_decay\n to zero and use \\(l_1\\)-norm regularization, then we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_1\n\\]\nIf we use both regularizers, we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_{R_2} \\lVert W \\rVert_2^2 + \\lambda_{R_1} \\lVert W \\rVert_1\n\\]\n\n\nClass \ndistiller.L1Regularizer\n implements \\(l_1\\)-norm regularization, and of course, you can also schedule regularization.\n\n\nl1_regularizer = distiller.s(model.parameters())\n...\nloss = criterion(output, target) + lambda * l1_regularizer()\n\n\n\n\nGroup Regularization\n\n\nIn Group Regularization, we penalize entire groups of parameter elements, instead of individual elements. Therefore, entire groups are either sparsified (i.e. all of the group elements have a value of zero) or not. The group structures have to be pre-defined.\n\n\nTo the data loss, and the element-wise regularization (if any), we can add group-wise regularization penalty. We represent all of the parameter groups in layer \\(l\\) as \\( W_l^{(G)} \\), and we add the penalty of all groups for all layers. It gets a bit messy, but not overly complicated:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W) + \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)})\n\\]\n\n\nLet's denote all of the weight elements in group \\(g\\) as \\(w^{(g)}\\).\n\n\n\\[\nR_g(w^{(g)}) = \\sum_{g=1}^{G} \\lVert w^{(g)} \\rVert_g = \\sum_{g=1}^{G} \\sum_{i=1}^{|w^{(g)}|} {(w_i^{(g)})}^2\n\\]\nwhere \\(w^{(g)} \\in w^{(l)} \\) and \\( |w^{(g)}| \\) is the number of elements in \\( w^{(g)} \\).\n\n\n\\( \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)}) \\) is called the Group Lasso regularizer. Much as in \\(l_1\\)-norm regularization we sum the magnitudes of all tensor elements, in Group Lasso we sum the magnitudes of element structures (i.e. groups).\n\n\n\nGroup Regularization is also called Block Regularization, Structured Regularization, or coarse-grained sparsity (remember that element-wise sparsity is sometimes referred to as fine-grained sparsity). Group sparsity exhibits regularity (i.e. its shape is regular), and therefore\nit can be beneficial to improve inference speed.\n\n\nHuizi-et-al-2017\n provides an overview of some of the different groups: kernel, channel, filter, layers. Fiber structures such as matrix columns and rows, as well as various shaped structures (block sparsity), and even \nintra kernel strided sparsity\n can also be used.\n\n\ndistiller.GroupLassoRegularizer\n currently implements most of these groups, and you can easily add new groups.\n\n\nReferences\n\n\n \nIan Goodfellow and Yoshua Bengio and Aaron Courville\n.\n \nDeep Learning\n,\n arXiv:1607.04381v2,\n 2017.\n\n\n\n\n\nSong Han, Jeff Pool, Sharan Narang, Huizi Mao, Enhao Gong, Shijian Tang, Erich Elsen, Peter Vajda, Manohar Paluri, John Tran, Bryan Catanzaro, William J. Dally\n.\n \nDSD: Dense-Sparse-Dense Training for Deep Neural Networks\n,\n arXiv:1607.04381v2,\n 2017.\n\n\n\n\n\nHuizi Mao, Song Han, Jeff Pool, Wenshuo Li, Xingyu Liu, Yu Wang, William J. Dally\n.\n \nExploring the Regularity of Sparse Structure in Convolutional Neural Networks\n,\n arXiv:1705.08922v3,\n 2017.\n\n\n\n\n\nSajid Anwar, Kyuyeon Hwang, and Wonyong Sung\n.\n \nStructured pruning of deep convolutional neural networks\n,\n arXiv:1512.08571,\n 2015", "title": "Regularization" - }, + }, { - "location": "/regularization/index.html#regularization", - "text": "In their book Deep Learning Ian Goodfellow et al. define regularization as \"any modification we make to a learning algorithm that is intended to reduce its generalization error, but not its training error.\" PyTorch's optimizers use \\(l_2\\) parameter regularization to limit the capacity of models (i.e. reduce the variance). In general, we can write this as:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W)\n\\]\nAnd specifically,\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_2^2\n\\]\nWhere W is the collection of all weight elements in the network (i.e. this is model.parameters()), \\(loss(W;x;y)\\) is the total training loss, and \\(loss_D(W)\\) is the data loss (i.e. the error of the objective function, also called the loss function, or criterion in the Distiller sample image classifier compression application). optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9, weight_decay=0.0001)\ncriterion = nn.CrossEntropyLoss()\n...\nfor input, target in dataset:\n optimizer.zero_grad()\n output = model(input)\n loss = criterion(output, target)\n loss.backward()\n optimizer.step() \\(\\lambda_R\\) is a scalar called the regularization strength , and it balances the data error and the regularization error. In PyTorch, this is the weight_decay argument. \\(\\lVert W \\rVert_2^2\\) is the square of the \\(l_2\\)-norm of W, and as such it is a magnitude , or sizing, of the weights tensor.\n\\[\n\\lVert W \\rVert_2^2 = \\sum_{l=1}^{L} \\sum_{i=1}^{n} |w_{l,i}|^2 \\;\\;where \\;n = torch.numel(w_l)\n\\] \\(L\\) is the number of layers in the network; and the notation about used 1-based numbering to simplify the notation. The qualitative differences between the \\(l_2\\)-norm, and the squared \\(l_2\\)-norm is explained in Deep Learning .", + "location": "/regularization/index.html#regularization", + "text": "In their book Deep Learning Ian Goodfellow et al. define regularization as \"any modification we make to a learning algorithm that is intended to reduce its generalization error, but not its training error.\" PyTorch's optimizers use \\(l_2\\) parameter regularization to limit the capacity of models (i.e. reduce the variance). In general, we can write this as:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W)\n\\]\nAnd specifically,\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_2^2\n\\]\nWhere W is the collection of all weight elements in the network (i.e. this is model.parameters()), \\(loss(W;x;y)\\) is the total training loss, and \\(loss_D(W)\\) is the data loss (i.e. the error of the objective function, also called the loss function, or criterion in the Distiller sample image classifier compression application). optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9, weight_decay=0.0001)\ncriterion = nn.CrossEntropyLoss()\n...\nfor input, target in dataset:\n optimizer.zero_grad()\n output = model(input)\n loss = criterion(output, target)\n loss.backward()\n optimizer.step() \\(\\lambda_R\\) is a scalar called the regularization strength , and it balances the data error and the regularization error. In PyTorch, this is the weight_decay argument. \\(\\lVert W \\rVert_2^2\\) is the square of the \\(l_2\\)-norm of W, and as such it is a magnitude , or sizing, of the weights tensor.\n\\[\n\\lVert W \\rVert_2^2 = \\sum_{l=1}^{L} \\sum_{i=1}^{n} |w_{l,i}|^2 \\;\\;where \\;n = torch.numel(w_l)\n\\] \\(L\\) is the number of layers in the network; and the notation about used 1-based numbering to simplify the notation. The qualitative differences between the \\(l_2\\)-norm, and the squared \\(l_2\\)-norm is explained in Deep Learning .", "title": "Regularization" - }, + }, { - "location": "/regularization/index.html#sparsity-and-regularization", - "text": "We mention regularization because there is an interesting interaction between regularization and some DNN sparsity-inducing methods. In Dense-Sparse-Dense (DSD) , Song Han et al. use pruning as a regularizer to improve a model's accuracy: \"Sparsity is a powerful form of regularization. Our intuition is that, once the network arrives at a local minimum given the sparsity constraint, relaxing the constraint gives the network more freedom to escape the saddle point and arrive at a higher-accuracy local minimum.\" Regularization can also be used to induce sparsity. To induce element-wise sparsity we can use the \\(l_1\\)-norm, \\(\\lVert W \\rVert_1\\).\n\\[\n\\lVert W \\rVert_1 = l_1(W) = \\sum_{i=1}^{|W|} |w_i|\n\\] \\(l_2\\)-norm regularization reduces overfitting and improves a model's accuracy by shrinking large parameters, but it does not force these parameters to absolute zero. \\(l_1\\)-norm regularization sets some of the parameter elements to zero, therefore limiting the model's capacity while making the model simpler. This is sometimes referred to as feature selection and gives us another interpretation of pruning. One of Distiller's Jupyter notebooks explains how the \\(l_1\\)-norm regularizer induces sparsity, and how it interacts with \\(l_2\\)-norm regularization. If we configure weight_decay to zero and use \\(l_1\\)-norm regularization, then we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_1\n\\]\nIf we use both regularizers, we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_{R_2} \\lVert W \\rVert_2^2 + \\lambda_{R_1} \\lVert W \\rVert_1\n\\] Class distiller.L1Regularizer implements \\(l_1\\)-norm regularization, and of course, you can also schedule regularization. l1_regularizer = distiller.s(model.parameters())\n...\nloss = criterion(output, target) + lambda * l1_regularizer()", + "location": "/regularization/index.html#sparsity-and-regularization", + "text": "We mention regularization because there is an interesting interaction between regularization and some DNN sparsity-inducing methods. In Dense-Sparse-Dense (DSD) , Song Han et al. use pruning as a regularizer to improve a model's accuracy: \"Sparsity is a powerful form of regularization. Our intuition is that, once the network arrives at a local minimum given the sparsity constraint, relaxing the constraint gives the network more freedom to escape the saddle point and arrive at a higher-accuracy local minimum.\" Regularization can also be used to induce sparsity. To induce element-wise sparsity we can use the \\(l_1\\)-norm, \\(\\lVert W \\rVert_1\\).\n\\[\n\\lVert W \\rVert_1 = l_1(W) = \\sum_{i=1}^{|W|} |w_i|\n\\] \\(l_2\\)-norm regularization reduces overfitting and improves a model's accuracy by shrinking large parameters, but it does not force these parameters to absolute zero. \\(l_1\\)-norm regularization sets some of the parameter elements to zero, therefore limiting the model's capacity while making the model simpler. This is sometimes referred to as feature selection and gives us another interpretation of pruning. One of Distiller's Jupyter notebooks explains how the \\(l_1\\)-norm regularizer induces sparsity, and how it interacts with \\(l_2\\)-norm regularization. If we configure weight_decay to zero and use \\(l_1\\)-norm regularization, then we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R \\lVert W \\rVert_1\n\\]\nIf we use both regularizers, we have:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_{R_2} \\lVert W \\rVert_2^2 + \\lambda_{R_1} \\lVert W \\rVert_1\n\\] Class distiller.L1Regularizer implements \\(l_1\\)-norm regularization, and of course, you can also schedule regularization. l1_regularizer = distiller.s(model.parameters())\n...\nloss = criterion(output, target) + lambda * l1_regularizer()", "title": "Sparsity and Regularization" - }, + }, { - "location": "/regularization/index.html#group-regularization", - "text": "In Group Regularization, we penalize entire groups of parameter elements, instead of individual elements. Therefore, entire groups are either sparsified (i.e. all of the group elements have a value of zero) or not. The group structures have to be pre-defined. To the data loss, and the element-wise regularization (if any), we can add group-wise regularization penalty. We represent all of the parameter groups in layer \\(l\\) as \\( W_l^{(G)} \\), and we add the penalty of all groups for all layers. It gets a bit messy, but not overly complicated:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W) + \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)})\n\\] Let's denote all of the weight elements in group \\(g\\) as \\(w^{(g)}\\). \\[\nR_g(w^{(g)}) = \\sum_{g=1}^{G} \\lVert w^{(g)} \\rVert_g = \\sum_{g=1}^{G} \\sum_{i=1}^{|w^{(g)}|} {(w_i^{(g)})}^2\n\\]\nwhere \\(w^{(g)} \\in w^{(l)} \\) and \\( |w^{(g)}| \\) is the number of elements in \\( w^{(g)} \\). \\( \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)}) \\) is called the Group Lasso regularizer. Much as in \\(l_1\\)-norm regularization we sum the magnitudes of all tensor elements, in Group Lasso we sum the magnitudes of element structures (i.e. groups). \nGroup Regularization is also called Block Regularization, Structured Regularization, or coarse-grained sparsity (remember that element-wise sparsity is sometimes referred to as fine-grained sparsity). Group sparsity exhibits regularity (i.e. its shape is regular), and therefore\nit can be beneficial to improve inference speed. Huizi-et-al-2017 provides an overview of some of the different groups: kernel, channel, filter, layers. Fiber structures such as matrix columns and rows, as well as various shaped structures (block sparsity), and even intra kernel strided sparsity can also be used. distiller.GroupLassoRegularizer currently implements most of these groups, and you can easily add new groups.", + "location": "/regularization/index.html#group-regularization", + "text": "In Group Regularization, we penalize entire groups of parameter elements, instead of individual elements. Therefore, entire groups are either sparsified (i.e. all of the group elements have a value of zero) or not. The group structures have to be pre-defined. To the data loss, and the element-wise regularization (if any), we can add group-wise regularization penalty. We represent all of the parameter groups in layer \\(l\\) as \\( W_l^{(G)} \\), and we add the penalty of all groups for all layers. It gets a bit messy, but not overly complicated:\n\\[\nloss(W;x;y) = loss_D(W;x;y) + \\lambda_R R(W) + \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)})\n\\] Let's denote all of the weight elements in group \\(g\\) as \\(w^{(g)}\\). \\[\nR_g(w^{(g)}) = \\sum_{g=1}^{G} \\lVert w^{(g)} \\rVert_g = \\sum_{g=1}^{G} \\sum_{i=1}^{|w^{(g)}|} {(w_i^{(g)})}^2\n\\]\nwhere \\(w^{(g)} \\in w^{(l)} \\) and \\( |w^{(g)}| \\) is the number of elements in \\( w^{(g)} \\). \\( \\lambda_g \\sum_{l=1}^{L} R_g(W_l^{(G)}) \\) is called the Group Lasso regularizer. Much as in \\(l_1\\)-norm regularization we sum the magnitudes of all tensor elements, in Group Lasso we sum the magnitudes of element structures (i.e. groups). \nGroup Regularization is also called Block Regularization, Structured Regularization, or coarse-grained sparsity (remember that element-wise sparsity is sometimes referred to as fine-grained sparsity). Group sparsity exhibits regularity (i.e. its shape is regular), and therefore\nit can be beneficial to improve inference speed. Huizi-et-al-2017 provides an overview of some of the different groups: kernel, channel, filter, layers. Fiber structures such as matrix columns and rows, as well as various shaped structures (block sparsity), and even intra kernel strided sparsity can also be used. distiller.GroupLassoRegularizer currently implements most of these groups, and you can easily add new groups.", "title": "Group Regularization" - }, + }, { - "location": "/regularization/index.html#references", - "text": "Ian Goodfellow and Yoshua Bengio and Aaron Courville .\n Deep Learning ,\n arXiv:1607.04381v2,\n 2017. Song Han, Jeff Pool, Sharan Narang, Huizi Mao, Enhao Gong, Shijian Tang, Erich Elsen, Peter Vajda, Manohar Paluri, John Tran, Bryan Catanzaro, William J. Dally .\n DSD: Dense-Sparse-Dense Training for Deep Neural Networks ,\n arXiv:1607.04381v2,\n 2017. Huizi Mao, Song Han, Jeff Pool, Wenshuo Li, Xingyu Liu, Yu Wang, William J. Dally .\n Exploring the Regularity of Sparse Structure in Convolutional Neural Networks ,\n arXiv:1705.08922v3,\n 2017. Sajid Anwar, Kyuyeon Hwang, and Wonyong Sung .\n Structured pruning of deep convolutional neural networks ,\n arXiv:1512.08571,\n 2015", + "location": "/regularization/index.html#references", + "text": "Ian Goodfellow and Yoshua Bengio and Aaron Courville .\n Deep Learning ,\n arXiv:1607.04381v2,\n 2017. Song Han, Jeff Pool, Sharan Narang, Huizi Mao, Enhao Gong, Shijian Tang, Erich Elsen, Peter Vajda, Manohar Paluri, John Tran, Bryan Catanzaro, William J. Dally .\n DSD: Dense-Sparse-Dense Training for Deep Neural Networks ,\n arXiv:1607.04381v2,\n 2017. Huizi Mao, Song Han, Jeff Pool, Wenshuo Li, Xingyu Liu, Yu Wang, William J. Dally .\n Exploring the Regularity of Sparse Structure in Convolutional Neural Networks ,\n arXiv:1705.08922v3,\n 2017. Sajid Anwar, Kyuyeon Hwang, and Wonyong Sung .\n Structured pruning of deep convolutional neural networks ,\n arXiv:1512.08571,\n 2015", "title": "References" - }, + }, { - "location": "/quantization/index.html", - "text": "Quantization\n\n\nQuantization refers to the process of reducing the number of bits that represent a number. In the context of deep learning, the predominant numerical format used for research and for deployment has so far been 32-bit floating point, or FP32. However, the desire for reduced bandwidth and compute requirements of deep learning models has driven research into using lower-precision numerical formats. It has been extensively demonstrated that weights and activations can be represented using 8-bit integers (or INT8) without incurring significant loss in accuracy. The use of even lower bit-widths, such as 4/2/1-bits, is an active field of research that has also shown great progress.\n\n\nNote that this discussion is on quantization only in the context of more efficient inference. Using lower-precision numerics for more efficient training is currently out of scope.\n\n\nMotivation: Overall Efficiency\n\n\nThe more obvious benefit from quantization is \nsignificantly reduced bandwidth and storage\n. For instance, using INT8 for weights and activations consumes 4x less overall bandwidth compared to FP32.\n\nAdditionally integer compute is \nfaster\n than floating point compute. It is also much more \narea and energy efficient\n: \n\n\n\n\n\n\n\n\nINT8 Operation\n\n\nEnergy Saving vs FP32\n\n\nArea Saving vs FP32\n\n\n\n\n\n\n\n\n\n\nAdd\n\n\n30x\n\n\n116x\n\n\n\n\n\n\nMultiply\n\n\n18.5x\n\n\n27x\n\n\n\n\n\n\n\n\n(\nDally, 2015\n)\n\n\nNote that very aggressive quantization can yield even more efficiency. If weights are binary (-1, 1) or ternary (-1, 0, 1 using 2-bits), then convolution and fully-connected layers can be computed with additions and subtractions only, removing multiplications completely. If activations are binary as well, then additions can also be removed, in favor of bitwise operations (\nRastegari et al., 2016\n).\n\n\nInteger vs. FP32\n\n\nThere are two main attributes when discussing a numerical format. The first is \ndynamic range\n, which refers to the range of representable numbers. The second one is how many values can be represented within the dynamic range, which in turn determines the \nprecision / resolution\n of the format (the distance between two numbers).\n\nFor all integer formats, the dynamic range is \n[-2^{n-1} .. 2^{n-1}-1]\n, where \nn\n is the number of bits. So for INT8 the range is \n[-128 .. 127]\n, and for INT4 it is \n[-16 .. 15]\n (we're limiting ourselves to signed integers for now). The number of representable values is \n2^n\n.\nContrast that with FP32, where the dynamic range is \n\\pm 3.4\\ x\\ 10^{38}\n, and approximately \n4.2\\ x\\ 10^9\n values can be represented.\n\nWe can immediately see that FP32 is much more \nversatile\n, in that it is able to represent a wide range of distributions accurately. This is a nice property for deep learning models, where the distributions of weights and activations are usually very different (at least in dynamic range). In addition the dynamic range can differ between layers in the model.\n\nIn order to be able to represent these different distributions with an integer format, a \nscale factor\n is used to map the dynamic range of the tensor to the integer format range. But still we remain with the issue of having a significantly lower number of representable values, that is - much lower resolution.\n\nNote that this scale factor is, in most cases, a floating-point number. Hence, even when using integer numerics, some floating-point computations remain. \nCourbariaux et al., 2014\n scale using only shifts, eliminating the floating point operation. In \nGEMMLWOP\n, the FP32 scale factor is approximated using an integer or fixed-point multiplication followed by a shift operation. In many cases the effect of this approximation on accuracy is negligible.\n\n\nAvoiding Overflows\n\n\nConvolution and fully connected layers involve the storing of intermediate results in accumulators. Due to the limited dynamic range of integer formats, if we would use the same bit-width for the weights and activation, \nand\n for the accumulators, we would likely overflow very quickly. Therefore, accumulators are usually implemented with higher bit-widths.\n\nThe result of multiplying two \nn\n-bit integers is, at most, a \n2n\n-bit number. In convolution layers, such multiplications are accumulated \nc\\cdot k^2\n times, where \nc\n is the number of input channels and \nk\n is the kernel width (assuming a square kernel). Hence, to avoid overflowing, the accumulator should be \n2n + M\n-bits wide, where M is at least \nlog_2(c\\cdot k^2)\n. In many cases 32-bit accumulators are used, however for INT4 and lower it might be possible to use less than 32 -bits, depending on the expected use cases and layer widths.\n\n\n\"Conservative\" Quantization: INT8\n\n\nIn many cases, taking a model trained for FP32 and directly quantizing it to INT8, without any re-training, can result in a relatively low loss of accuracy (which may or may not be acceptable, depending on the use case). Some fine-tuning can further improve the accuracy (\nGysel at al., 2018\n).\n\nAs mentioned above, a scale factor is used to adapt the dynamic range of the tensor at hand to that of the integer format. This scale factor needs to be calculated per-layer per-tensor. The simplest way is to map the min/max values of the float tensor to the min/max of the integer format. For weights and biases this is easy, as they are set once training is complete. For activations, the min/max float values can be obtained \"online\" during inference, or \"offline\".\n\n\n\n\nOffline\n means gathering activations statistics before deploying the model, either during training or by running a few \"calibration\" batches on the trained FP32 model. Based on these gathered statistics, the scaled factors are calculated and are fixed once the model is deployed. This method has the risk of encountering values outside the previously observed ranges at runtime. These values will be clipped, which might lead to accuracy degradation.\n\n\nOnline\n means calculating the min/max values for each tensor dynamically during runtime. In this method clipping cannot occur, however the added computation resources required to calculate the min/max values at runtime might be prohibitive.\n\n\n\n\nIt is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. Statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible (\nMigacz, 2017\n). \n\n\nAnother possible optimization point is \nscale-factor scope\n. The most common way is use a single scale-factor per-layer, but it is also possible to calculate a scale-factor per-channel. This can be beneficial if the weight distributions vary greatly between channels.\n\n\n\"Aggressive\" Quantization: INT4 and Lower\n\n\nNaively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy:\n\n\n\n\nTraining / Re-Training\n: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the \nnext section\n.\n\n\nZhou S et al., 2016\n have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods \nrequire\n a trained FP32 model, either as a starting point (\nZhou A et al., 2017\n), or as a teacher network in a knowledge distillation training setup (see \nhere\n).\n\n\nReplacing the activation function\n: The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used (\nZhou S et al., 2016\n, \nMishra et al., 2018\n). Another method learns the clipping value per layer, with better results (\nChoi et al., 2018\n). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above).\n\n\nModifying network structure\n: \nMishra et al., 2018\n try to compensate for the loss of information due to quantization by using wider layers (more channels). \nLin et al., 2017\n proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different \"base\", covering a larger dynamic range overall.\n\n\nFirst and last layer\n: Many methods do not quantize the first and last layer of the model. It has been observed by \nHan et al., 2015\n that the first convolutional layer is more sensitive to weights pruning, and some quantization works cite the same reason and show it empirically (\nZhou S et al., 2016\n, \nChoi et al., 2018\n). Some works also note that these layers usually constitute a very small portion of the overall computation within the model, further reducing the motivation to quantize them (\nRastegari et al., 2016\n). Most methods keep the first and last layers at FP32. However, \nChoi et al., 2018\n showed that \"conservative\" quantization of these layers, e.g. to INT8, does not reduce accuracy.\n\n\nIterative quantization\n: Most methods quantize the entire model at once. \nZhou A et al., 2017\n employ an iterative method, which starts with a trained FP32 baseline, and quantizes only a portion of the model at the time followed by several epochs of re-training to recover the accuracy loss from quantization.\n\n\nMixed Weights and Activations Precision\n: It has been observed that activations are more sensitive to quantization than weights (\nZhou S et al., 2016\n). Hence it is not uncommon to see experiments with activations quantized to a higher precision compared to weights. Some works have focused solely on quantizing weights, keeping the activations at FP32 (\nLi et al., 2016\n, \nZhu et al., 2016\n).\n\n\n\n\nTraining with Quantization\n\n\nAs mentioned above, in order to minimize the loss of accuracy from \"aggressive\" quantization, many methods that target INT4 and lower involve training the model in a way that considers the quantization. This means training with quantization of weights and activations \"baked\" into the training procedure. The training graph usually looks like this:\n\n\n\n\nA full precision copy of the weights is maintained throughout the training process (\"weights_fp\" in the diagram). Its purpose is to accumulate the small changes from the gradients without loss of precision (Note that the quantization of the weights is an integral part of the training graph, meaning that we back-propagate through it as well). Once the model is trained, only the quantized weights are used for inference.\n\nIn the diagram we show \"layer N\" as the conv + batch-norm + activation combination, but the same applies to fully-connected layers, element-wise operations, etc. During training, the operations within \"layer N\" can still run in full precision, with the \"quantize\" operations in the boundaries ensuring discrete-valued weights and activations. This is sometimes called \"simulated quantization\". \n\n\nStraight-Through Estimator\n\n\nAn important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severly hinder the learning process. An approximation commonly used to overcome this issue is the \"straight-through estimator\" (STE) (\nHinton et al., 2012\n, \nBengio, 2013\n), which simply passes the gradient through these functions as-is. \n\n\nReferences\n\n\n\n\nWilliam Dally\n. High-Performance Hardware for Machine Learning. \nTutorial, NIPS, 2015\n\n\n\n\n\nMohammad Rastegari, Vicente Ordone, Joseph Redmon and Ali Farhadi\n. XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks. \nECCV, 2016\n\n\n\n\n\nMatthieu Courbariaux, Yoshua Bengio and Jean-Pierre David\n. Training deep neural networks with low precision multiplications. \narxiv:1412.7024\n\n\n\n\n\nPhilipp Gysel, Jon Pimentel, Mohammad Motamedi and Soheil Ghiasi\n. Ristretto: A Framework for Empirical Study of Resource-Efficient Inference in Convolutional Neural Networks. \nIEEE Transactions on Neural Networks and Learning Systems, 2018\n\n\n\n\n\nSzymon Migacz\n. 8-bit Inference with TensorRT. \nGTC San Jose, 2017\n\n\n\n\n\nShuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu and Yuheng Zou\n. DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. \narxiv:1606.06160\n\n\n\n\n\nAojun Zhou, Anbang Yao, Yiwen Guo, Lin Xu and Yurong Chen\n. Incremental Network Quantization: Towards Lossless CNNs with Low-precision Weights. \nICLR, 2017\n\n\n\n\n\nAsit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr\n. WRPN: Wide Reduced-Precision Networks. \nICLR, 2018\n\n\n\n\n\nJungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan\n. PACT: Parameterized Clipping Activation for Quantized Neural Networks. \n2018\n\n\n\n\n\nXiaofan Lin, Cong Zhao and Wei Pan\n. Towards Accurate Binary Convolutional Neural Network. \nNIPS, 2017\n\n\n\n\n\nSong Han, Jeff Pool, John Tran and William Dally\n. Learning both Weights and Connections for Efficient Neural Network. \nNIPS, 2015\n\n\n\n\n\nFengfu Li, Bo Zhang and Bin Liu\n. Ternary Weight Networks. \narxiv:1605.04711\n\n\n\n\n\nChenzhuo Zhu, Song Han, Huizi Mao and William J. Dally\n. Trained Ternary Quantization. \narxiv:1612.01064\n\n\n\n\n\nYoshua Bengio, Nicholas Leonard and Aaron Courville\n. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. \narxiv:1308.3432, 2013\n\n\n\n\n\nGeoffrey Hinton, Nitish Srivastava, Kevin Swersky, Tijmen Tieleman and Abdelrahman Mohamed\n. Neural Networks for Machine Learning. \nCoursera, video lectures, 2012", + "location": "/quantization/index.html", + "text": "Quantization\n\n\nQuantization refers to the process of reducing the number of bits that represent a number. In the context of deep learning, the predominant numerical format used for research and for deployment has so far been 32-bit floating point, or FP32. However, the desire for reduced bandwidth and compute requirements of deep learning models has driven research into using lower-precision numerical formats. It has been extensively demonstrated that weights and activations can be represented using 8-bit integers (or INT8) without incurring significant loss in accuracy. The use of even lower bit-widths, such as 4/2/1-bits, is an active field of research that has also shown great progress.\n\n\nNote that this discussion is on quantization only in the context of more efficient inference. Using lower-precision numerics for more efficient training is currently out of scope.\n\n\nMotivation: Overall Efficiency\n\n\nThe more obvious benefit from quantization is \nsignificantly reduced bandwidth and storage\n. For instance, using INT8 for weights and activations consumes 4x less overall bandwidth compared to FP32.\n\nAdditionally integer compute is \nfaster\n than floating point compute. It is also much more \narea and energy efficient\n: \n\n\n\n\n\n\n\n\nINT8 Operation\n\n\nEnergy Saving vs FP32\n\n\nArea Saving vs FP32\n\n\n\n\n\n\n\n\n\n\nAdd\n\n\n30x\n\n\n116x\n\n\n\n\n\n\nMultiply\n\n\n18.5x\n\n\n27x\n\n\n\n\n\n\n\n\n(\nDally, 2015\n)\n\n\nNote that very aggressive quantization can yield even more efficiency. If weights are binary (-1, 1) or ternary (-1, 0, 1 using 2-bits), then convolution and fully-connected layers can be computed with additions and subtractions only, removing multiplications completely. If activations are binary as well, then additions can also be removed, in favor of bitwise operations (\nRastegari et al., 2016\n).\n\n\nInteger vs. FP32\n\n\nThere are two main attributes when discussing a numerical format. The first is \ndynamic range\n, which refers to the range of representable numbers. The second one is how many values can be represented within the dynamic range, which in turn determines the \nprecision / resolution\n of the format (the distance between two numbers).\n\nFor all integer formats, the dynamic range is \n[-2^{n-1} .. 2^{n-1}-1]\n, where \nn\n is the number of bits. So for INT8 the range is \n[-128 .. 127]\n, and for INT4 it is \n[-16 .. 15]\n (we're limiting ourselves to signed integers for now). The number of representable values is \n2^n\n.\nContrast that with FP32, where the dynamic range is \n\\pm 3.4\\ x\\ 10^{38}\n, and approximately \n4.2\\ x\\ 10^9\n values can be represented.\n\nWe can immediately see that FP32 is much more \nversatile\n, in that it is able to represent a wide range of distributions accurately. This is a nice property for deep learning models, where the distributions of weights and activations are usually very different (at least in dynamic range). In addition the dynamic range can differ between layers in the model.\n\nIn order to be able to represent these different distributions with an integer format, a \nscale factor\n is used to map the dynamic range of the tensor to the integer format range. But still we remain with the issue of having a significantly lower number of representable values, that is - much lower resolution.\n\nNote that this scale factor is, in most cases, a floating-point number. Hence, even when using integer numerics, some floating-point computations remain. \nCourbariaux et al., 2014\n scale using only shifts, eliminating the floating point operation. In \nGEMMLWOP\n, the FP32 scale factor is approximated using an integer or fixed-point multiplication followed by a shift operation. In many cases the effect of this approximation on accuracy is negligible.\n\n\nAvoiding Overflows\n\n\nConvolution and fully connected layers involve the storing of intermediate results in accumulators. Due to the limited dynamic range of integer formats, if we would use the same bit-width for the weights and activation, \nand\n for the accumulators, we would likely overflow very quickly. Therefore, accumulators are usually implemented with higher bit-widths.\n\nThe result of multiplying two \nn\n-bit integers is, at most, a \n2n\n-bit number. In convolution layers, such multiplications are accumulated \nc\\cdot k^2\n times, where \nc\n is the number of input channels and \nk\n is the kernel width (assuming a square kernel). Hence, to avoid overflowing, the accumulator should be \n2n + M\n-bits wide, where M is at least \nlog_2(c\\cdot k^2)\n. In many cases 32-bit accumulators are used, however for INT4 and lower it might be possible to use less than 32 -bits, depending on the expected use cases and layer widths.\n\n\n\"Conservative\" Quantization: INT8\n\n\nIn many cases, taking a model trained for FP32 and directly quantizing it to INT8, without any re-training, can result in a relatively low loss of accuracy (which may or may not be acceptable, depending on the use case). Some fine-tuning can further improve the accuracy (\nGysel at al., 2018\n).\n\nAs mentioned above, a scale factor is used to adapt the dynamic range of the tensor at hand to that of the integer format. This scale factor needs to be calculated per-layer per-tensor. The simplest way is to map the min/max values of the float tensor to the min/max of the integer format. For weights and biases this is easy, as they are set once training is complete. For activations, the min/max float values can be obtained \"online\" during inference, or \"offline\".\n\n\n\n\nOffline\n means gathering activations statistics before deploying the model, either during training or by running a few \"calibration\" batches on the trained FP32 model. Based on these gathered statistics, the scaled factors are calculated and are fixed once the model is deployed. This method has the risk of encountering values outside the previously observed ranges at runtime. These values will be clipped, which might lead to accuracy degradation.\n\n\nOnline\n means calculating the min/max values for each tensor dynamically during runtime. In this method clipping cannot occur, however the added computation resources required to calculate the min/max values at runtime might be prohibitive.\n\n\n\n\n\n\n\nIt is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. A simple method which can yield nice results is to simply use an average of the observed min/max values instead of the actual values. Alternatively, statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible (\nMigacz, 2017\n). Going further, \nBanner et al., 2018\n have proposed a method for analytically computing the clipping value under certain conditions.\n\n\nAnother possible optimization point is \nscale-factor scope\n. The most common way is use a single scale-factor per-layer, but it is also possible to calculate a scale-factor per-channel. This can be beneficial if the weight distributions vary greatly between channels.\n\n\nWhen used to directly quantize a model without re-training, as described so far, this method is commonly referred to as \npost-training quantization\n. However, recent publications have shown that there are cases where post-training quantization to INT8 doesn't preserve accuracy (\nBenoit et al., 2018\n, \nKrishnamoorthi, 2018\n). Namely, smaller models such as MobileNet seem to not respond as well to post-training quantization, presumabley due to their smaller representational capacity. In such cases, \nquantization-aware training\n is used.\n\n\n\"Aggressive\" Quantization: INT4 and Lower\n\n\nNaively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy:\n\n\n\n\nTraining / Re-Training\n: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the \nnext section\n.\n\n\nZhou S et al., 2016\n have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods \nrequire\n a trained FP32 model, either as a starting point (\nZhou A et al., 2017\n), or as a teacher network in a knowledge distillation training setup (see \nhere\n).\n\n\nReplacing the activation function\n: The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used (\nZhou S et al., 2016\n, \nMishra et al., 2018\n). Another method learns the clipping value per layer, with better results (\nChoi et al., 2018\n). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above).\n\n\nModifying network structure\n: \nMishra et al., 2018\n try to compensate for the loss of information due to quantization by using wider layers (more channels). \nLin et al., 2017\n proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different \"base\", covering a larger dynamic range overall.\n\n\nFirst and last layer\n: Many methods do not quantize the first and last layer of the model. It has been observed by \nHan et al., 2015\n that the first convolutional layer is more sensitive to weights pruning, and some quantization works cite the same reason and show it empirically (\nZhou S et al., 2016\n, \nChoi et al., 2018\n). Some works also note that these layers usually constitute a very small portion of the overall computation within the model, further reducing the motivation to quantize them (\nRastegari et al., 2016\n). Most methods keep the first and last layers at FP32. However, \nChoi et al., 2018\n showed that \"conservative\" quantization of these layers, e.g. to INT8, does not reduce accuracy.\n\n\nIterative quantization\n: Most methods quantize the entire model at once. \nZhou A et al., 2017\n employ an iterative method, which starts with a trained FP32 baseline, and quantizes only a portion of the model at the time followed by several epochs of re-training to recover the accuracy loss from quantization.\n\n\nMixed Weights and Activations Precision\n: It has been observed that activations are more sensitive to quantization than weights (\nZhou S et al., 2016\n). Hence it is not uncommon to see experiments with activations quantized to a higher precision compared to weights. Some works have focused solely on quantizing weights, keeping the activations at FP32 (\nLi et al., 2016\n, \nZhu et al., 2016\n).\n\n\n\n\nQuantization-Aware Training\n\n\nAs mentioned above, in order to minimize the loss of accuracy from \"aggressive\" quantization, many methods that target INT4 and lower (and in some cases for INT8 as well) involve training the model in a way that considers the quantization. This means training with quantization of weights and activations \"baked\" into the training procedure. The training graph usually looks like this:\n\n\n\n\nA full precision copy of the weights is maintained throughout the training process (\"weights_fp\" in the diagram). Its purpose is to accumulate the small changes from the gradients without loss of precision (Note that the quantization of the weights is an integral part of the training graph, meaning that we back-propagate through it as well). Once the model is trained, only the quantized weights are used for inference.\n\nIn the diagram we show \"layer N\" as the conv + batch-norm + activation combination, but the same applies to fully-connected layers, element-wise operations, etc. During training, the operations within \"layer N\" can still run in full precision, with the \"quantize\" operations in the boundaries ensuring discrete-valued weights and activations. This is sometimes called \"simulated quantization\". \n\n\nStraight-Through Estimator\n\n\nAn important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severely hinder the learning process. An approximation commonly used to overcome this issue is the \"straight-through estimator\" (STE) (\nHinton et al., 2012\n, \nBengio, 2013\n), which simply passes the gradient through these functions as-is. \n\n\nReferences\n\n\n\n\nWilliam Dally\n. High-Performance Hardware for Machine Learning. \nTutorial, NIPS, 2015\n\n\n\n\n\nMohammad Rastegari, Vicente Ordone, Joseph Redmon and Ali Farhadi\n. XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks. \nECCV, 2016\n\n\n\n\n\nMatthieu Courbariaux, Yoshua Bengio and Jean-Pierre David\n. Training deep neural networks with low precision multiplications. \narxiv:1412.7024\n\n\n\n\n\nPhilipp Gysel, Jon Pimentel, Mohammad Motamedi and Soheil Ghiasi\n. Ristretto: A Framework for Empirical Study of Resource-Efficient Inference in Convolutional Neural Networks. \nIEEE Transactions on Neural Networks and Learning Systems, 2018\n\n\n\n\n\nSzymon Migacz\n. 8-bit Inference with TensorRT. \nGTC San Jose, 2017\n\n\n\n\n\nShuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu and Yuheng Zou\n. DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. \narxiv:1606.06160\n\n\n\n\n\nAojun Zhou, Anbang Yao, Yiwen Guo, Lin Xu and Yurong Chen\n. Incremental Network Quantization: Towards Lossless CNNs with Low-precision Weights. \nICLR, 2017\n\n\n\n\n\nAsit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr\n. WRPN: Wide Reduced-Precision Networks. \nICLR, 2018\n\n\n\n\n\nJungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan\n. PACT: Parameterized Clipping Activation for Quantized Neural Networks. \narxiv:1805.06085\n\n\n\n\n\nXiaofan Lin, Cong Zhao and Wei Pan\n. Towards Accurate Binary Convolutional Neural Network. \nNIPS, 2017\n\n\n\n\n\nSong Han, Jeff Pool, John Tran and William Dally\n. Learning both Weights and Connections for Efficient Neural Network. \nNIPS, 2015\n\n\n\n\n\nFengfu Li, Bo Zhang and Bin Liu\n. Ternary Weight Networks. \narxiv:1605.04711\n\n\n\n\n\nChenzhuo Zhu, Song Han, Huizi Mao and William J. Dally\n. Trained Ternary Quantization. \narxiv:1612.01064\n\n\n\n\n\nYoshua Bengio, Nicholas Leonard and Aaron Courville\n. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. \narxiv:1308.3432\n\n\n\n\n\nGeoffrey Hinton, Nitish Srivastava, Kevin Swersky, Tijmen Tieleman and Abdelrahman Mohamed\n. Neural Networks for Machine Learning. \nCoursera, video lectures, 2012\n\n\n\n\n\nBenoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam and Dmitry Kalenichenko\n. Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. \nECCV, 2018\n\n\n\n\n\nRaghuraman Krishnamoorthi\n. Quantizing deep convolutional networks for efficient inference: A whitepaper \narxiv:1806.08342\n\n\n\n\n\nRon Banner, Yury Nahshan, Elad Hoffer and Daniel Soudry\n. ACIQ: Analytical Clipping for Integer Quantization of neural networks \narxiv:1810.05723", "title": "Quantization" - }, + }, { - "location": "/quantization/index.html#quantization", - "text": "Quantization refers to the process of reducing the number of bits that represent a number. In the context of deep learning, the predominant numerical format used for research and for deployment has so far been 32-bit floating point, or FP32. However, the desire for reduced bandwidth and compute requirements of deep learning models has driven research into using lower-precision numerical formats. It has been extensively demonstrated that weights and activations can be represented using 8-bit integers (or INT8) without incurring significant loss in accuracy. The use of even lower bit-widths, such as 4/2/1-bits, is an active field of research that has also shown great progress. Note that this discussion is on quantization only in the context of more efficient inference. Using lower-precision numerics for more efficient training is currently out of scope.", + "location": "/quantization/index.html#quantization", + "text": "Quantization refers to the process of reducing the number of bits that represent a number. In the context of deep learning, the predominant numerical format used for research and for deployment has so far been 32-bit floating point, or FP32. However, the desire for reduced bandwidth and compute requirements of deep learning models has driven research into using lower-precision numerical formats. It has been extensively demonstrated that weights and activations can be represented using 8-bit integers (or INT8) without incurring significant loss in accuracy. The use of even lower bit-widths, such as 4/2/1-bits, is an active field of research that has also shown great progress. Note that this discussion is on quantization only in the context of more efficient inference. Using lower-precision numerics for more efficient training is currently out of scope.", "title": "Quantization" - }, + }, { - "location": "/quantization/index.html#motivation-overall-efficiency", - "text": "The more obvious benefit from quantization is significantly reduced bandwidth and storage . For instance, using INT8 for weights and activations consumes 4x less overall bandwidth compared to FP32. \nAdditionally integer compute is faster than floating point compute. It is also much more area and energy efficient : INT8 Operation Energy Saving vs FP32 Area Saving vs FP32 Add 30x 116x Multiply 18.5x 27x ( Dally, 2015 ) Note that very aggressive quantization can yield even more efficiency. If weights are binary (-1, 1) or ternary (-1, 0, 1 using 2-bits), then convolution and fully-connected layers can be computed with additions and subtractions only, removing multiplications completely. If activations are binary as well, then additions can also be removed, in favor of bitwise operations ( Rastegari et al., 2016 ).", + "location": "/quantization/index.html#motivation-overall-efficiency", + "text": "The more obvious benefit from quantization is significantly reduced bandwidth and storage . For instance, using INT8 for weights and activations consumes 4x less overall bandwidth compared to FP32. \nAdditionally integer compute is faster than floating point compute. It is also much more area and energy efficient : INT8 Operation Energy Saving vs FP32 Area Saving vs FP32 Add 30x 116x Multiply 18.5x 27x ( Dally, 2015 ) Note that very aggressive quantization can yield even more efficiency. If weights are binary (-1, 1) or ternary (-1, 0, 1 using 2-bits), then convolution and fully-connected layers can be computed with additions and subtractions only, removing multiplications completely. If activations are binary as well, then additions can also be removed, in favor of bitwise operations ( Rastegari et al., 2016 ).", "title": "Motivation: Overall Efficiency" - }, + }, { - "location": "/quantization/index.html#integer-vs-fp32", - "text": "There are two main attributes when discussing a numerical format. The first is dynamic range , which refers to the range of representable numbers. The second one is how many values can be represented within the dynamic range, which in turn determines the precision / resolution of the format (the distance between two numbers). \nFor all integer formats, the dynamic range is [-2^{n-1} .. 2^{n-1}-1] , where n is the number of bits. So for INT8 the range is [-128 .. 127] , and for INT4 it is [-16 .. 15] (we're limiting ourselves to signed integers for now). The number of representable values is 2^n .\nContrast that with FP32, where the dynamic range is \\pm 3.4\\ x\\ 10^{38} , and approximately 4.2\\ x\\ 10^9 values can be represented. \nWe can immediately see that FP32 is much more versatile , in that it is able to represent a wide range of distributions accurately. This is a nice property for deep learning models, where the distributions of weights and activations are usually very different (at least in dynamic range). In addition the dynamic range can differ between layers in the model. \nIn order to be able to represent these different distributions with an integer format, a scale factor is used to map the dynamic range of the tensor to the integer format range. But still we remain with the issue of having a significantly lower number of representable values, that is - much lower resolution. \nNote that this scale factor is, in most cases, a floating-point number. Hence, even when using integer numerics, some floating-point computations remain. Courbariaux et al., 2014 scale using only shifts, eliminating the floating point operation. In GEMMLWOP , the FP32 scale factor is approximated using an integer or fixed-point multiplication followed by a shift operation. In many cases the effect of this approximation on accuracy is negligible.", + "location": "/quantization/index.html#integer-vs-fp32", + "text": "There are two main attributes when discussing a numerical format. The first is dynamic range , which refers to the range of representable numbers. The second one is how many values can be represented within the dynamic range, which in turn determines the precision / resolution of the format (the distance between two numbers). \nFor all integer formats, the dynamic range is [-2^{n-1} .. 2^{n-1}-1] , where n is the number of bits. So for INT8 the range is [-128 .. 127] , and for INT4 it is [-16 .. 15] (we're limiting ourselves to signed integers for now). The number of representable values is 2^n .\nContrast that with FP32, where the dynamic range is \\pm 3.4\\ x\\ 10^{38} , and approximately 4.2\\ x\\ 10^9 values can be represented. \nWe can immediately see that FP32 is much more versatile , in that it is able to represent a wide range of distributions accurately. This is a nice property for deep learning models, where the distributions of weights and activations are usually very different (at least in dynamic range). In addition the dynamic range can differ between layers in the model. \nIn order to be able to represent these different distributions with an integer format, a scale factor is used to map the dynamic range of the tensor to the integer format range. But still we remain with the issue of having a significantly lower number of representable values, that is - much lower resolution. \nNote that this scale factor is, in most cases, a floating-point number. Hence, even when using integer numerics, some floating-point computations remain. Courbariaux et al., 2014 scale using only shifts, eliminating the floating point operation. In GEMMLWOP , the FP32 scale factor is approximated using an integer or fixed-point multiplication followed by a shift operation. In many cases the effect of this approximation on accuracy is negligible.", "title": "Integer vs. FP32" - }, + }, { - "location": "/quantization/index.html#avoiding-overflows", - "text": "Convolution and fully connected layers involve the storing of intermediate results in accumulators. Due to the limited dynamic range of integer formats, if we would use the same bit-width for the weights and activation, and for the accumulators, we would likely overflow very quickly. Therefore, accumulators are usually implemented with higher bit-widths. \nThe result of multiplying two n -bit integers is, at most, a 2n -bit number. In convolution layers, such multiplications are accumulated c\\cdot k^2 times, where c is the number of input channels and k is the kernel width (assuming a square kernel). Hence, to avoid overflowing, the accumulator should be 2n + M -bits wide, where M is at least log_2(c\\cdot k^2) . In many cases 32-bit accumulators are used, however for INT4 and lower it might be possible to use less than 32 -bits, depending on the expected use cases and layer widths.", + "location": "/quantization/index.html#avoiding-overflows", + "text": "Convolution and fully connected layers involve the storing of intermediate results in accumulators. Due to the limited dynamic range of integer formats, if we would use the same bit-width for the weights and activation, and for the accumulators, we would likely overflow very quickly. Therefore, accumulators are usually implemented with higher bit-widths. \nThe result of multiplying two n -bit integers is, at most, a 2n -bit number. In convolution layers, such multiplications are accumulated c\\cdot k^2 times, where c is the number of input channels and k is the kernel width (assuming a square kernel). Hence, to avoid overflowing, the accumulator should be 2n + M -bits wide, where M is at least log_2(c\\cdot k^2) . In many cases 32-bit accumulators are used, however for INT4 and lower it might be possible to use less than 32 -bits, depending on the expected use cases and layer widths.", "title": "Avoiding Overflows" - }, + }, { - "location": "/quantization/index.html#conservative-quantization-int8", - "text": "In many cases, taking a model trained for FP32 and directly quantizing it to INT8, without any re-training, can result in a relatively low loss of accuracy (which may or may not be acceptable, depending on the use case). Some fine-tuning can further improve the accuracy ( Gysel at al., 2018 ). \nAs mentioned above, a scale factor is used to adapt the dynamic range of the tensor at hand to that of the integer format. This scale factor needs to be calculated per-layer per-tensor. The simplest way is to map the min/max values of the float tensor to the min/max of the integer format. For weights and biases this is easy, as they are set once training is complete. For activations, the min/max float values can be obtained \"online\" during inference, or \"offline\". Offline means gathering activations statistics before deploying the model, either during training or by running a few \"calibration\" batches on the trained FP32 model. Based on these gathered statistics, the scaled factors are calculated and are fixed once the model is deployed. This method has the risk of encountering values outside the previously observed ranges at runtime. These values will be clipped, which might lead to accuracy degradation. Online means calculating the min/max values for each tensor dynamically during runtime. In this method clipping cannot occur, however the added computation resources required to calculate the min/max values at runtime might be prohibitive. It is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. Statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible ( Migacz, 2017 ). Another possible optimization point is scale-factor scope . The most common way is use a single scale-factor per-layer, but it is also possible to calculate a scale-factor per-channel. This can be beneficial if the weight distributions vary greatly between channels.", + "location": "/quantization/index.html#conservative-quantization-int8", + "text": "In many cases, taking a model trained for FP32 and directly quantizing it to INT8, without any re-training, can result in a relatively low loss of accuracy (which may or may not be acceptable, depending on the use case). Some fine-tuning can further improve the accuracy ( Gysel at al., 2018 ). \nAs mentioned above, a scale factor is used to adapt the dynamic range of the tensor at hand to that of the integer format. This scale factor needs to be calculated per-layer per-tensor. The simplest way is to map the min/max values of the float tensor to the min/max of the integer format. For weights and biases this is easy, as they are set once training is complete. For activations, the min/max float values can be obtained \"online\" during inference, or \"offline\". Offline means gathering activations statistics before deploying the model, either during training or by running a few \"calibration\" batches on the trained FP32 model. Based on these gathered statistics, the scaled factors are calculated and are fixed once the model is deployed. This method has the risk of encountering values outside the previously observed ranges at runtime. These values will be clipped, which might lead to accuracy degradation. Online means calculating the min/max values for each tensor dynamically during runtime. In this method clipping cannot occur, however the added computation resources required to calculate the min/max values at runtime might be prohibitive. It is important to note, however, that the full float range of an activations tensor usually includes elements which are statistically outliers. These values can be discarded by using a narrower min/max range, effectively allowing some clipping to occur in favor of increasing the resolution provided to the part of the distribution containing most of the information. A simple method which can yield nice results is to simply use an average of the observed min/max values instead of the actual values. Alternatively, statistical measures can be used to intelligently select where to clip the original range in order to preserve as much information as possible ( Migacz, 2017 ). Going further, Banner et al., 2018 have proposed a method for analytically computing the clipping value under certain conditions. Another possible optimization point is scale-factor scope . The most common way is use a single scale-factor per-layer, but it is also possible to calculate a scale-factor per-channel. This can be beneficial if the weight distributions vary greatly between channels. When used to directly quantize a model without re-training, as described so far, this method is commonly referred to as post-training quantization . However, recent publications have shown that there are cases where post-training quantization to INT8 doesn't preserve accuracy ( Benoit et al., 2018 , Krishnamoorthi, 2018 ). Namely, smaller models such as MobileNet seem to not respond as well to post-training quantization, presumabley due to their smaller representational capacity. In such cases, quantization-aware training is used.", "title": "\"Conservative\" Quantization: INT8" - }, + }, { - "location": "/quantization/index.html#aggressive-quantization-int4-and-lower", - "text": "Naively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy: Training / Re-Training : For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the next section . Zhou S et al., 2016 have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods require a trained FP32 model, either as a starting point ( Zhou A et al., 2017 ), or as a teacher network in a knowledge distillation training setup (see here ). Replacing the activation function : The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used ( Zhou S et al., 2016 , Mishra et al., 2018 ). Another method learns the clipping value per layer, with better results ( Choi et al., 2018 ). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above). Modifying network structure : Mishra et al., 2018 try to compensate for the loss of information due to quantization by using wider layers (more channels). Lin et al., 2017 proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different \"base\", covering a larger dynamic range overall. First and last layer : Many methods do not quantize the first and last layer of the model. It has been observed by Han et al., 2015 that the first convolutional layer is more sensitive to weights pruning, and some quantization works cite the same reason and show it empirically ( Zhou S et al., 2016 , Choi et al., 2018 ). Some works also note that these layers usually constitute a very small portion of the overall computation within the model, further reducing the motivation to quantize them ( Rastegari et al., 2016 ). Most methods keep the first and last layers at FP32. However, Choi et al., 2018 showed that \"conservative\" quantization of these layers, e.g. to INT8, does not reduce accuracy. Iterative quantization : Most methods quantize the entire model at once. Zhou A et al., 2017 employ an iterative method, which starts with a trained FP32 baseline, and quantizes only a portion of the model at the time followed by several epochs of re-training to recover the accuracy loss from quantization. Mixed Weights and Activations Precision : It has been observed that activations are more sensitive to quantization than weights ( Zhou S et al., 2016 ). Hence it is not uncommon to see experiments with activations quantized to a higher precision compared to weights. Some works have focused solely on quantizing weights, keeping the activations at FP32 ( Li et al., 2016 , Zhu et al., 2016 ).", + "location": "/quantization/index.html#aggressive-quantization-int4-and-lower", + "text": "Naively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy: Training / Re-Training : For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the next section . Zhou S et al., 2016 have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods require a trained FP32 model, either as a starting point ( Zhou A et al., 2017 ), or as a teacher network in a knowledge distillation training setup (see here ). Replacing the activation function : The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used ( Zhou S et al., 2016 , Mishra et al., 2018 ). Another method learns the clipping value per layer, with better results ( Choi et al., 2018 ). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above). Modifying network structure : Mishra et al., 2018 try to compensate for the loss of information due to quantization by using wider layers (more channels). Lin et al., 2017 proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different \"base\", covering a larger dynamic range overall. First and last layer : Many methods do not quantize the first and last layer of the model. It has been observed by Han et al., 2015 that the first convolutional layer is more sensitive to weights pruning, and some quantization works cite the same reason and show it empirically ( Zhou S et al., 2016 , Choi et al., 2018 ). Some works also note that these layers usually constitute a very small portion of the overall computation within the model, further reducing the motivation to quantize them ( Rastegari et al., 2016 ). Most methods keep the first and last layers at FP32. However, Choi et al., 2018 showed that \"conservative\" quantization of these layers, e.g. to INT8, does not reduce accuracy. Iterative quantization : Most methods quantize the entire model at once. Zhou A et al., 2017 employ an iterative method, which starts with a trained FP32 baseline, and quantizes only a portion of the model at the time followed by several epochs of re-training to recover the accuracy loss from quantization. Mixed Weights and Activations Precision : It has been observed that activations are more sensitive to quantization than weights ( Zhou S et al., 2016 ). Hence it is not uncommon to see experiments with activations quantized to a higher precision compared to weights. Some works have focused solely on quantizing weights, keeping the activations at FP32 ( Li et al., 2016 , Zhu et al., 2016 ).", "title": "\"Aggressive\" Quantization: INT4 and Lower" - }, + }, { - "location": "/quantization/index.html#training-with-quantization", - "text": "As mentioned above, in order to minimize the loss of accuracy from \"aggressive\" quantization, many methods that target INT4 and lower involve training the model in a way that considers the quantization. This means training with quantization of weights and activations \"baked\" into the training procedure. The training graph usually looks like this: A full precision copy of the weights is maintained throughout the training process (\"weights_fp\" in the diagram). Its purpose is to accumulate the small changes from the gradients without loss of precision (Note that the quantization of the weights is an integral part of the training graph, meaning that we back-propagate through it as well). Once the model is trained, only the quantized weights are used for inference. \nIn the diagram we show \"layer N\" as the conv + batch-norm + activation combination, but the same applies to fully-connected layers, element-wise operations, etc. During training, the operations within \"layer N\" can still run in full precision, with the \"quantize\" operations in the boundaries ensuring discrete-valued weights and activations. This is sometimes called \"simulated quantization\".", - "title": "Training with Quantization" - }, + "location": "/quantization/index.html#quantization-aware-training", + "text": "As mentioned above, in order to minimize the loss of accuracy from \"aggressive\" quantization, many methods that target INT4 and lower (and in some cases for INT8 as well) involve training the model in a way that considers the quantization. This means training with quantization of weights and activations \"baked\" into the training procedure. The training graph usually looks like this: A full precision copy of the weights is maintained throughout the training process (\"weights_fp\" in the diagram). Its purpose is to accumulate the small changes from the gradients without loss of precision (Note that the quantization of the weights is an integral part of the training graph, meaning that we back-propagate through it as well). Once the model is trained, only the quantized weights are used for inference. \nIn the diagram we show \"layer N\" as the conv + batch-norm + activation combination, but the same applies to fully-connected layers, element-wise operations, etc. During training, the operations within \"layer N\" can still run in full precision, with the \"quantize\" operations in the boundaries ensuring discrete-valued weights and activations. This is sometimes called \"simulated quantization\".", + "title": "Quantization-Aware Training" + }, { - "location": "/quantization/index.html#straight-through-estimator", - "text": "An important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severly hinder the learning process. An approximation commonly used to overcome this issue is the \"straight-through estimator\" (STE) ( Hinton et al., 2012 , Bengio, 2013 ), which simply passes the gradient through these functions as-is.", + "location": "/quantization/index.html#straight-through-estimator", + "text": "An important question in this context is how to back-propagate through the quantization functions. These functions are discrete-valued, hence their derivative is 0 almost everywhere. So, using their gradients as-is would severely hinder the learning process. An approximation commonly used to overcome this issue is the \"straight-through estimator\" (STE) ( Hinton et al., 2012 , Bengio, 2013 ), which simply passes the gradient through these functions as-is.", "title": "Straight-Through Estimator" - }, + }, { - "location": "/quantization/index.html#references", - "text": "William Dally . High-Performance Hardware for Machine Learning. Tutorial, NIPS, 2015 Mohammad Rastegari, Vicente Ordone, Joseph Redmon and Ali Farhadi . XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks. ECCV, 2016 Matthieu Courbariaux, Yoshua Bengio and Jean-Pierre David . Training deep neural networks with low precision multiplications. arxiv:1412.7024 Philipp Gysel, Jon Pimentel, Mohammad Motamedi and Soheil Ghiasi . Ristretto: A Framework for Empirical Study of Resource-Efficient Inference in Convolutional Neural Networks. IEEE Transactions on Neural Networks and Learning Systems, 2018 Szymon Migacz . 8-bit Inference with TensorRT. GTC San Jose, 2017 Shuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu and Yuheng Zou . DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. arxiv:1606.06160 Aojun Zhou, Anbang Yao, Yiwen Guo, Lin Xu and Yurong Chen . Incremental Network Quantization: Towards Lossless CNNs with Low-precision Weights. ICLR, 2017 Asit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr . WRPN: Wide Reduced-Precision Networks. ICLR, 2018 Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan . PACT: Parameterized Clipping Activation for Quantized Neural Networks. 2018 Xiaofan Lin, Cong Zhao and Wei Pan . Towards Accurate Binary Convolutional Neural Network. NIPS, 2017 Song Han, Jeff Pool, John Tran and William Dally . Learning both Weights and Connections for Efficient Neural Network. NIPS, 2015 Fengfu Li, Bo Zhang and Bin Liu . Ternary Weight Networks. arxiv:1605.04711 Chenzhuo Zhu, Song Han, Huizi Mao and William J. Dally . Trained Ternary Quantization. arxiv:1612.01064 Yoshua Bengio, Nicholas Leonard and Aaron Courville . Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. arxiv:1308.3432, 2013 Geoffrey Hinton, Nitish Srivastava, Kevin Swersky, Tijmen Tieleman and Abdelrahman Mohamed . Neural Networks for Machine Learning. Coursera, video lectures, 2012", + "location": "/quantization/index.html#references", + "text": "William Dally . High-Performance Hardware for Machine Learning. Tutorial, NIPS, 2015 Mohammad Rastegari, Vicente Ordone, Joseph Redmon and Ali Farhadi . XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks. ECCV, 2016 Matthieu Courbariaux, Yoshua Bengio and Jean-Pierre David . Training deep neural networks with low precision multiplications. arxiv:1412.7024 Philipp Gysel, Jon Pimentel, Mohammad Motamedi and Soheil Ghiasi . Ristretto: A Framework for Empirical Study of Resource-Efficient Inference in Convolutional Neural Networks. IEEE Transactions on Neural Networks and Learning Systems, 2018 Szymon Migacz . 8-bit Inference with TensorRT. GTC San Jose, 2017 Shuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu and Yuheng Zou . DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. arxiv:1606.06160 Aojun Zhou, Anbang Yao, Yiwen Guo, Lin Xu and Yurong Chen . Incremental Network Quantization: Towards Lossless CNNs with Low-precision Weights. ICLR, 2017 Asit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr . WRPN: Wide Reduced-Precision Networks. ICLR, 2018 Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan and Kailash Gopalakrishnan . PACT: Parameterized Clipping Activation for Quantized Neural Networks. arxiv:1805.06085 Xiaofan Lin, Cong Zhao and Wei Pan . Towards Accurate Binary Convolutional Neural Network. NIPS, 2017 Song Han, Jeff Pool, John Tran and William Dally . Learning both Weights and Connections for Efficient Neural Network. NIPS, 2015 Fengfu Li, Bo Zhang and Bin Liu . Ternary Weight Networks. arxiv:1605.04711 Chenzhuo Zhu, Song Han, Huizi Mao and William J. Dally . Trained Ternary Quantization. arxiv:1612.01064 Yoshua Bengio, Nicholas Leonard and Aaron Courville . Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. arxiv:1308.3432 Geoffrey Hinton, Nitish Srivastava, Kevin Swersky, Tijmen Tieleman and Abdelrahman Mohamed . Neural Networks for Machine Learning. Coursera, video lectures, 2012 Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam and Dmitry Kalenichenko . Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. ECCV, 2018 Raghuraman Krishnamoorthi . Quantizing deep convolutional networks for efficient inference: A whitepaper arxiv:1806.08342 Ron Banner, Yury Nahshan, Elad Hoffer and Daniel Soudry . ACIQ: Analytical Clipping for Integer Quantization of neural networks arxiv:1810.05723", "title": "References" - }, + }, { - "location": "/knowledge_distillation/index.html", - "text": "Knowledge Distillation\n\n\n(For details on how to train a model with knowledge distillation in Distiller, see \nhere\n)\n\n\nKnowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as \"teacher-student\", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably).\n\n\nThe method was first proposed by \nBucila et al., 2006\n and generalized by \nHinton et al., 2015\n. The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a \nvideo lecture\n with \nslides\n is also available).\n\n\nIn distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, \nHinton et al., 2015\n introduced the concept of \"softmax temperature\". The probability \np_i\n of class \ni\n is calculated from the logits \nz\n as:\n\n\n\n\np_i = \\frac{exp\\left(\\frac{z_i}{T}\\right)}{\\sum_{j} \\exp\\left(\\frac{z_j}{T}\\right)}\n\n\n\n\nwhere \nT\n is the temperature parameter. When \nT=1\n we get the standard softmax function. As \nT\n grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the \"dark knowledge\" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of \nT\n to compute the softmax on the student's logits. We call this loss the \"distillation loss\".\n\n\nHinton et al., 2015\n found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the \"standard\" loss between the student's predicted class probabilities and the ground-truth labels (also called \"hard labels/targets\"). We dub this loss the \"student loss\". When calculating the class probabilities for the student loss we use \nT = 1\n. \n\n\nThe overall loss function, incorporating both distillation and student losses, is calculated as:\n\n\n\n\n\\mathcal{L}(x;W) = \\alpha * \\mathcal{H}(y, \\sigma(z_s; T=1)) + \\beta * \\mathcal{H}(\\sigma(z_t; T=\\tau), \\sigma(z_s, T=\\tau))\n\n\n\n\nwhere \nx\n is the input, \nW\n are the student model parameters, \ny\n is the ground truth label, \n\\mathcal{H}\n is the cross-entropy loss function, \n\\sigma\n is the softmax function parameterized by the temperature \nT\n, and \n\\alpha\n and \n\\beta\n are coefficients. \nz_s\n and \nz_t\n are the logits of the student and teacher respectively.\n\n\n\n\nNew Hyper-Parameters\n\n\nIn general \n\\tau\n, \n\\alpha\n and \n\\beta\n are hyper parameters.\n\n\nIn their experiments, \nHinton et al., 2015\n use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have.\n\n\nWith regards to \n\\alpha\n and \n\\beta\n, \nHinton et al., 2015\n use a weighted average between the distillation loss and the student loss. That is, \n\\beta = 1 - \\alpha\n. They note that in general, they obtained the best results when setting \n\\alpha\n to be much smaller than \n\\beta\n (although in one of their experiments they use \n\\alpha = \\beta = 0.5\n). Other works which utilize knowledge distillation don't use a weighted average. Some set \n\\alpha = 1\n while leaving \n\\beta\n tunable, while others don't set any constraints.\n\n\nCombining with Other Model Compression Techniques\n\n\nIn the \"basic\" scenario, the smaller (student) model is a pre-defined architecture which just has a smaller number of parameters compared to the teacher model. For example, we could train ResNet-18 by distilling knowledge from ResNet-34. But, a model with smaller capacity can also be obtained by other model compression techniques - sparsification and/or quantization. So, for example, we could train a 4-bit ResNet-18 model with some method using quantization-aware training, and use a distillation loss function as described above. In that case, the teacher model can even be a FP32 ResNet-18 model. Same goes for pruning and regularization.\n\n\nTann et al., 2017\n, \nMishra and Marr, 2018\n and \nPolino et al., 2018\n are some works that combine knowledge distillation with \nquantization\n. \nTheis et al., 2018\n and \nAshok et al., 2018\n combine distillation with \npruning\n.\n\n\nReferences\n\n\n\n\nCristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil\n. Model Compression. \nKDD, 2006\n\n\n\n\n\nGeoffrey Hinton, Oriol Vinyals and Jeff Dean\n. Distilling the Knowledge in a Neural Network. \narxiv:1503.02531\n\n\n\n\n\nHokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda\n. Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. \nDAC, 2017\n\n\n\n\n\nAsit Mishra and Debbie Marr\n. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. \nICLR, 2018\n\n\n\n\n\nAntonio Polino, Razvan Pascanu and Dan Alistarh\n. Model compression via distillation and quantization. \nICLR, 2018\n\n\n\n\n\nAnubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani\n. N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. \nICLR, 2018\n\n\n\n\n\nLucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Husz\u00e1r\n. Faster gaze prediction with dense networks and Fisher pruning. \narxiv:1801.05787", + "location": "/knowledge_distillation/index.html", + "text": "Knowledge Distillation\n\n\n(For details on how to train a model with knowledge distillation in Distiller, see \nhere\n)\n\n\nKnowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as \"teacher-student\", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably).\n\n\nThe method was first proposed by \nBucila et al., 2006\n and generalized by \nHinton et al., 2015\n. The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a \nvideo lecture\n with \nslides\n is also available).\n\n\nIn distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, \nHinton et al., 2015\n introduced the concept of \"softmax temperature\". The probability \np_i\n of class \ni\n is calculated from the logits \nz\n as:\n\n\n\n\np_i = \\frac{exp\\left(\\frac{z_i}{T}\\right)}{\\sum_{j} \\exp\\left(\\frac{z_j}{T}\\right)}\n\n\n\n\nwhere \nT\n is the temperature parameter. When \nT=1\n we get the standard softmax function. As \nT\n grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the \"dark knowledge\" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of \nT\n to compute the softmax on the student's logits. We call this loss the \"distillation loss\".\n\n\nHinton et al., 2015\n found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the \"standard\" loss between the student's predicted class probabilities and the ground-truth labels (also called \"hard labels/targets\"). We dub this loss the \"student loss\". When calculating the class probabilities for the student loss we use \nT = 1\n. \n\n\nThe overall loss function, incorporating both distillation and student losses, is calculated as:\n\n\n\n\n\\mathcal{L}(x;W) = \\alpha * \\mathcal{H}(y, \\sigma(z_s; T=1)) + \\beta * \\mathcal{H}(\\sigma(z_t; T=\\tau), \\sigma(z_s, T=\\tau))\n\n\n\n\nwhere \nx\n is the input, \nW\n are the student model parameters, \ny\n is the ground truth label, \n\\mathcal{H}\n is the cross-entropy loss function, \n\\sigma\n is the softmax function parameterized by the temperature \nT\n, and \n\\alpha\n and \n\\beta\n are coefficients. \nz_s\n and \nz_t\n are the logits of the student and teacher respectively.\n\n\n\n\nNew Hyper-Parameters\n\n\nIn general \n\\tau\n, \n\\alpha\n and \n\\beta\n are hyper parameters.\n\n\nIn their experiments, \nHinton et al., 2015\n use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have.\n\n\nWith regards to \n\\alpha\n and \n\\beta\n, \nHinton et al., 2015\n use a weighted average between the distillation loss and the student loss. That is, \n\\beta = 1 - \\alpha\n. They note that in general, they obtained the best results when setting \n\\alpha\n to be much smaller than \n\\beta\n (although in one of their experiments they use \n\\alpha = \\beta = 0.5\n). Other works which utilize knowledge distillation don't use a weighted average. Some set \n\\alpha = 1\n while leaving \n\\beta\n tunable, while others don't set any constraints.\n\n\nCombining with Other Model Compression Techniques\n\n\nIn the \"basic\" scenario, the smaller (student) model is a pre-defined architecture which just has a smaller number of parameters compared to the teacher model. For example, we could train ResNet-18 by distilling knowledge from ResNet-34. But, a model with smaller capacity can also be obtained by other model compression techniques - sparsification and/or quantization. So, for example, we could train a 4-bit ResNet-18 model with some method using quantization-aware training, and use a distillation loss function as described above. In that case, the teacher model can even be a FP32 ResNet-18 model. Same goes for pruning and regularization.\n\n\nTann et al., 2017\n, \nMishra and Marr, 2018\n and \nPolino et al., 2018\n are some works that combine knowledge distillation with \nquantization\n. \nTheis et al., 2018\n and \nAshok et al., 2018\n combine distillation with \npruning\n.\n\n\nReferences\n\n\n\n\nCristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil\n. Model Compression. \nKDD, 2006\n\n\n\n\n\nGeoffrey Hinton, Oriol Vinyals and Jeff Dean\n. Distilling the Knowledge in a Neural Network. \narxiv:1503.02531\n\n\n\n\n\nHokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda\n. Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. \nDAC, 2017\n\n\n\n\n\nAsit Mishra and Debbie Marr\n. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. \nICLR, 2018\n\n\n\n\n\nAntonio Polino, Razvan Pascanu and Dan Alistarh\n. Model compression via distillation and quantization. \nICLR, 2018\n\n\n\n\n\nAnubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani\n. N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. \nICLR, 2018\n\n\n\n\n\nLucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Husz\u00e1r\n. Faster gaze prediction with dense networks and Fisher pruning. \narxiv:1801.05787", "title": "Knowledge Distillation" - }, + }, { - "location": "/knowledge_distillation/index.html#knowledge-distillation", - "text": "(For details on how to train a model with knowledge distillation in Distiller, see here ) Knowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as \"teacher-student\", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably). The method was first proposed by Bucila et al., 2006 and generalized by Hinton et al., 2015 . The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a video lecture with slides is also available). In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, Hinton et al., 2015 introduced the concept of \"softmax temperature\". The probability p_i of class i is calculated from the logits z as: p_i = \\frac{exp\\left(\\frac{z_i}{T}\\right)}{\\sum_{j} \\exp\\left(\\frac{z_j}{T}\\right)} where T is the temperature parameter. When T=1 we get the standard softmax function. As T grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the \"dark knowledge\" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of T to compute the softmax on the student's logits. We call this loss the \"distillation loss\". Hinton et al., 2015 found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the \"standard\" loss between the student's predicted class probabilities and the ground-truth labels (also called \"hard labels/targets\"). We dub this loss the \"student loss\". When calculating the class probabilities for the student loss we use T = 1 . The overall loss function, incorporating both distillation and student losses, is calculated as: \\mathcal{L}(x;W) = \\alpha * \\mathcal{H}(y, \\sigma(z_s; T=1)) + \\beta * \\mathcal{H}(\\sigma(z_t; T=\\tau), \\sigma(z_s, T=\\tau)) where x is the input, W are the student model parameters, y is the ground truth label, \\mathcal{H} is the cross-entropy loss function, \\sigma is the softmax function parameterized by the temperature T , and \\alpha and \\beta are coefficients. z_s and z_t are the logits of the student and teacher respectively.", + "location": "/knowledge_distillation/index.html#knowledge-distillation", + "text": "(For details on how to train a model with knowledge distillation in Distiller, see here ) Knowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as \"teacher-student\", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably). The method was first proposed by Bucila et al., 2006 and generalized by Hinton et al., 2015 . The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a video lecture with slides is also available). In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, Hinton et al., 2015 introduced the concept of \"softmax temperature\". The probability p_i of class i is calculated from the logits z as: p_i = \\frac{exp\\left(\\frac{z_i}{T}\\right)}{\\sum_{j} \\exp\\left(\\frac{z_j}{T}\\right)} where T is the temperature parameter. When T=1 we get the standard softmax function. As T grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the \"dark knowledge\" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of T to compute the softmax on the student's logits. We call this loss the \"distillation loss\". Hinton et al., 2015 found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the \"standard\" loss between the student's predicted class probabilities and the ground-truth labels (also called \"hard labels/targets\"). We dub this loss the \"student loss\". When calculating the class probabilities for the student loss we use T = 1 . The overall loss function, incorporating both distillation and student losses, is calculated as: \\mathcal{L}(x;W) = \\alpha * \\mathcal{H}(y, \\sigma(z_s; T=1)) + \\beta * \\mathcal{H}(\\sigma(z_t; T=\\tau), \\sigma(z_s, T=\\tau)) where x is the input, W are the student model parameters, y is the ground truth label, \\mathcal{H} is the cross-entropy loss function, \\sigma is the softmax function parameterized by the temperature T , and \\alpha and \\beta are coefficients. z_s and z_t are the logits of the student and teacher respectively.", "title": "Knowledge Distillation" - }, + }, { - "location": "/knowledge_distillation/index.html#new-hyper-parameters", - "text": "In general \\tau , \\alpha and \\beta are hyper parameters. In their experiments, Hinton et al., 2015 use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have. With regards to \\alpha and \\beta , Hinton et al., 2015 use a weighted average between the distillation loss and the student loss. That is, \\beta = 1 - \\alpha . They note that in general, they obtained the best results when setting \\alpha to be much smaller than \\beta (although in one of their experiments they use \\alpha = \\beta = 0.5 ). Other works which utilize knowledge distillation don't use a weighted average. Some set \\alpha = 1 while leaving \\beta tunable, while others don't set any constraints.", + "location": "/knowledge_distillation/index.html#new-hyper-parameters", + "text": "In general \\tau , \\alpha and \\beta are hyper parameters. In their experiments, Hinton et al., 2015 use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have. With regards to \\alpha and \\beta , Hinton et al., 2015 use a weighted average between the distillation loss and the student loss. That is, \\beta = 1 - \\alpha . They note that in general, they obtained the best results when setting \\alpha to be much smaller than \\beta (although in one of their experiments they use \\alpha = \\beta = 0.5 ). Other works which utilize knowledge distillation don't use a weighted average. Some set \\alpha = 1 while leaving \\beta tunable, while others don't set any constraints.", "title": "New Hyper-Parameters" - }, + }, { - "location": "/knowledge_distillation/index.html#references", - "text": "Cristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil . Model Compression. KDD, 2006 Geoffrey Hinton, Oriol Vinyals and Jeff Dean . Distilling the Knowledge in a Neural Network. arxiv:1503.02531 Hokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda . Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. DAC, 2017 Asit Mishra and Debbie Marr . Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. ICLR, 2018 Antonio Polino, Razvan Pascanu and Dan Alistarh . Model compression via distillation and quantization. ICLR, 2018 Anubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani . N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. ICLR, 2018 Lucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Husz\u00e1r . Faster gaze prediction with dense networks and Fisher pruning. arxiv:1801.05787", + "location": "/knowledge_distillation/index.html#references", + "text": "Cristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil . Model Compression. KDD, 2006 Geoffrey Hinton, Oriol Vinyals and Jeff Dean . Distilling the Knowledge in a Neural Network. arxiv:1503.02531 Hokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda . Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. DAC, 2017 Asit Mishra and Debbie Marr . Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. ICLR, 2018 Antonio Polino, Razvan Pascanu and Dan Alistarh . Model compression via distillation and quantization. ICLR, 2018 Anubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani . N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. ICLR, 2018 Lucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Husz\u00e1r . Faster gaze prediction with dense networks and Fisher pruning. arxiv:1801.05787", "title": "References" - }, + }, { - "location": "/conditional_computation/index.html", - "text": "Conditional Computation\n\n\nConditional Computation refers to a class of algorithms in which each input sample uses a different part of the model, such that on average the compute, latency or power (depending on our objective) is reduced.\nTo quote \nBengio et. al\n\n\n\n\n\"Conditional computation refers to activating only some of the units in a network, in an input-dependent fashion. For example, if we think we\u2019re looking at a car, we only need to compute the activations of the vehicle detecting units, not of all features that a network could possible compute. The immediate effect of activating fewer units is that propagating information through the network will be faster, both at training as well as at test time. However, one needs to be able to decide in an intelligent fashion which units to turn on and off, depending on the input data. This is typically achieved with some form of gating structure, learned in parallel with the original network.\"\n\n\n\n\nAs usual, there are several approaches to implement Conditional Computation:\n\n\n\n\nSun et. al\n use several expert CNN, each trained on a different task, and combine them to one large network.\n\n\nZheng et. al\n use cascading, an idea which may be familiar to you from Viola-Jones face detection.\n\n\nTheodorakopoulos et. al\n add small layers that learn which filters to use per input sample, and then enforce that during inference (LKAM module).\n\n\nIoannou et. al\n introduce Conditional Networks: that \"can be thought of as: i) decision trees augmented with data transformation\noperators, or ii) CNNs, with block-diagonal sparse weight matrices, and explicit data routing functions\"\n\n\nBolukbasi et. al\n \"learn a system to adaptively choose the components of a deep network to be evaluated for each example. By allowing examples correctly classified using early layers of the system to exit, we avoid the computational time associated with full evaluation of the network. We extend this to learn a network selection system that adaptively selects the network to be evaluated for each example.\"\n\n\n\n\nConditional Computation is especially useful for real-time, latency-sensitive applicative.\n\nIn Distiller we currently have implemented a variant of Early Exit.\n\n\nReferences\n\n\n \nEmmanuel Bengio, Pierre-Luc Bacon, Joelle Pineau, Doina Precup.\n\n \nConditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition\n, arXiv:1511.06297v2, 2016.\n\n\n\n\n\nY. Sun, X.Wang, and X. Tang.\n\n \nDeep Convolutional Network Cascade for Facial Point Detection\n. In Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR), 2014\n\n\n\n\n\nX. Zheng, W.Ouyang, and X.Wang.\n \nMulti-Stage Contextual Deep Learning for Pedestrian Detection.\n In Proc. IEEE Intl Conf. on Computer Vision (ICCV), 2014.\n\n\n\n\n\nI. Theodorakopoulos, V. Pothos, D. Kastaniotis and N. Fragoulis1.\n \nParsimonious Inference on Convolutional Neural Networks: Learning and applying on-line kernel activation rules.\n Irida Labs S.A, January 2017\n\n\n\n\n\nTolga Bolukbasi, Joseph Wang, Ofer Dekel, Venkatesh Saligrama\n \nAdaptive Neural Networks for Efficient Inference\n. Proceedings of the 34th International Conference on Machine Learning, PMLR 70:527-536, 2017.\n\n\n\n\n\nYani Ioannou, Duncan Robertson, Darko Zikic, Peter Kontschieder, Jamie Shotton, Matthew Brown, Antonio Criminisi\n.\n \nDecision Forests, Convolutional Networks and the Models in-Between\n, arXiv:1511.06297v2, 2016.", + "location": "/conditional_computation/index.html", + "text": "Conditional Computation\n\n\nConditional Computation refers to a class of algorithms in which each input sample uses a different part of the model, such that on average the compute, latency or power (depending on our objective) is reduced.\nTo quote \nBengio et. al\n\n\n\n\n\"Conditional computation refers to activating only some of the units in a network, in an input-dependent fashion. For example, if we think we\u2019re looking at a car, we only need to compute the activations of the vehicle detecting units, not of all features that a network could possible compute. The immediate effect of activating fewer units is that propagating information through the network will be faster, both at training as well as at test time. However, one needs to be able to decide in an intelligent fashion which units to turn on and off, depending on the input data. This is typically achieved with some form of gating structure, learned in parallel with the original network.\"\n\n\n\n\nAs usual, there are several approaches to implement Conditional Computation:\n\n\n\n\nSun et. al\n use several expert CNN, each trained on a different task, and combine them to one large network.\n\n\nZheng et. al\n use cascading, an idea which may be familiar to you from Viola-Jones face detection.\n\n\nTheodorakopoulos et. al\n add small layers that learn which filters to use per input sample, and then enforce that during inference (LKAM module).\n\n\nIoannou et. al\n introduce Conditional Networks: that \"can be thought of as: i) decision trees augmented with data transformation\noperators, or ii) CNNs, with block-diagonal sparse weight matrices, and explicit data routing functions\"\n\n\nBolukbasi et. al\n \"learn a system to adaptively choose the components of a deep network to be evaluated for each example. By allowing examples correctly classified using early layers of the system to exit, we avoid the computational time associated with full evaluation of the network. We extend this to learn a network selection system that adaptively selects the network to be evaluated for each example.\"\n\n\n\n\nConditional Computation is especially useful for real-time, latency-sensitive applicative.\n\nIn Distiller we currently have implemented a variant of Early Exit.\n\n\nReferences\n\n\n \nEmmanuel Bengio, Pierre-Luc Bacon, Joelle Pineau, Doina Precup.\n\n \nConditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition\n, arXiv:1511.06297v2, 2016.\n\n\n\n\n\nY. Sun, X.Wang, and X. Tang.\n\n \nDeep Convolutional Network Cascade for Facial Point Detection\n. In Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR), 2014\n\n\n\n\n\nX. Zheng, W.Ouyang, and X.Wang.\n \nMulti-Stage Contextual Deep Learning for Pedestrian Detection.\n In Proc. IEEE Intl Conf. on Computer Vision (ICCV), 2014.\n\n\n\n\n\nI. Theodorakopoulos, V. Pothos, D. Kastaniotis and N. Fragoulis1.\n \nParsimonious Inference on Convolutional Neural Networks: Learning and applying on-line kernel activation rules.\n Irida Labs S.A, January 2017\n\n\n\n\n\nTolga Bolukbasi, Joseph Wang, Ofer Dekel, Venkatesh Saligrama\n \nAdaptive Neural Networks for Efficient Inference\n. Proceedings of the 34th International Conference on Machine Learning, PMLR 70:527-536, 2017.\n\n\n\n\n\nYani Ioannou, Duncan Robertson, Darko Zikic, Peter Kontschieder, Jamie Shotton, Matthew Brown, Antonio Criminisi\n.\n \nDecision Forests, Convolutional Networks and the Models in-Between\n, arXiv:1511.06297v2, 2016.", "title": "Conditional Computation" - }, + }, { - "location": "/conditional_computation/index.html#conditional-computation", - "text": "Conditional Computation refers to a class of algorithms in which each input sample uses a different part of the model, such that on average the compute, latency or power (depending on our objective) is reduced.\nTo quote Bengio et. al \"Conditional computation refers to activating only some of the units in a network, in an input-dependent fashion. For example, if we think we\u2019re looking at a car, we only need to compute the activations of the vehicle detecting units, not of all features that a network could possible compute. The immediate effect of activating fewer units is that propagating information through the network will be faster, both at training as well as at test time. However, one needs to be able to decide in an intelligent fashion which units to turn on and off, depending on the input data. This is typically achieved with some form of gating structure, learned in parallel with the original network.\" As usual, there are several approaches to implement Conditional Computation: Sun et. al use several expert CNN, each trained on a different task, and combine them to one large network. Zheng et. al use cascading, an idea which may be familiar to you from Viola-Jones face detection. Theodorakopoulos et. al add small layers that learn which filters to use per input sample, and then enforce that during inference (LKAM module). Ioannou et. al introduce Conditional Networks: that \"can be thought of as: i) decision trees augmented with data transformation\noperators, or ii) CNNs, with block-diagonal sparse weight matrices, and explicit data routing functions\" Bolukbasi et. al \"learn a system to adaptively choose the components of a deep network to be evaluated for each example. By allowing examples correctly classified using early layers of the system to exit, we avoid the computational time associated with full evaluation of the network. We extend this to learn a network selection system that adaptively selects the network to be evaluated for each example.\" Conditional Computation is especially useful for real-time, latency-sensitive applicative. \nIn Distiller we currently have implemented a variant of Early Exit.", + "location": "/conditional_computation/index.html#conditional-computation", + "text": "Conditional Computation refers to a class of algorithms in which each input sample uses a different part of the model, such that on average the compute, latency or power (depending on our objective) is reduced.\nTo quote Bengio et. al \"Conditional computation refers to activating only some of the units in a network, in an input-dependent fashion. For example, if we think we\u2019re looking at a car, we only need to compute the activations of the vehicle detecting units, not of all features that a network could possible compute. The immediate effect of activating fewer units is that propagating information through the network will be faster, both at training as well as at test time. However, one needs to be able to decide in an intelligent fashion which units to turn on and off, depending on the input data. This is typically achieved with some form of gating structure, learned in parallel with the original network.\" As usual, there are several approaches to implement Conditional Computation: Sun et. al use several expert CNN, each trained on a different task, and combine them to one large network. Zheng et. al use cascading, an idea which may be familiar to you from Viola-Jones face detection. Theodorakopoulos et. al add small layers that learn which filters to use per input sample, and then enforce that during inference (LKAM module). Ioannou et. al introduce Conditional Networks: that \"can be thought of as: i) decision trees augmented with data transformation\noperators, or ii) CNNs, with block-diagonal sparse weight matrices, and explicit data routing functions\" Bolukbasi et. al \"learn a system to adaptively choose the components of a deep network to be evaluated for each example. By allowing examples correctly classified using early layers of the system to exit, we avoid the computational time associated with full evaluation of the network. We extend this to learn a network selection system that adaptively selects the network to be evaluated for each example.\" Conditional Computation is especially useful for real-time, latency-sensitive applicative. \nIn Distiller we currently have implemented a variant of Early Exit.", "title": "Conditional Computation" - }, + }, { - "location": "/conditional_computation/index.html#references", - "text": "Emmanuel Bengio, Pierre-Luc Bacon, Joelle Pineau, Doina Precup. \n Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition , arXiv:1511.06297v2, 2016. Y. Sun, X.Wang, and X. Tang. \n Deep Convolutional Network Cascade for Facial Point Detection . In Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR), 2014 X. Zheng, W.Ouyang, and X.Wang. Multi-Stage Contextual Deep Learning for Pedestrian Detection. In Proc. IEEE Intl Conf. on Computer Vision (ICCV), 2014. I. Theodorakopoulos, V. Pothos, D. Kastaniotis and N. Fragoulis1. Parsimonious Inference on Convolutional Neural Networks: Learning and applying on-line kernel activation rules. Irida Labs S.A, January 2017 Tolga Bolukbasi, Joseph Wang, Ofer Dekel, Venkatesh Saligrama Adaptive Neural Networks for Efficient Inference . Proceedings of the 34th International Conference on Machine Learning, PMLR 70:527-536, 2017. Yani Ioannou, Duncan Robertson, Darko Zikic, Peter Kontschieder, Jamie Shotton, Matthew Brown, Antonio Criminisi .\n Decision Forests, Convolutional Networks and the Models in-Between , arXiv:1511.06297v2, 2016.", + "location": "/conditional_computation/index.html#references", + "text": "Emmanuel Bengio, Pierre-Luc Bacon, Joelle Pineau, Doina Precup. \n Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition , arXiv:1511.06297v2, 2016. Y. Sun, X.Wang, and X. Tang. \n Deep Convolutional Network Cascade for Facial Point Detection . In Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR), 2014 X. Zheng, W.Ouyang, and X.Wang. Multi-Stage Contextual Deep Learning for Pedestrian Detection. In Proc. IEEE Intl Conf. on Computer Vision (ICCV), 2014. I. Theodorakopoulos, V. Pothos, D. Kastaniotis and N. Fragoulis1. Parsimonious Inference on Convolutional Neural Networks: Learning and applying on-line kernel activation rules. Irida Labs S.A, January 2017 Tolga Bolukbasi, Joseph Wang, Ofer Dekel, Venkatesh Saligrama Adaptive Neural Networks for Efficient Inference . Proceedings of the 34th International Conference on Machine Learning, PMLR 70:527-536, 2017. Yani Ioannou, Duncan Robertson, Darko Zikic, Peter Kontschieder, Jamie Shotton, Matthew Brown, Antonio Criminisi .\n Decision Forests, Convolutional Networks and the Models in-Between , arXiv:1511.06297v2, 2016.", "title": "References" - }, + }, { - "location": "/algo_pruning/index.html", - "text": "Weights pruning algorithms\n\n\n\n\nMagnitude pruner\n\n\nThis is the most basic pruner: it applies a thresholding function, \\(thresh(.)\\), on each element, \\(w_i\\), of a weights tensor. A different threshold can be used for each layer's weights tensor.\n\nBecause the threshold is applied on individual elements, this pruner belongs to the element-wise pruning algorithm family.\n\n\n\\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\]\n\n\nSensitivity pruner\n\n\nFinding a threshold magnitude per layer is daunting, especially since each layer's elements have different average absolute values. We can take advantage of the fact that the weights of convolutional and fully connected layers exhibit a Gaussian distribution with a mean value roughly zero, to avoid using a direct threshold based on the values of each specific tensor.\n\n\nThe diagram below shows the distribution the weights tensor of the first convolutional layer, and first fully-connected layer in TorchVision's pre-trained Alexnet model. You can see that they have an approximate Gaussian distribution.\n\n\n \n\n\nThe distributions of Alexnet conv1 and fc1 layers\n\n\nWe use the standard deviation of the weights tensor as a sort of normalizing factor between the different weights tensors. For example, if a tensor is Normally distributed, then about 68% of the elements have an absolute value less than the standard deviation (\\(\\sigma\\)) of the tensor. Thus, if we set the threshold to \\(s*\\sigma\\), then basically we are thresholding \\(s * 68\\%\\) of the tensor elements. \n\n\n\\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\]\n\n\n\\[\n\\lambda = s * \\sigma_l \\;\\;\\; where\\; \\sigma_l\\; is \\;the \\;std \\;of \\;layer \\;l \\;as \\;measured \\;on \\;the \\;dense \\;model\n\\]\n\n\nHow do we choose this \\(s\\) multiplier?\n\n\nIn \nLearning both Weights and Connections for Efficient Neural Networks\n the authors write:\n\n\n\n\n\"We used the sensitivity results to find each layer\u2019s threshold: for example, the smallest threshold was applied to the most sensitive layer, which is the first convolutional layer... The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer\u2019s weights\n\n\n\n\nSo the results of executing pruning sensitivity analysis on the tensor, gives us a good starting guess at \\(s\\). Sensitivity analysis is an empirical method, and we still have to spend time to hone in on the exact multiplier value.\n\n\nMethod of operation\n\n\n\n\nStart by running a pruning sensitivity analysis on the model. \n\n\nThen use the results to set and tune the threshold of each layer, but instead of using a direct threshold use a sensitivity parameter which is multiplied by the standard-deviation of the initial weight-tensor's distribution.\n\n\n\n\nSchedule\n\n\nIn their \npaper\n Song Han et al. use iterative pruning and change the value of the \\(s\\) multiplier at each pruning step. Distiller's \nSensitivityPruner\n works differently: the value \\(s\\) is set once based on a one-time calculation of the standard-deviation of the tensor (the first time we prune), and relies on the fact that as the tensor is pruned, more elements are \"pulled\" toward the center of the distribution and thus more elements gets pruned.\n\n\nThis actually works quite well as we can see in the diagram below. This is a TensorBoard screen-capture from Alexnet training, which shows how this method starts off pruning very aggressively, but then slowly reduces the pruning rate.\n\n\n\nWe use a simple iterative-pruning schedule such as: \nPrune every second epoch starting at epoch 0, and ending at epoch 38.\n This excerpt from \nalexnet.schedule_sensitivity.yaml\n shows how this iterative schedule is conveyed in Distiller scheduling configuration YAML:\n\n\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n\n\n\nLevel pruner\n\n\nClass \nSparsityLevelParameterPruner\n uses a similar method to go around specifying specific thresholding magnitudes.\nInstead of specifying a threshold magnitude, you specify a target sparsity level (expressed as a fraction, so 0.5 means 50% sparsity). Essentially this pruner also uses a pruning criteria based on the magnitude of each tensor element, but it has the advantage that you can aim for an exact and specific sparsity level.\n\nThis pruner is much more stable compared to \nSensitivityPruner\n because the target sparsity level is not coupled to the actual magnitudes of the elements. Distiller's \nSensitivityPruner\n is unstable because the final sparsity level depends on the convergence pattern of the tensor distribution. Song Han's methodology of using several different values for the multiplier \\(s\\), and the recalculation of the standard-deviation at each pruning phase, probably gives it stability, but requires much more hyper-parameters (this is the reason we have not implemented it thus far). \n\n\nTo set the target sparsity levels, you can once again use pruning sensitivity analysis to make better guesses at the correct sparsity level of each\n\n\nMethod of operation\n\n\n\n\nSort the weights in the specified layer by their absolute values. \n\n\nMask to zero the smallest magnitude weights until the desired sparsity level is reached.\n\n\n\n\nAutomated gradual pruner (AGP)\n\n\nIn \nTo prune, or not to prune: exploring the efficacy of pruning for model compression\n, authors Michael Zhu and Suyog Gupta provide an algorithm to schedule a Level Pruner which Distiller implements in \nAutomatedGradualPruner\n.\n\n\n\n\n\n\"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value \\(s_i\\) (usually 0) to a \ufb01nal sparsity value \\(s_f\\) over a span of n pruning steps.\nThe intuition behind this sparsity function in equation (1) is to prune the network rapidly in the initial phase when the redundant connections are\nabundant and gradually reduce the number of weights being pruned each time as there are fewer and fewer weights remaining in the network.\"\"\n\n\n\n\n\n\nYou can play with the scheduling parameters in the \nagp_schedule.ipynb notebook\n.\n\n\nThe authors describe AGP:\n\n\n\n\n\n\nOur automated gradual pruning algorithm prunes the smallest magnitude weights to achieve a preset level of network sparsity.\n\n\nDoesn't require much hyper-parameter tuning\n\n\nShown to perform well across different models\n\n\nDoes not make any assumptions about the structure of the network or its constituent layers, and is therefore more generally applicable.\n\n\n\n\n\n\nRNN pruner\n\n\nThe authors of \nExploring Sparsity in Recurrent Neural Networks\n, Sharan Narang, Erich Elsen, Gregory Diamos, and Shubho Sengupta, \"propose a technique to reduce the parameters of a network by pruning weights during the initial training of the network.\" They use a gradual pruning schedule which is reminiscent of the schedule used in AGP, for element-wise pruning of RNNs, which they also employ during training. They show pruning of RNN, GRU, LSTM and embedding layers.\n\n\nDistiller's distiller.pruning.BaiduRNNPruner class implements this pruning algorithm.\n\n\n\n\nStructure pruners\n\n\nElement-wise pruning can create very sparse models which can be compressed to consume less memory footprint and bandwidth, but without specialized hardware that can compute using the sparse representation of the tensors, we don't gain any speedup of the computation. Structure pruners, remove entire \"structures\", such as kernels, filters, and even entire feature-maps.\n\n\nRanked structure pruner\n\n\nThe \nL1RankedStructureParameterPruner\n pruner calculates the magnitude of some \"structure\", orders all of the structures based on some magnitude function and the \nm\n lowest ranking structures are pruned away. Currently this pruner only performs ranking of filters (3D structures) and it uses the mean of the absolute value of the tensor as the representative of the filter magnitude. The absolute mean does not depend on the size of the filter, so it is easier to use compared to just using the \\(L_1\\)-norm of the structure, and at the same time it is a good proxy of the \\(L_1\\)-norm.\n\n\nIn \nPruning Filters for Efficient ConvNets\n the authors use filter ranking, with \none-shot pruning\n followed by fine-tuning. The authors of \nExploiting Sparseness in Deep Neural Networks for Large Vocabulary Speech Recognition\n also use a one-shot pruning schedule, for fully-connected layers, and they provide an explanation:\n\n\n\n\nFirst, after sweeping through the full training set several times the weights become relatively stable \u2014 they tend to remain either large or small magnitudes. Second, in a stabilized model, the importance of the connection is approximated well by the magnitudes of the weights (times the magnitudes of the corresponding input values, but these are relatively uniform within each layer since on the input layer, features are normalized to zero-mean and unit-variance, and hidden-layer values are probabilities)\n\n\n\n\nActivation-influenced pruner\n\n\nThe motivation for this pruner, is that if a feature-map produces very small activations, then this feature-map is not very important, and can be pruned away.\n- \nStatus: not implemented", + "location": "/algo_pruning/index.html", + "text": "Weights pruning algorithms\n\n\n\n\nMagnitude pruner\n\n\nThis is the most basic pruner: it applies a thresholding function, \\(thresh(.)\\), on each element, \\(w_i\\), of a weights tensor. A different threshold can be used for each layer's weights tensor.\n\nBecause the threshold is applied on individual elements, this pruner belongs to the element-wise pruning algorithm family.\n\n\n\\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\]\n\n\nSensitivity pruner\n\n\nFinding a threshold magnitude per layer is daunting, especially since each layer's elements have different average absolute values. We can take advantage of the fact that the weights of convolutional and fully connected layers exhibit a Gaussian distribution with a mean value roughly zero, to avoid using a direct threshold based on the values of each specific tensor.\n\n\nThe diagram below shows the distribution the weights tensor of the first convolutional layer, and first fully-connected layer in TorchVision's pre-trained Alexnet model. You can see that they have an approximate Gaussian distribution.\n\n\n \n\n\nThe distributions of Alexnet conv1 and fc1 layers\n\n\nWe use the standard deviation of the weights tensor as a sort of normalizing factor between the different weights tensors. For example, if a tensor is Normally distributed, then about 68% of the elements have an absolute value less than the standard deviation (\\(\\sigma\\)) of the tensor. Thus, if we set the threshold to \\(s*\\sigma\\), then basically we are thresholding \\(s * 68\\%\\) of the tensor elements. \n\n\n\\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\]\n\n\n\\[\n\\lambda = s * \\sigma_l \\;\\;\\; where\\; \\sigma_l\\; is \\;the \\;std \\;of \\;layer \\;l \\;as \\;measured \\;on \\;the \\;dense \\;model\n\\]\n\n\nHow do we choose this \\(s\\) multiplier?\n\n\nIn \nLearning both Weights and Connections for Efficient Neural Networks\n the authors write:\n\n\n\n\n\"We used the sensitivity results to find each layer\u2019s threshold: for example, the smallest threshold was applied to the most sensitive layer, which is the first convolutional layer... The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer\u2019s weights\n\n\n\n\nSo the results of executing pruning sensitivity analysis on the tensor, gives us a good starting guess at \\(s\\). Sensitivity analysis is an empirical method, and we still have to spend time to hone in on the exact multiplier value.\n\n\nMethod of operation\n\n\n\n\nStart by running a pruning sensitivity analysis on the model. \n\n\nThen use the results to set and tune the threshold of each layer, but instead of using a direct threshold use a sensitivity parameter which is multiplied by the standard-deviation of the initial weight-tensor's distribution.\n\n\n\n\nSchedule\n\n\nIn their \npaper\n Song Han et al. use iterative pruning and change the value of the \\(s\\) multiplier at each pruning step. Distiller's \nSensitivityPruner\n works differently: the value \\(s\\) is set once based on a one-time calculation of the standard-deviation of the tensor (the first time we prune), and relies on the fact that as the tensor is pruned, more elements are \"pulled\" toward the center of the distribution and thus more elements gets pruned.\n\n\nThis actually works quite well as we can see in the diagram below. This is a TensorBoard screen-capture from Alexnet training, which shows how this method starts off pruning very aggressively, but then slowly reduces the pruning rate.\n\n\n\nWe use a simple iterative-pruning schedule such as: \nPrune every second epoch starting at epoch 0, and ending at epoch 38.\n This excerpt from \nalexnet.schedule_sensitivity.yaml\n shows how this iterative schedule is conveyed in Distiller scheduling configuration YAML:\n\n\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n\n\n\nLevel pruner\n\n\nClass \nSparsityLevelParameterPruner\n uses a similar method to go around specifying specific thresholding magnitudes.\nInstead of specifying a threshold magnitude, you specify a target sparsity level (expressed as a fraction, so 0.5 means 50% sparsity). Essentially this pruner also uses a pruning criteria based on the magnitude of each tensor element, but it has the advantage that you can aim for an exact and specific sparsity level.\n\nThis pruner is much more stable compared to \nSensitivityPruner\n because the target sparsity level is not coupled to the actual magnitudes of the elements. Distiller's \nSensitivityPruner\n is unstable because the final sparsity level depends on the convergence pattern of the tensor distribution. Song Han's methodology of using several different values for the multiplier \\(s\\), and the recalculation of the standard-deviation at each pruning phase, probably gives it stability, but requires much more hyper-parameters (this is the reason we have not implemented it thus far). \n\n\nTo set the target sparsity levels, you can once again use pruning sensitivity analysis to make better guesses at the correct sparsity level of each\n\n\nMethod of operation\n\n\n\n\nSort the weights in the specified layer by their absolute values. \n\n\nMask to zero the smallest magnitude weights until the desired sparsity level is reached.\n\n\n\n\nAutomated gradual pruner (AGP)\n\n\nIn \nTo prune, or not to prune: exploring the efficacy of pruning for model compression\n, authors Michael Zhu and Suyog Gupta provide an algorithm to schedule a Level Pruner which Distiller implements in \nAutomatedGradualPruner\n.\n\n\n\n\n\n\"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value \\(s_i\\) (usually 0) to a \ufb01nal sparsity value \\(s_f\\) over a span of n pruning steps.\nThe intuition behind this sparsity function in equation (1) is to prune the network rapidly in the initial phase when the redundant connections are\nabundant and gradually reduce the number of weights being pruned each time as there are fewer and fewer weights remaining in the network.\"\"\n\n\n\n\n\n\nYou can play with the scheduling parameters in the \nagp_schedule.ipynb notebook\n.\n\n\nThe authors describe AGP:\n\n\n\n\n\n\nOur automated gradual pruning algorithm prunes the smallest magnitude weights to achieve a preset level of network sparsity.\n\n\nDoesn't require much hyper-parameter tuning\n\n\nShown to perform well across different models\n\n\nDoes not make any assumptions about the structure of the network or its constituent layers, and is therefore more generally applicable.\n\n\n\n\n\n\nRNN pruner\n\n\nThe authors of \nExploring Sparsity in Recurrent Neural Networks\n, Sharan Narang, Erich Elsen, Gregory Diamos, and Shubho Sengupta, \"propose a technique to reduce the parameters of a network by pruning weights during the initial training of the network.\" They use a gradual pruning schedule which is reminiscent of the schedule used in AGP, for element-wise pruning of RNNs, which they also employ during training. They show pruning of RNN, GRU, LSTM and embedding layers.\n\n\nDistiller's distiller.pruning.BaiduRNNPruner class implements this pruning algorithm.\n\n\n\n\nStructure pruners\n\n\nElement-wise pruning can create very sparse models which can be compressed to consume less memory footprint and bandwidth, but without specialized hardware that can compute using the sparse representation of the tensors, we don't gain any speedup of the computation. Structure pruners, remove entire \"structures\", such as kernels, filters, and even entire feature-maps.\n\n\nRanked structure pruner\n\n\nThe \nL1RankedStructureParameterPruner\n pruner calculates the magnitude of some \"structure\", orders all of the structures based on some magnitude function and the \nm\n lowest ranking structures are pruned away. Currently this pruner only performs ranking of filters (3D structures) and it uses the mean of the absolute value of the tensor as the representative of the filter magnitude. The absolute mean does not depend on the size of the filter, so it is easier to use compared to just using the \\(L_1\\)-norm of the structure, and at the same time it is a good proxy of the \\(L_1\\)-norm.\n\n\nIn \nPruning Filters for Efficient ConvNets\n the authors use filter ranking, with \none-shot pruning\n followed by fine-tuning. The authors of \nExploiting Sparseness in Deep Neural Networks for Large Vocabulary Speech Recognition\n also use a one-shot pruning schedule, for fully-connected layers, and they provide an explanation:\n\n\n\n\nFirst, after sweeping through the full training set several times the weights become relatively stable \u2014 they tend to remain either large or small magnitudes. Second, in a stabilized model, the importance of the connection is approximated well by the magnitudes of the weights (times the magnitudes of the corresponding input values, but these are relatively uniform within each layer since on the input layer, features are normalized to zero-mean and unit-variance, and hidden-layer values are probabilities)\n\n\n\n\nActivation-influenced pruner\n\n\nThe motivation for this pruner, is that if a feature-map produces very small activations, then this feature-map is not very important, and can be pruned away.\n- \nStatus: not implemented", "title": "Pruning" - }, + }, { - "location": "/algo_pruning/index.html#weights-pruning-algorithms", - "text": "", + "location": "/algo_pruning/index.html#weights-pruning-algorithms", + "text": "", "title": "Weights pruning algorithms" - }, + }, { - "location": "/algo_pruning/index.html#magnitude-pruner", - "text": "This is the most basic pruner: it applies a thresholding function, \\(thresh(.)\\), on each element, \\(w_i\\), of a weights tensor. A different threshold can be used for each layer's weights tensor. \nBecause the threshold is applied on individual elements, this pruner belongs to the element-wise pruning algorithm family. \\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\]", + "location": "/algo_pruning/index.html#magnitude-pruner", + "text": "This is the most basic pruner: it applies a thresholding function, \\(thresh(.)\\), on each element, \\(w_i\\), of a weights tensor. A different threshold can be used for each layer's weights tensor. \nBecause the threshold is applied on individual elements, this pruner belongs to the element-wise pruning algorithm family. \\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\]", "title": "Magnitude pruner" - }, + }, { - "location": "/algo_pruning/index.html#sensitivity-pruner", - "text": "Finding a threshold magnitude per layer is daunting, especially since each layer's elements have different average absolute values. We can take advantage of the fact that the weights of convolutional and fully connected layers exhibit a Gaussian distribution with a mean value roughly zero, to avoid using a direct threshold based on the values of each specific tensor. \nThe diagram below shows the distribution the weights tensor of the first convolutional layer, and first fully-connected layer in TorchVision's pre-trained Alexnet model. You can see that they have an approximate Gaussian distribution. The distributions of Alexnet conv1 and fc1 layers We use the standard deviation of the weights tensor as a sort of normalizing factor between the different weights tensors. For example, if a tensor is Normally distributed, then about 68% of the elements have an absolute value less than the standard deviation (\\(\\sigma\\)) of the tensor. Thus, if we set the threshold to \\(s*\\sigma\\), then basically we are thresholding \\(s * 68\\%\\) of the tensor elements. \\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\] \\[\n\\lambda = s * \\sigma_l \\;\\;\\; where\\; \\sigma_l\\; is \\;the \\;std \\;of \\;layer \\;l \\;as \\;measured \\;on \\;the \\;dense \\;model\n\\] How do we choose this \\(s\\) multiplier? In Learning both Weights and Connections for Efficient Neural Networks the authors write: \"We used the sensitivity results to find each layer\u2019s threshold: for example, the smallest threshold was applied to the most sensitive layer, which is the first convolutional layer... The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer\u2019s weights So the results of executing pruning sensitivity analysis on the tensor, gives us a good starting guess at \\(s\\). Sensitivity analysis is an empirical method, and we still have to spend time to hone in on the exact multiplier value.", + "location": "/algo_pruning/index.html#sensitivity-pruner", + "text": "Finding a threshold magnitude per layer is daunting, especially since each layer's elements have different average absolute values. We can take advantage of the fact that the weights of convolutional and fully connected layers exhibit a Gaussian distribution with a mean value roughly zero, to avoid using a direct threshold based on the values of each specific tensor. \nThe diagram below shows the distribution the weights tensor of the first convolutional layer, and first fully-connected layer in TorchVision's pre-trained Alexnet model. You can see that they have an approximate Gaussian distribution. The distributions of Alexnet conv1 and fc1 layers We use the standard deviation of the weights tensor as a sort of normalizing factor between the different weights tensors. For example, if a tensor is Normally distributed, then about 68% of the elements have an absolute value less than the standard deviation (\\(\\sigma\\)) of the tensor. Thus, if we set the threshold to \\(s*\\sigma\\), then basically we are thresholding \\(s * 68\\%\\) of the tensor elements. \\[ thresh(w_i)=\\left\\lbrace\n\\matrix{{{w_i: \\; if \\;|w_i| \\; \\gt}\\;\\lambda}\\cr {0: \\; if \\; |w_i| \\leq \\lambda} }\n\\right\\rbrace \\] \\[\n\\lambda = s * \\sigma_l \\;\\;\\; where\\; \\sigma_l\\; is \\;the \\;std \\;of \\;layer \\;l \\;as \\;measured \\;on \\;the \\;dense \\;model\n\\] How do we choose this \\(s\\) multiplier? In Learning both Weights and Connections for Efficient Neural Networks the authors write: \"We used the sensitivity results to find each layer\u2019s threshold: for example, the smallest threshold was applied to the most sensitive layer, which is the first convolutional layer... The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer\u2019s weights So the results of executing pruning sensitivity analysis on the tensor, gives us a good starting guess at \\(s\\). Sensitivity analysis is an empirical method, and we still have to spend time to hone in on the exact multiplier value.", "title": "Sensitivity pruner" - }, + }, { - "location": "/algo_pruning/index.html#method-of-operation", - "text": "Start by running a pruning sensitivity analysis on the model. Then use the results to set and tune the threshold of each layer, but instead of using a direct threshold use a sensitivity parameter which is multiplied by the standard-deviation of the initial weight-tensor's distribution.", + "location": "/algo_pruning/index.html#method-of-operation", + "text": "Start by running a pruning sensitivity analysis on the model. Then use the results to set and tune the threshold of each layer, but instead of using a direct threshold use a sensitivity parameter which is multiplied by the standard-deviation of the initial weight-tensor's distribution.", "title": "Method of operation" - }, + }, { - "location": "/algo_pruning/index.html#schedule", - "text": "In their paper Song Han et al. use iterative pruning and change the value of the \\(s\\) multiplier at each pruning step. Distiller's SensitivityPruner works differently: the value \\(s\\) is set once based on a one-time calculation of the standard-deviation of the tensor (the first time we prune), and relies on the fact that as the tensor is pruned, more elements are \"pulled\" toward the center of the distribution and thus more elements gets pruned. This actually works quite well as we can see in the diagram below. This is a TensorBoard screen-capture from Alexnet training, which shows how this method starts off pruning very aggressively, but then slowly reduces the pruning rate. We use a simple iterative-pruning schedule such as: Prune every second epoch starting at epoch 0, and ending at epoch 38. This excerpt from alexnet.schedule_sensitivity.yaml shows how this iterative schedule is conveyed in Distiller scheduling configuration YAML: pruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2", + "location": "/algo_pruning/index.html#schedule", + "text": "In their paper Song Han et al. use iterative pruning and change the value of the \\(s\\) multiplier at each pruning step. Distiller's SensitivityPruner works differently: the value \\(s\\) is set once based on a one-time calculation of the standard-deviation of the tensor (the first time we prune), and relies on the fact that as the tensor is pruned, more elements are \"pulled\" toward the center of the distribution and thus more elements gets pruned. This actually works quite well as we can see in the diagram below. This is a TensorBoard screen-capture from Alexnet training, which shows how this method starts off pruning very aggressively, but then slowly reduces the pruning rate. We use a simple iterative-pruning schedule such as: Prune every second epoch starting at epoch 0, and ending at epoch 38. This excerpt from alexnet.schedule_sensitivity.yaml shows how this iterative schedule is conveyed in Distiller scheduling configuration YAML: pruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2", "title": "Schedule" - }, + }, { - "location": "/algo_pruning/index.html#level-pruner", - "text": "Class SparsityLevelParameterPruner uses a similar method to go around specifying specific thresholding magnitudes.\nInstead of specifying a threshold magnitude, you specify a target sparsity level (expressed as a fraction, so 0.5 means 50% sparsity). Essentially this pruner also uses a pruning criteria based on the magnitude of each tensor element, but it has the advantage that you can aim for an exact and specific sparsity level. \nThis pruner is much more stable compared to SensitivityPruner because the target sparsity level is not coupled to the actual magnitudes of the elements. Distiller's SensitivityPruner is unstable because the final sparsity level depends on the convergence pattern of the tensor distribution. Song Han's methodology of using several different values for the multiplier \\(s\\), and the recalculation of the standard-deviation at each pruning phase, probably gives it stability, but requires much more hyper-parameters (this is the reason we have not implemented it thus far). To set the target sparsity levels, you can once again use pruning sensitivity analysis to make better guesses at the correct sparsity level of each", + "location": "/algo_pruning/index.html#level-pruner", + "text": "Class SparsityLevelParameterPruner uses a similar method to go around specifying specific thresholding magnitudes.\nInstead of specifying a threshold magnitude, you specify a target sparsity level (expressed as a fraction, so 0.5 means 50% sparsity). Essentially this pruner also uses a pruning criteria based on the magnitude of each tensor element, but it has the advantage that you can aim for an exact and specific sparsity level. \nThis pruner is much more stable compared to SensitivityPruner because the target sparsity level is not coupled to the actual magnitudes of the elements. Distiller's SensitivityPruner is unstable because the final sparsity level depends on the convergence pattern of the tensor distribution. Song Han's methodology of using several different values for the multiplier \\(s\\), and the recalculation of the standard-deviation at each pruning phase, probably gives it stability, but requires much more hyper-parameters (this is the reason we have not implemented it thus far). To set the target sparsity levels, you can once again use pruning sensitivity analysis to make better guesses at the correct sparsity level of each", "title": "Level pruner" - }, + }, { - "location": "/algo_pruning/index.html#method-of-operation_1", - "text": "Sort the weights in the specified layer by their absolute values. Mask to zero the smallest magnitude weights until the desired sparsity level is reached.", + "location": "/algo_pruning/index.html#method-of-operation_1", + "text": "Sort the weights in the specified layer by their absolute values. Mask to zero the smallest magnitude weights until the desired sparsity level is reached.", "title": "Method of operation" - }, + }, { - "location": "/algo_pruning/index.html#automated-gradual-pruner-agp", - "text": "In To prune, or not to prune: exploring the efficacy of pruning for model compression , authors Michael Zhu and Suyog Gupta provide an algorithm to schedule a Level Pruner which Distiller implements in AutomatedGradualPruner . \"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value \\(s_i\\) (usually 0) to a \ufb01nal sparsity value \\(s_f\\) over a span of n pruning steps.\nThe intuition behind this sparsity function in equation (1) is to prune the network rapidly in the initial phase when the redundant connections are\nabundant and gradually reduce the number of weights being pruned each time as there are fewer and fewer weights remaining in the network.\"\" You can play with the scheduling parameters in the agp_schedule.ipynb notebook . The authors describe AGP: Our automated gradual pruning algorithm prunes the smallest magnitude weights to achieve a preset level of network sparsity. Doesn't require much hyper-parameter tuning Shown to perform well across different models Does not make any assumptions about the structure of the network or its constituent layers, and is therefore more generally applicable.", + "location": "/algo_pruning/index.html#automated-gradual-pruner-agp", + "text": "In To prune, or not to prune: exploring the efficacy of pruning for model compression , authors Michael Zhu and Suyog Gupta provide an algorithm to schedule a Level Pruner which Distiller implements in AutomatedGradualPruner . \"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value \\(s_i\\) (usually 0) to a \ufb01nal sparsity value \\(s_f\\) over a span of n pruning steps.\nThe intuition behind this sparsity function in equation (1) is to prune the network rapidly in the initial phase when the redundant connections are\nabundant and gradually reduce the number of weights being pruned each time as there are fewer and fewer weights remaining in the network.\"\" You can play with the scheduling parameters in the agp_schedule.ipynb notebook . The authors describe AGP: Our automated gradual pruning algorithm prunes the smallest magnitude weights to achieve a preset level of network sparsity. Doesn't require much hyper-parameter tuning Shown to perform well across different models Does not make any assumptions about the structure of the network or its constituent layers, and is therefore more generally applicable.", "title": "Automated gradual pruner (AGP)" - }, + }, { - "location": "/algo_pruning/index.html#rnn-pruner", - "text": "The authors of Exploring Sparsity in Recurrent Neural Networks , Sharan Narang, Erich Elsen, Gregory Diamos, and Shubho Sengupta, \"propose a technique to reduce the parameters of a network by pruning weights during the initial training of the network.\" They use a gradual pruning schedule which is reminiscent of the schedule used in AGP, for element-wise pruning of RNNs, which they also employ during training. They show pruning of RNN, GRU, LSTM and embedding layers. Distiller's distiller.pruning.BaiduRNNPruner class implements this pruning algorithm.", + "location": "/algo_pruning/index.html#rnn-pruner", + "text": "The authors of Exploring Sparsity in Recurrent Neural Networks , Sharan Narang, Erich Elsen, Gregory Diamos, and Shubho Sengupta, \"propose a technique to reduce the parameters of a network by pruning weights during the initial training of the network.\" They use a gradual pruning schedule which is reminiscent of the schedule used in AGP, for element-wise pruning of RNNs, which they also employ during training. They show pruning of RNN, GRU, LSTM and embedding layers. Distiller's distiller.pruning.BaiduRNNPruner class implements this pruning algorithm.", "title": "RNN pruner" - }, + }, { - "location": "/algo_pruning/index.html#structure-pruners", - "text": "Element-wise pruning can create very sparse models which can be compressed to consume less memory footprint and bandwidth, but without specialized hardware that can compute using the sparse representation of the tensors, we don't gain any speedup of the computation. Structure pruners, remove entire \"structures\", such as kernels, filters, and even entire feature-maps.", + "location": "/algo_pruning/index.html#structure-pruners", + "text": "Element-wise pruning can create very sparse models which can be compressed to consume less memory footprint and bandwidth, but without specialized hardware that can compute using the sparse representation of the tensors, we don't gain any speedup of the computation. Structure pruners, remove entire \"structures\", such as kernels, filters, and even entire feature-maps.", "title": "Structure pruners" - }, + }, { - "location": "/algo_pruning/index.html#ranked-structure-pruner", - "text": "The L1RankedStructureParameterPruner pruner calculates the magnitude of some \"structure\", orders all of the structures based on some magnitude function and the m lowest ranking structures are pruned away. Currently this pruner only performs ranking of filters (3D structures) and it uses the mean of the absolute value of the tensor as the representative of the filter magnitude. The absolute mean does not depend on the size of the filter, so it is easier to use compared to just using the \\(L_1\\)-norm of the structure, and at the same time it is a good proxy of the \\(L_1\\)-norm. In Pruning Filters for Efficient ConvNets the authors use filter ranking, with one-shot pruning followed by fine-tuning. The authors of Exploiting Sparseness in Deep Neural Networks for Large Vocabulary Speech Recognition also use a one-shot pruning schedule, for fully-connected layers, and they provide an explanation: First, after sweeping through the full training set several times the weights become relatively stable \u2014 they tend to remain either large or small magnitudes. Second, in a stabilized model, the importance of the connection is approximated well by the magnitudes of the weights (times the magnitudes of the corresponding input values, but these are relatively uniform within each layer since on the input layer, features are normalized to zero-mean and unit-variance, and hidden-layer values are probabilities)", + "location": "/algo_pruning/index.html#ranked-structure-pruner", + "text": "The L1RankedStructureParameterPruner pruner calculates the magnitude of some \"structure\", orders all of the structures based on some magnitude function and the m lowest ranking structures are pruned away. Currently this pruner only performs ranking of filters (3D structures) and it uses the mean of the absolute value of the tensor as the representative of the filter magnitude. The absolute mean does not depend on the size of the filter, so it is easier to use compared to just using the \\(L_1\\)-norm of the structure, and at the same time it is a good proxy of the \\(L_1\\)-norm. In Pruning Filters for Efficient ConvNets the authors use filter ranking, with one-shot pruning followed by fine-tuning. The authors of Exploiting Sparseness in Deep Neural Networks for Large Vocabulary Speech Recognition also use a one-shot pruning schedule, for fully-connected layers, and they provide an explanation: First, after sweeping through the full training set several times the weights become relatively stable \u2014 they tend to remain either large or small magnitudes. Second, in a stabilized model, the importance of the connection is approximated well by the magnitudes of the weights (times the magnitudes of the corresponding input values, but these are relatively uniform within each layer since on the input layer, features are normalized to zero-mean and unit-variance, and hidden-layer values are probabilities)", "title": "Ranked structure pruner" - }, + }, { - "location": "/algo_pruning/index.html#activation-influenced-pruner", - "text": "The motivation for this pruner, is that if a feature-map produces very small activations, then this feature-map is not very important, and can be pruned away.\n- Status: not implemented", + "location": "/algo_pruning/index.html#activation-influenced-pruner", + "text": "The motivation for this pruner, is that if a feature-map produces very small activations, then this feature-map is not very important, and can be pruned away.\n- Status: not implemented", "title": "Activation-influenced pruner" - }, + }, { - "location": "/algo_quantization/index.html", - "text": "Quantization Algorithms\n\n\nThe following quantization methods are currently implemented in Distiller:\n\n\nDoReFa\n\n\n(As proposed in \nDoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients\n) \n\n\nIn this method, we first define the quantization function \nquantize_k\n, which takes a real value \na_f \\in [0, 1]\n and outputs a discrete-valued \na_q \\in \\left\\{ \\frac{0}{2^k-1}, \\frac{1}{2^k-1}, ... , \\frac{2^k-1}{2^k-1} \\right\\}\n, where \nk\n is the number of bits used for quantization.\n\n\n\n\na_q = quantize_k(a_f) = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) a_f \\right)\n\n\n\n\nActivations are clipped to the \n[0, 1]\n range and then quantized as follows:\n\n\n\n\nx_q = quantize_k(x_f)\n\n\n\n\nFor weights, we define the following function \nf\n, which takes an unbounded real valued input and outputs a real value in \n[0, 1]\n:\n\n\n\n\nf(w) = \\frac{tanh(w)}{2 max(|tanh(w)|)} + \\frac{1}{2} \n\n\n\n\nNow we can use \nquantize_k\n to get quantized weight values, as follows:\n\n\n\n\nw_q = 2 quantize_k \\left( f(w_f) \\right) - 1\n\n\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nDorefaQuantizer\n class to transform an existing model to a model suitable for training with quantization using DoReFa.\n\n\nNotes:\n\n\n\n\nGradients quantization as proposed in the paper is not supported yet.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nPACT\n\n\n(As proposed in \nPACT: Parameterized Clipping Activation for Quantized Neural Networks\n)\n\n\nThis method is similar to DoReFa, but the upper clipping values, \n\\alpha\n, of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \n\\alpha\n is shared per layer.\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nPACTQuantizer\n class to transform an existing model to a model suitable for training with quantization using PACT.\n\n\nWRPN\n\n\n(As proposed in \nWRPN: Wide Reduced-Precision Networks\n) \n\n\nIn this method, activations are clipped to \n[0, 1]\n and quantized as follows (\nk\n is the number of bits used for quantization):\n\n\n\n\nx_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right)\n\n\n\n\nWeights are clipped to \n[-1, 1]\n and quantized as follows:\n\n\n\n\nw_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right)\n\n\n\n\nNote that \nk-1\n bits are used to quantize weights, leaving one bit for sign.\n\n\nThis method requires training the model with quantization, as discussed \nhere\n. Use the \nWRPNQuantizer\n class to transform an existing model to a model suitable for training with quantization using WRPN.\n\n\nNotes:\n\n\n\n\nThe paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of \nWRPNQuantizer\n at the moment. To experiment with this, modify your model implementation to have wider layers.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nSymmetric Linear Quantization\n\n\nIn this method, a float value is quantized by multiplying with a numeric constant (the \nscale factor\n), hence it is \nLinear\n. We use a signed integer to represent the quantized range, with no quantization bias (or \"offset\") used. As a result, the floating-point range considered for quantization is \nsymmetric\n with respect to zero.\n\nIn the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers).\n\nLet us denote the original floating-point tensor by \nx_f\n, the quantized tensor by \nx_q\n, the scale factor by \nq_x\n and the number of bits used for quantization by \nn\n. Then, we get:\n\nq_x = \\frac{2^{n-1}-1}{\\max|x|}\n\n\nx_q = round(q_x x_f)\n\n(The \nround\n operation is round-to-nearest-integer) \n\n\nLet's see how a \nconvolution\n or \nfully-connected (FC)\n layer is quantized using this method: (we denote input, output, weights and bias with \nx, y, w\n and \nb\n respectively)\n\ny_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q}{q_x} \\frac{w_q}{q_w}} + \\frac{b_q}{q_b} = \\frac{1}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right)\n\n\ny_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) \\right) \n\nNote how the bias has to be re-scaled to match the scale of the summation.\n\n\nImplementation\n\n\nWe've implemented \nconvolution\n and \nFC\n using this method. \n\n\n\n\nThey are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the \nRangeLinearQuantParamLayerWrapper\n class. \n\n\nAll other layers are unaffected and are executed using their original FP32 implementation. \n\n\nTo automatically transform an existing model to a quantized model using this method, use the \nSymmetricLinearQuantizer\n class.\n\n\nFor weights and bias the scale factor is determined once at quantization setup (\"offline\"), and for activations it is determined dynamically at runtime (\"online\"). \n\n\nImportant note:\n Currently, this method is implemented as \ninference only\n, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with \nn < 8\n is likely to lead to severe accuracy degradation for any non-trivial workload.", + "location": "/algo_quantization/index.html", + "text": "Quantization Algorithms\n\n\nNote:\n\nFor any of the methods below that require quantization-aware training, please see \nhere\n for details on how to invoke it using Distiller's scheduling mechanism.\n\n\nRange-Based Linear Quantization\n\n\nLet's break down the terminology we use here:\n\n\n\n\nLinear:\n Means a float value is quantized by multiplying with a numeric constant (the \nscale factor\n).\n\n\nRange-Based:\n: Means that in order to calculate the scale factor, we look at the actual range of the tensor's values. In the most naive implementation, we use the actual min/max values of the tensor. Alternatively, we use some derivation based on the tensor's range / distribution to come up with a narrower min/max range, in order to remove possible outliers. This is in contrast to the other methods described here, which we could call \nclipping-based\n, as they impose an explicit clipping function on the tensors (using either a hard-coded value or a learned value).\n\n\n\n\nAsymmetric vs. Symmetric\n\n\nIn this method we can use two modes - \nasymmetric\n and \nsymmetric\n.\n\n\nAsymmetric Mode\n\n\n\n \n\n\n\n\n\nIn \nasymmetric\n mode, we map the min/max in the float range to the min/max of the integer range. This is done by using a \nzero-point\n (also called \nquantization bias\n, or \noffset\n) in addition to the scale factor.\n\n\nLet us denote the original floating-point tensor by \nx_f\n, the quantized tensor by \nx_q\n, the scale factor by \nq_x\n, the zero-point by \nzp_x\n and the number of bits used for quantization by \nn\n. Then, we get:\n\n\n\n\nx_q = round\\left ((x_f - min_{x_f})\\underbrace{\\frac{2^n - 1}{max_{x_f} - min_{x_f}}}_{q_x} \\right) = round(q_x x_f - \\underbrace{min_{x_f}q_x)}_{zp_x} = round(q_x x_f - zp_x)\n\n\n\n\nIn practice, we actually use \nzp_x = round(min_{x_f}q_x)\n. This means that zero is exactly representable by an integer in the quantized range. This is important, for example, for layers that have zero-padding. By rounding the zero-point, we effectively \"nudge\" the min/max values in the float range a little bit, in order to gain this exact quantization of zero.\n\n\nNote that in the derivation above we use unsigned integer to represent the quantized range. That is, \nx_q \\in [0, 2^n-1]\n. One could use signed integer if necessary (perhaps due to HW considerations). This can be achieved by subtracting \n2^{n-1}\n.\n\n\nLet's see how a \nconvolution\n or \nfully-connected (FC)\n layer is quantized in asymmetric mode: (we denote input, output, weights and bias with \nx, y, w\n and \nb\n respectively)\n\n\n\n\ny_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q + zp_x}{q_x} \\frac{w_q + zp_w}{q_w}} + \\frac{b_q + zp_b}{q_b} =\n\n\n = \\frac{1}{q_x q_w} \\left( \\sum { (x_q + zp_x) (w_q + zp_w) + \\frac{q_x q_w}{q_b}(b_q + zp_b) } \\right)\n\n\n\n\nTherefore:\n\n\n\n\ny_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { (x_q+zp_x) (w_q+zp_w) + \\frac{q_x q_w}{q_b}(b_q+zp_b) } \\right) \\right) \n\n\n\n\nNotes:\n\n\n\n\nWe can see that the bias has to be re-scaled to match the scale of the summation.\n\n\nIn a proper integer-only HW pipeline, we would like our main accumulation term to simply be \n\\sum{x_q w_q}\n. In order to achieve this, one needs to further develop the expression we derived above. For further details please refer to the \ngemmlowp documentation\n\n\n\n\nSymmetric Mode\n\n\n\n \n\n\n\n\n\nIn \nsymmetric\n mode, instead of mapping the exact min/max of the float range to the quantized range, we choose the maximum absolute value between min/max. In addition, we don't use a zero-point. So, the floating-point range we're effectively quantizing is symmetric with respect to zero, and so is the quantized range.\n\n\nUsing the same notations as above, we get:\n\n\n\n\nx_q = round\\left (x_f \\underbrace{\\frac{2^{n-1} - 1}{\\max|x_f|}}_{q_x} \\right) = round(q_x x_f)\n\n\n\n\nAgain, let's see how a \nconvolution\n or \nfully-connected (FC)\n layer is quantized, this time in symmetric mode:\n\n\n\n\ny_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q}{q_x} \\frac{w_q}{q_w}} + \\frac{b_q}{q_b} = \\frac{1}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right)\n\n\n\n\nTherefore:\n\n\n\n\ny_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) \\right) \n\n\n\n\nComparing the Two Modes\n\n\nThe main trade-off between these two modes is simplicity vs. utilization of the quantized range.\n\n\n\n\nWhen using asymmetric quantization, the quantized range is fully utilized. That is because we exactly map the min/max values from the float range to the min/max of the quantized range. Using symmetric mode, if the float range is biased towards one side, could result in a quantized range where significant dynamic range is dedicated to values that we'll never see. The most extreme example of this is after ReLU, where the entire tensor is positive. Quantizing it in symmetric mode means we're effectively losing 1 bit.\n\n\nOn the other hand, if we look at the derviations for convolution / FC layers above, we can see that the actual implementation of symmetric mode is much simpler. In asymmetric mode, the zero-points require additional logic in HW. The cost of this extra logic in terms of latency and/or power and/or area will of course depend on the exact implementation.\n\n\n\n\nOther Features\n\n\n\n\nRemoving Outliers:\n As discussed \nhere\n, in some cases the float range of activations contains outliers. Spending dynamic range on these outliers hurts our ability ro represent the values we actually care about accurately.\n \n\n \n\n \n\n Currently, Distiller supports clipping of activations with averaging during post-training quantization. That is - for each batch, instead of calculating global min/max values, an average of the min/max values of each sample in the batch.\n\n\nScale factor scope:\n For weight tensors, Distiller supports per-channel quantization (per output channel).\n\n\n\n\nImplementation in Distiller\n\n\nPost-Training\n\n\nFor post-training quantization, currently \nconvolution\n and \nFC\n are supported using this method. \n\n\n\n\nThey are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the \nRangeLinearQuantParamLayerWrapper\n class. \n\n\nAll other layers are unaffected and are executed using their original FP32 implementation. \n\n\nTo automatically transform an existing model to a quantized model using this method, use the \nPostTrainLinearQuantizer\n class. For an example of how to do this, see the \ncompress_classifier.py\n. This sample also exposes command line arguments to invoke post-training quantization. For details see \nhere\n.\n\n\nFor weights and bias the scale factor and zero-point are determined once at quantization setup (\"offline\"), and for activations it is determined dynamically at runtime (\"online\"). The calculated quantization parameters are store as buffers within the module, so they are automatically serialized when the model checkpoint is saved.\n\n\nAs this is post-training, using it with number of bits \n 8 is likely to lead to severe accuracy degradation for any non-trivial workload.\n\n\n\n\nQuantization-Aware Training\n\n\nTo apply range-based linear quantization in training, use the \nQuantAwareTrainRangeLinearQuantizer\n class. As it is now, it will apply weights quantization to convolution and FC modules. For activations quantization, it will insert instances \nFakeLinearQuantization\n module after ReLUs. This module follows the methodology described in \nBenoit et al., 2018\n and uses exponential moving averages to track activation ranges.\n\n\nSimilarly to post-training, the calculated quantization parameters (scale factors, zero-points, tracked activation ranges) are stored as buffers within their respective modules, so they're saved when a checkpoint is created.\n\n\nNote that converting from a quantization-aware training model to a post-training quantization model is not yet supported. Such a conversion will use the activation ranges tracked during training, so additional offline or online calculation of quantization parameters will not be required.\n\n\nDoReFa\n\n\n(As proposed in \nDoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients\n) \n\n\nIn this method, we first define the quantization function \nquantize_k\n, which takes a real value \na_f \\in [0, 1]\n and outputs a discrete-valued \na_q \\in \\left\\{ \\frac{0}{2^k-1}, \\frac{1}{2^k-1}, ... , \\frac{2^k-1}{2^k-1} \\right\\}\n, where \nk\n is the number of bits used for quantization.\n\n\n\n\na_q = quantize_k(a_f) = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) a_f \\right)\n\n\n\n\nActivations are clipped to the \n[0, 1]\n range and then quantized as follows:\n\n\n\n\nx_q = quantize_k(x_f)\n\n\n\n\nFor weights, we define the following function \nf\n, which takes an unbounded real valued input and outputs a real value in \n[0, 1]\n:\n\n\n\n\nf(w) = \\frac{tanh(w)}{2 max(|tanh(w)|)} + \\frac{1}{2} \n\n\n\n\nNow we can use \nquantize_k\n to get quantized weight values, as follows:\n\n\n\n\nw_q = 2 quantize_k \\left( f(w_f) \\right) - 1\n\n\n\n\nThis method requires training the model with quantization-aware training, as discussed \nhere\n. Use the \nDorefaQuantizer\n class to transform an existing model to a model suitable for training with quantization using DoReFa.\n\n\nNotes:\n\n\n\n\nGradients quantization as proposed in the paper is not supported yet.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.\n\n\n\n\nPACT\n\n\n(As proposed in \nPACT: Parameterized Clipping Activation for Quantized Neural Networks\n)\n\n\nThis method is similar to DoReFa, but the upper clipping values, \n\\alpha\n, of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \n\\alpha\n is shared per layer.\n\n\nThis method requires training the model with quantization-aware training, as discussed \nhere\n. Use the \nPACTQuantizer\n class to transform an existing model to a model suitable for training with quantization using PACT.\n\n\nWRPN\n\n\n(As proposed in \nWRPN: Wide Reduced-Precision Networks\n) \n\n\nIn this method, activations are clipped to \n[0, 1]\n and quantized as follows (\nk\n is the number of bits used for quantization):\n\n\n\n\nx_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right)\n\n\n\n\nWeights are clipped to \n[-1, 1]\n and quantized as follows:\n\n\n\n\nw_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right)\n\n\n\n\nNote that \nk-1\n bits are used to quantize weights, leaving one bit for sign.\n\n\nThis method requires training the model with quantization-aware training, as discussed \nhere\n. Use the \nWRPNQuantizer\n class to transform an existing model to a model suitable for training with quantization using WRPN.\n\n\nNotes:\n\n\n\n\nThe paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of \nWRPNQuantizer\n at the moment. To experiment with this, modify your model implementation to have wider layers.\n\n\nThe paper defines special handling for binary weights which isn't supported in Distiller yet.", "title": "Quantization" - }, + }, { - "location": "/algo_quantization/index.html#quantization-algorithms", - "text": "The following quantization methods are currently implemented in Distiller:", + "location": "/algo_quantization/index.html#quantization-algorithms", + "text": "Note: \nFor any of the methods below that require quantization-aware training, please see here for details on how to invoke it using Distiller's scheduling mechanism.", "title": "Quantization Algorithms" - }, - { - "location": "/algo_quantization/index.html#dorefa", - "text": "(As proposed in DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients ) In this method, we first define the quantization function quantize_k , which takes a real value a_f \\in [0, 1] and outputs a discrete-valued a_q \\in \\left\\{ \\frac{0}{2^k-1}, \\frac{1}{2^k-1}, ... , \\frac{2^k-1}{2^k-1} \\right\\} , where k is the number of bits used for quantization. a_q = quantize_k(a_f) = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) a_f \\right) Activations are clipped to the [0, 1] range and then quantized as follows: x_q = quantize_k(x_f) For weights, we define the following function f , which takes an unbounded real valued input and outputs a real value in [0, 1] : f(w) = \\frac{tanh(w)}{2 max(|tanh(w)|)} + \\frac{1}{2} Now we can use quantize_k to get quantized weight values, as follows: w_q = 2 quantize_k \\left( f(w_f) \\right) - 1 This method requires training the model with quantization, as discussed here . Use the DorefaQuantizer class to transform an existing model to a model suitable for training with quantization using DoReFa.", + }, + { + "location": "/algo_quantization/index.html#range-based-linear-quantization", + "text": "Let's break down the terminology we use here: Linear: Means a float value is quantized by multiplying with a numeric constant (the scale factor ). Range-Based: : Means that in order to calculate the scale factor, we look at the actual range of the tensor's values. In the most naive implementation, we use the actual min/max values of the tensor. Alternatively, we use some derivation based on the tensor's range / distribution to come up with a narrower min/max range, in order to remove possible outliers. This is in contrast to the other methods described here, which we could call clipping-based , as they impose an explicit clipping function on the tensors (using either a hard-coded value or a learned value).", + "title": "Range-Based Linear Quantization" + }, + { + "location": "/algo_quantization/index.html#asymmetric-vs-symmetric", + "text": "In this method we can use two modes - asymmetric and symmetric .", + "title": "Asymmetric vs. Symmetric" + }, + { + "location": "/algo_quantization/index.html#asymmetric-mode", + "text": "In asymmetric mode, we map the min/max in the float range to the min/max of the integer range. This is done by using a zero-point (also called quantization bias , or offset ) in addition to the scale factor. Let us denote the original floating-point tensor by x_f , the quantized tensor by x_q , the scale factor by q_x , the zero-point by zp_x and the number of bits used for quantization by n . Then, we get: x_q = round\\left ((x_f - min_{x_f})\\underbrace{\\frac{2^n - 1}{max_{x_f} - min_{x_f}}}_{q_x} \\right) = round(q_x x_f - \\underbrace{min_{x_f}q_x)}_{zp_x} = round(q_x x_f - zp_x) In practice, we actually use zp_x = round(min_{x_f}q_x) . This means that zero is exactly representable by an integer in the quantized range. This is important, for example, for layers that have zero-padding. By rounding the zero-point, we effectively \"nudge\" the min/max values in the float range a little bit, in order to gain this exact quantization of zero. Note that in the derivation above we use unsigned integer to represent the quantized range. That is, x_q \\in [0, 2^n-1] . One could use signed integer if necessary (perhaps due to HW considerations). This can be achieved by subtracting 2^{n-1} . Let's see how a convolution or fully-connected (FC) layer is quantized in asymmetric mode: (we denote input, output, weights and bias with x, y, w and b respectively) y_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q + zp_x}{q_x} \\frac{w_q + zp_w}{q_w}} + \\frac{b_q + zp_b}{q_b} = = \\frac{1}{q_x q_w} \\left( \\sum { (x_q + zp_x) (w_q + zp_w) + \\frac{q_x q_w}{q_b}(b_q + zp_b) } \\right) Therefore: y_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { (x_q+zp_x) (w_q+zp_w) + \\frac{q_x q_w}{q_b}(b_q+zp_b) } \\right) \\right) Notes: We can see that the bias has to be re-scaled to match the scale of the summation. In a proper integer-only HW pipeline, we would like our main accumulation term to simply be \\sum{x_q w_q} . In order to achieve this, one needs to further develop the expression we derived above. For further details please refer to the gemmlowp documentation", + "title": "Asymmetric Mode" + }, + { + "location": "/algo_quantization/index.html#symmetric-mode", + "text": "In symmetric mode, instead of mapping the exact min/max of the float range to the quantized range, we choose the maximum absolute value between min/max. In addition, we don't use a zero-point. So, the floating-point range we're effectively quantizing is symmetric with respect to zero, and so is the quantized range. Using the same notations as above, we get: x_q = round\\left (x_f \\underbrace{\\frac{2^{n-1} - 1}{\\max|x_f|}}_{q_x} \\right) = round(q_x x_f) Again, let's see how a convolution or fully-connected (FC) layer is quantized, this time in symmetric mode: y_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q}{q_x} \\frac{w_q}{q_w}} + \\frac{b_q}{q_b} = \\frac{1}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) Therefore: y_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) \\right)", + "title": "Symmetric Mode" + }, + { + "location": "/algo_quantization/index.html#comparing-the-two-modes", + "text": "The main trade-off between these two modes is simplicity vs. utilization of the quantized range. When using asymmetric quantization, the quantized range is fully utilized. That is because we exactly map the min/max values from the float range to the min/max of the quantized range. Using symmetric mode, if the float range is biased towards one side, could result in a quantized range where significant dynamic range is dedicated to values that we'll never see. The most extreme example of this is after ReLU, where the entire tensor is positive. Quantizing it in symmetric mode means we're effectively losing 1 bit. On the other hand, if we look at the derviations for convolution / FC layers above, we can see that the actual implementation of symmetric mode is much simpler. In asymmetric mode, the zero-points require additional logic in HW. The cost of this extra logic in terms of latency and/or power and/or area will of course depend on the exact implementation.", + "title": "Comparing the Two Modes" + }, + { + "location": "/algo_quantization/index.html#other-features", + "text": "Removing Outliers: As discussed here , in some cases the float range of activations contains outliers. Spending dynamic range on these outliers hurts our ability ro represent the values we actually care about accurately.\n \n \n \n Currently, Distiller supports clipping of activations with averaging during post-training quantization. That is - for each batch, instead of calculating global min/max values, an average of the min/max values of each sample in the batch. Scale factor scope: For weight tensors, Distiller supports per-channel quantization (per output channel).", + "title": "Other Features" + }, + { + "location": "/algo_quantization/index.html#implementation-in-distiller", + "text": "", + "title": "Implementation in Distiller" + }, + { + "location": "/algo_quantization/index.html#post-training", + "text": "For post-training quantization, currently convolution and FC are supported using this method. They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the RangeLinearQuantParamLayerWrapper class. All other layers are unaffected and are executed using their original FP32 implementation. To automatically transform an existing model to a quantized model using this method, use the PostTrainLinearQuantizer class. For an example of how to do this, see the compress_classifier.py . This sample also exposes command line arguments to invoke post-training quantization. For details see here . For weights and bias the scale factor and zero-point are determined once at quantization setup (\"offline\"), and for activations it is determined dynamically at runtime (\"online\"). The calculated quantization parameters are store as buffers within the module, so they are automatically serialized when the model checkpoint is saved. As this is post-training, using it with number of bits 8 is likely to lead to severe accuracy degradation for any non-trivial workload.", + "title": "Post-Training" + }, + { + "location": "/algo_quantization/index.html#quantization-aware-training", + "text": "To apply range-based linear quantization in training, use the QuantAwareTrainRangeLinearQuantizer class. As it is now, it will apply weights quantization to convolution and FC modules. For activations quantization, it will insert instances FakeLinearQuantization module after ReLUs. This module follows the methodology described in Benoit et al., 2018 and uses exponential moving averages to track activation ranges. Similarly to post-training, the calculated quantization parameters (scale factors, zero-points, tracked activation ranges) are stored as buffers within their respective modules, so they're saved when a checkpoint is created. Note that converting from a quantization-aware training model to a post-training quantization model is not yet supported. Such a conversion will use the activation ranges tracked during training, so additional offline or online calculation of quantization parameters will not be required.", + "title": "Quantization-Aware Training" + }, + { + "location": "/algo_quantization/index.html#dorefa", + "text": "(As proposed in DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients ) In this method, we first define the quantization function quantize_k , which takes a real value a_f \\in [0, 1] and outputs a discrete-valued a_q \\in \\left\\{ \\frac{0}{2^k-1}, \\frac{1}{2^k-1}, ... , \\frac{2^k-1}{2^k-1} \\right\\} , where k is the number of bits used for quantization. a_q = quantize_k(a_f) = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) a_f \\right) Activations are clipped to the [0, 1] range and then quantized as follows: x_q = quantize_k(x_f) For weights, we define the following function f , which takes an unbounded real valued input and outputs a real value in [0, 1] : f(w) = \\frac{tanh(w)}{2 max(|tanh(w)|)} + \\frac{1}{2} Now we can use quantize_k to get quantized weight values, as follows: w_q = 2 quantize_k \\left( f(w_f) \\right) - 1 This method requires training the model with quantization-aware training, as discussed here . Use the DorefaQuantizer class to transform an existing model to a model suitable for training with quantization using DoReFa.", "title": "DoReFa" - }, + }, { - "location": "/algo_quantization/index.html#notes", - "text": "Gradients quantization as proposed in the paper is not supported yet. The paper defines special handling for binary weights which isn't supported in Distiller yet.", + "location": "/algo_quantization/index.html#notes", + "text": "Gradients quantization as proposed in the paper is not supported yet. The paper defines special handling for binary weights which isn't supported in Distiller yet.", "title": "Notes:" - }, + }, { - "location": "/algo_quantization/index.html#pact", - "text": "(As proposed in PACT: Parameterized Clipping Activation for Quantized Neural Networks ) This method is similar to DoReFa, but the upper clipping values, \\alpha , of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \\alpha is shared per layer. This method requires training the model with quantization, as discussed here . Use the PACTQuantizer class to transform an existing model to a model suitable for training with quantization using PACT.", + "location": "/algo_quantization/index.html#pact", + "text": "(As proposed in PACT: Parameterized Clipping Activation for Quantized Neural Networks ) This method is similar to DoReFa, but the upper clipping values, \\alpha , of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \\alpha is shared per layer. This method requires training the model with quantization-aware training, as discussed here . Use the PACTQuantizer class to transform an existing model to a model suitable for training with quantization using PACT.", "title": "PACT" - }, + }, { - "location": "/algo_quantization/index.html#wrpn", - "text": "(As proposed in WRPN: Wide Reduced-Precision Networks ) In this method, activations are clipped to [0, 1] and quantized as follows ( k is the number of bits used for quantization): x_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right) Weights are clipped to [-1, 1] and quantized as follows: w_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right) Note that k-1 bits are used to quantize weights, leaving one bit for sign. This method requires training the model with quantization, as discussed here . Use the WRPNQuantizer class to transform an existing model to a model suitable for training with quantization using WRPN.", + "location": "/algo_quantization/index.html#wrpn", + "text": "(As proposed in WRPN: Wide Reduced-Precision Networks ) In this method, activations are clipped to [0, 1] and quantized as follows ( k is the number of bits used for quantization): x_q = \\frac{1}{2^k-1} round \\left( \\left(2^k - 1 \\right) x_f \\right) Weights are clipped to [-1, 1] and quantized as follows: w_q = \\frac{1}{2^{k-1}-1} round \\left( \\left(2^{k-1} - 1 \\right)w_f \\right) Note that k-1 bits are used to quantize weights, leaving one bit for sign. This method requires training the model with quantization-aware training, as discussed here . Use the WRPNQuantizer class to transform an existing model to a model suitable for training with quantization using WRPN.", "title": "WRPN" - }, + }, { - "location": "/algo_quantization/index.html#notes_1", - "text": "The paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of WRPNQuantizer at the moment. To experiment with this, modify your model implementation to have wider layers. The paper defines special handling for binary weights which isn't supported in Distiller yet.", + "location": "/algo_quantization/index.html#notes_1", + "text": "The paper proposed widening of layers as a means to reduce accuracy loss. This isn't implemented as part of WRPNQuantizer at the moment. To experiment with this, modify your model implementation to have wider layers. The paper defines special handling for binary weights which isn't supported in Distiller yet.", "title": "Notes:" - }, - { - "location": "/algo_quantization/index.html#symmetric-linear-quantization", - "text": "In this method, a float value is quantized by multiplying with a numeric constant (the scale factor ), hence it is Linear . We use a signed integer to represent the quantized range, with no quantization bias (or \"offset\") used. As a result, the floating-point range considered for quantization is symmetric with respect to zero. \nIn the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers). \nLet us denote the original floating-point tensor by x_f , the quantized tensor by x_q , the scale factor by q_x and the number of bits used for quantization by n . Then, we get: q_x = \\frac{2^{n-1}-1}{\\max|x|} x_q = round(q_x x_f) \n(The round operation is round-to-nearest-integer) Let's see how a convolution or fully-connected (FC) layer is quantized using this method: (we denote input, output, weights and bias with x, y, w and b respectively) y_f = \\sum{x_f w_f} + b_f = \\sum{\\frac{x_q}{q_x} \\frac{w_q}{q_w}} + \\frac{b_q}{q_b} = \\frac{1}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) y_q = round(q_y y_f) = round\\left(\\frac{q_y}{q_x q_w} \\left( \\sum { x_q w_q + \\frac{q_x q_w}{q_b}b_q } \\right) \\right) \nNote how the bias has to be re-scaled to match the scale of the summation.", - "title": "Symmetric Linear Quantization" - }, - { - "location": "/algo_quantization/index.html#implementation", - "text": "We've implemented convolution and FC using this method. They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. The wrapper is implemented in the RangeLinearQuantParamLayerWrapper class. All other layers are unaffected and are executed using their original FP32 implementation. To automatically transform an existing model to a quantized model using this method, use the SymmetricLinearQuantizer class. For weights and bias the scale factor is determined once at quantization setup (\"offline\"), and for activations it is determined dynamically at runtime (\"online\"). Important note: Currently, this method is implemented as inference only , with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with n < 8 is likely to lead to severe accuracy degradation for any non-trivial workload.", - "title": "Implementation" - }, + }, { - "location": "/algo_earlyexit/index.html", - "text": "Early Exit Inference\n\n\nWhile Deep Neural Networks benefit from a large number of layers, it's often the case that many data points in classification tasks can be classified accurately with much less work. There have been several studies recently regarding the idea of exiting before the normal endpoint of the neural network. Panda et al in \nConditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition\n points out that a lot of data points can be classified easily and require less processing than some more difficult points and they view this in terms of power savings. Surat et al in \nBranchyNet: Fast Inference via Early Exiting from Deep Neural Networks\n look at a selective approach to exit placement and criteria for exiting early.\n\n\nWhy Does Early Exit Work?\n\n\nEarly Exit is a strategy with a straightforward and easy to understand concept Figure #fig(boundaries) shows a simple example in a 2-D feature space. While deep networks can represent more complex and expressive boundaries between classes (assuming we\u2019re confident of avoiding over-fitting the data), it\u2019s also clear that much of the data can be properly classified with even the simplest of classification boundaries.\n\n\n\n\nData points far from the boundary can be considered \"easy to classify\" and achieve a high degree of confidence quicker than do data points close to the boundary. In fact, we can think of the area between the outer straight lines as being the region that is \"difficult to classify\" and require the full expressiveness of the neural network to accurately classify it.\n\n\nExample code for Early Exit\n\n\nBoth CIFAR10 and ImageNet code comes directly from publically available examples from Pytorch. The only edits are the exits that are inserted in a methodology similar to BranchyNet work.\n\n\nDeeper networks can benefit from multiple exits. Our examples illustrate both a single and a pair of early exits for CIFAR10 and ImageNet, respectively.\n\n\nNote that this code does not actually take exits. What it does is to compute statistics of loss and accuracy assuming exits were taken when criteria are met. Actually implementing exits can be tricky and architecture dependent and we plan to address these issues.\n\n\nHeuristics\n\n\nThe insertion of the exits are ad-hoc, but there are some heuristic principals guiding their placement and parameters. The earlier exits are placed, the more agressive the exit as it essentially prunes the rest of the network at a very early stage, thus saving a lot of work. However, a diminishing percentage of data will be directed through the exit if we are to preserve accuracy.\n\n\nThere are other benefits to adding exits in that training the modified network now has backpropagation losses coming from the exits that affect the earlier layers more substantially than the last exit. This effect mitigates problems such as vanishing gradient.\n\n\nEarly Exit Hyperparameters\n\n\nThere are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit:\n\n\n\n\n\n\n--earlyexit_thresholds\n defines the\nthresholds for each of the early exits. The cross entropy measure must be \nless than\n the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify \"--earlyexit_thresholds 0.9 1.2\" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits.\n\n\n\n\n\n\n--earlyexit_lossweights\n provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of \"--earlyexit_lossweights 0.2 0.3\" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.\n\n\n\n\n\n\nCIFAR10\n\n\nIn the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.\n\n\nImageNet\n\n\nThis supports training and inference of the ImageNet dataset via several well known deep architectures. ResNet-50 is the architecture of interest in this study, however the exit is defined in the generic resnet code and could be used with other size resnets. There are two exits inserted in this example. Again, exit layers must have their sizes match properly.\n\n\nReferences\n\n\n \nPriyadarshini Panda, Abhronil Sengupta, Kaushik Roy\n.\n \nConditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition\n, arXiv:1509.08971v6, 2017.\n\n\n\n\n\nSurat Teerapittayanon, Bradley McDanel, H. T. Kung\n.\n \nBranchyNet: Fast Inference via Early Exiting from Deep Neural Networks\n, arXiv:1709.01686, 2017.", + "location": "/algo_earlyexit/index.html", + "text": "Early Exit Inference\n\n\nWhile Deep Neural Networks benefit from a large number of layers, it's often the case that many data points in classification tasks can be classified accurately with much less work. There have been several studies recently regarding the idea of exiting before the normal endpoint of the neural network. Panda et al in \nConditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition\n points out that a lot of data points can be classified easily and require less processing than some more difficult points and they view this in terms of power savings. Surat et al in \nBranchyNet: Fast Inference via Early Exiting from Deep Neural Networks\n look at a selective approach to exit placement and criteria for exiting early.\n\n\nWhy Does Early Exit Work?\n\n\nEarly Exit is a strategy with a straightforward and easy to understand concept Figure #fig(boundaries) shows a simple example in a 2-D feature space. While deep networks can represent more complex and expressive boundaries between classes (assuming we\u2019re confident of avoiding over-fitting the data), it\u2019s also clear that much of the data can be properly classified with even the simplest of classification boundaries.\n\n\n\n\nData points far from the boundary can be considered \"easy to classify\" and achieve a high degree of confidence quicker than do data points close to the boundary. In fact, we can think of the area between the outer straight lines as being the region that is \"difficult to classify\" and require the full expressiveness of the neural network to accurately classify it.\n\n\nExample code for Early Exit\n\n\nBoth CIFAR10 and ImageNet code comes directly from publically available examples from Pytorch. The only edits are the exits that are inserted in a methodology similar to BranchyNet work.\n\n\nDeeper networks can benefit from multiple exits. Our examples illustrate both a single and a pair of early exits for CIFAR10 and ImageNet, respectively.\n\n\nNote that this code does not actually take exits. What it does is to compute statistics of loss and accuracy assuming exits were taken when criteria are met. Actually implementing exits can be tricky and architecture dependent and we plan to address these issues.\n\n\nHeuristics\n\n\nThe insertion of the exits are ad-hoc, but there are some heuristic principals guiding their placement and parameters. The earlier exits are placed, the more agressive the exit as it essentially prunes the rest of the network at a very early stage, thus saving a lot of work. However, a diminishing percentage of data will be directed through the exit if we are to preserve accuracy.\n\n\nThere are other benefits to adding exits in that training the modified network now has backpropagation losses coming from the exits that affect the earlier layers more substantially than the last exit. This effect mitigates problems such as vanishing gradient.\n\n\nEarly Exit Hyperparameters\n\n\nThere are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit:\n\n\n\n\n\n\n--earlyexit_thresholds\n defines the\nthresholds for each of the early exits. The cross entropy measure must be \nless than\n the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify \"--earlyexit_thresholds 0.9 1.2\" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits.\n\n\n\n\n\n\n--earlyexit_lossweights\n provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of \"--earlyexit_lossweights 0.2 0.3\" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.\n\n\n\n\n\n\nCIFAR10\n\n\nIn the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.\n\n\nImageNet\n\n\nThis supports training and inference of the ImageNet dataset via several well known deep architectures. ResNet-50 is the architecture of interest in this study, however the exit is defined in the generic resnet code and could be used with other size resnets. There are two exits inserted in this example. Again, exit layers must have their sizes match properly.\n\n\nReferences\n\n\n \nPriyadarshini Panda, Abhronil Sengupta, Kaushik Roy\n.\n \nConditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition\n, arXiv:1509.08971v6, 2017.\n\n\n\n\n\nSurat Teerapittayanon, Bradley McDanel, H. T. Kung\n.\n \nBranchyNet: Fast Inference via Early Exiting from Deep Neural Networks\n, arXiv:1709.01686, 2017.", "title": "Early Exit" - }, + }, { - "location": "/algo_earlyexit/index.html#early-exit-inference", - "text": "While Deep Neural Networks benefit from a large number of layers, it's often the case that many data points in classification tasks can be classified accurately with much less work. There have been several studies recently regarding the idea of exiting before the normal endpoint of the neural network. Panda et al in Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition points out that a lot of data points can be classified easily and require less processing than some more difficult points and they view this in terms of power savings. Surat et al in BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks look at a selective approach to exit placement and criteria for exiting early.", + "location": "/algo_earlyexit/index.html#early-exit-inference", + "text": "While Deep Neural Networks benefit from a large number of layers, it's often the case that many data points in classification tasks can be classified accurately with much less work. There have been several studies recently regarding the idea of exiting before the normal endpoint of the neural network. Panda et al in Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition points out that a lot of data points can be classified easily and require less processing than some more difficult points and they view this in terms of power savings. Surat et al in BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks look at a selective approach to exit placement and criteria for exiting early.", "title": "Early Exit Inference" - }, + }, { - "location": "/algo_earlyexit/index.html#why-does-early-exit-work", - "text": "Early Exit is a strategy with a straightforward and easy to understand concept Figure #fig(boundaries) shows a simple example in a 2-D feature space. While deep networks can represent more complex and expressive boundaries between classes (assuming we\u2019re confident of avoiding over-fitting the data), it\u2019s also clear that much of the data can be properly classified with even the simplest of classification boundaries. Data points far from the boundary can be considered \"easy to classify\" and achieve a high degree of confidence quicker than do data points close to the boundary. In fact, we can think of the area between the outer straight lines as being the region that is \"difficult to classify\" and require the full expressiveness of the neural network to accurately classify it.", + "location": "/algo_earlyexit/index.html#why-does-early-exit-work", + "text": "Early Exit is a strategy with a straightforward and easy to understand concept Figure #fig(boundaries) shows a simple example in a 2-D feature space. While deep networks can represent more complex and expressive boundaries between classes (assuming we\u2019re confident of avoiding over-fitting the data), it\u2019s also clear that much of the data can be properly classified with even the simplest of classification boundaries. Data points far from the boundary can be considered \"easy to classify\" and achieve a high degree of confidence quicker than do data points close to the boundary. In fact, we can think of the area between the outer straight lines as being the region that is \"difficult to classify\" and require the full expressiveness of the neural network to accurately classify it.", "title": "Why Does Early Exit Work?" - }, + }, { - "location": "/algo_earlyexit/index.html#example-code-for-early-exit", - "text": "Both CIFAR10 and ImageNet code comes directly from publically available examples from Pytorch. The only edits are the exits that are inserted in a methodology similar to BranchyNet work. Deeper networks can benefit from multiple exits. Our examples illustrate both a single and a pair of early exits for CIFAR10 and ImageNet, respectively. Note that this code does not actually take exits. What it does is to compute statistics of loss and accuracy assuming exits were taken when criteria are met. Actually implementing exits can be tricky and architecture dependent and we plan to address these issues.", + "location": "/algo_earlyexit/index.html#example-code-for-early-exit", + "text": "Both CIFAR10 and ImageNet code comes directly from publically available examples from Pytorch. The only edits are the exits that are inserted in a methodology similar to BranchyNet work. Deeper networks can benefit from multiple exits. Our examples illustrate both a single and a pair of early exits for CIFAR10 and ImageNet, respectively. Note that this code does not actually take exits. What it does is to compute statistics of loss and accuracy assuming exits were taken when criteria are met. Actually implementing exits can be tricky and architecture dependent and we plan to address these issues.", "title": "Example code for Early Exit" - }, + }, { - "location": "/algo_earlyexit/index.html#heuristics", - "text": "The insertion of the exits are ad-hoc, but there are some heuristic principals guiding their placement and parameters. The earlier exits are placed, the more agressive the exit as it essentially prunes the rest of the network at a very early stage, thus saving a lot of work. However, a diminishing percentage of data will be directed through the exit if we are to preserve accuracy. There are other benefits to adding exits in that training the modified network now has backpropagation losses coming from the exits that affect the earlier layers more substantially than the last exit. This effect mitigates problems such as vanishing gradient.", + "location": "/algo_earlyexit/index.html#heuristics", + "text": "The insertion of the exits are ad-hoc, but there are some heuristic principals guiding their placement and parameters. The earlier exits are placed, the more agressive the exit as it essentially prunes the rest of the network at a very early stage, thus saving a lot of work. However, a diminishing percentage of data will be directed through the exit if we are to preserve accuracy. There are other benefits to adding exits in that training the modified network now has backpropagation losses coming from the exits that affect the earlier layers more substantially than the last exit. This effect mitigates problems such as vanishing gradient.", "title": "Heuristics" - }, + }, { - "location": "/algo_earlyexit/index.html#early-exit-hyperparameters", - "text": "There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit: --earlyexit_thresholds defines the\nthresholds for each of the early exits. The cross entropy measure must be less than the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify \"--earlyexit_thresholds 0.9 1.2\" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits. --earlyexit_lossweights provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of \"--earlyexit_lossweights 0.2 0.3\" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.", + "location": "/algo_earlyexit/index.html#early-exit-hyperparameters", + "text": "There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit: --earlyexit_thresholds defines the\nthresholds for each of the early exits. The cross entropy measure must be less than the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify \"--earlyexit_thresholds 0.9 1.2\" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits. --earlyexit_lossweights provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of \"--earlyexit_lossweights 0.2 0.3\" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.", "title": "Early Exit Hyperparameters" - }, + }, { - "location": "/algo_earlyexit/index.html#cifar10", - "text": "In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.", + "location": "/algo_earlyexit/index.html#cifar10", + "text": "In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.", "title": "CIFAR10" - }, + }, { - "location": "/algo_earlyexit/index.html#imagenet", - "text": "This supports training and inference of the ImageNet dataset via several well known deep architectures. ResNet-50 is the architecture of interest in this study, however the exit is defined in the generic resnet code and could be used with other size resnets. There are two exits inserted in this example. Again, exit layers must have their sizes match properly.", + "location": "/algo_earlyexit/index.html#imagenet", + "text": "This supports training and inference of the ImageNet dataset via several well known deep architectures. ResNet-50 is the architecture of interest in this study, however the exit is defined in the generic resnet code and could be used with other size resnets. There are two exits inserted in this example. Again, exit layers must have their sizes match properly.", "title": "ImageNet" - }, + }, { - "location": "/algo_earlyexit/index.html#references", - "text": "Priyadarshini Panda, Abhronil Sengupta, Kaushik Roy .\n Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition , arXiv:1509.08971v6, 2017. Surat Teerapittayanon, Bradley McDanel, H. T. Kung .\n BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks , arXiv:1709.01686, 2017.", + "location": "/algo_earlyexit/index.html#references", + "text": "Priyadarshini Panda, Abhronil Sengupta, Kaushik Roy .\n Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition , arXiv:1509.08971v6, 2017. Surat Teerapittayanon, Bradley McDanel, H. T. Kung .\n BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks , arXiv:1709.01686, 2017.", "title": "References" - }, + }, { - "location": "/model_zoo/index.html", - "text": "Distiller Model Zoo\n\n\nHow to contribute models to the Model Zoo\n\n\nWe encourage you to contribute new models to the Model Zoo. We welcome implementations of published papers or of your own work. To assure that models and algorithms shared with others are high-quality, please commit your models with the following:\n\n\n\n\nCommand-line arguments\n\n\nLog files\n\n\nPyTorch model\n\n\n\n\nContents\n\n\nThe Distiller model zoo is not a \"traditional\" model-zoo, because it does not necessarily contain best-in-class compressed models. Instead, the model-zoo contains a number of deep learning models that have been compressed using Distiller following some well-known research papers. These are meant to serve as examples of how Distiller can be used.\n\n\nEach model contains a Distiller schedule detailing how the model was compressed, a PyTorch checkpoint, text logs and TensorBoard logs.\n\n\n\n\ntable, th, td {\n border: 1px solid black;\n}\n\n\n\n\n \n\n \nPaper\n\n \nDataset\n\n \nNetwork\n\n \nMethod & Granularity\n\n \nSchedule\n\n \nFeatures\n\n \n\n \n\n \nLearning both Weights and Connections for Efficient Neural Networks\n\n \nImageNet\n\n \nAlexnet\n\n \nElement-wise pruning\n\n \nIterative; Manual\n\n \nMagnitude thresholding based on a sensitivity quantifier.\nElement-wise sparsity sensitivity analysis\n\n \n\n \n\n \nTo prune, or not to prune: exploring the efficacy of pruning for model compression\n\n \nImageNet\n\n \nMobileNet\n\n \nElement-wise pruning\n\n \nAutomated gradual; Iterative\n\n \nMagnitude thresholding based on target level\n\n \n\n \n\n \nLearning Structured Sparsity in Deep Neural Networks\n\n \nCIFAR10\n\n \nResNet20\n\n \nGroup regularization\n\n \n1.Train with group-lasso\n2.Remove zero groups and fine-tune\n\n \nGroup Lasso regularization. Groups: kernels (2D), channels, filters (3D), layers (4D), vectors (rows, cols)\n\n \n\n \n\n \nPruning Filters for Efficient ConvNets\n\n \nCIFAR10\n\n \nResNet56\n\n \nFilter ranking; guided by sensitivity analysis\n\n \n1.Rank filters\n2. Remove filters and channels\n3.Fine-tune\n\n \nOne-shot ranking and pruning of filters; with network thinning\n \n\n\n\n\nLearning both Weights and Connections for Efficient Neural Networks\n\n\nThis schedule is an example of \"Iterative Pruning\" for Alexnet/Imagent, as described in chapter 3 of Song Han's PhD dissertation: \nEfficient Methods and Hardware for Deep Learning\n and in his paper \nLearning both Weights and Connections for Efficient Neural Networks\n. \n\n\nThe Distiller schedule uses SensitivityPruner which is similar to MagnitudeParameterPruner, but instead of specifying \"raw\" thresholds, it uses a \"sensitivity parameter\". Song Han's paper says that \"the pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layers weights,\" and this is not explained much further. In Distiller, the \"quality parameter\" is referred to as \"sensitivity\" and\nis based on the values learned from performing sensitivity analysis. Using a parameter that is related to the standard deviation is very helpful: under the assumption that the weights tensors are distributed normally, the standard deviation acts as a threshold normalizer.\n\n\nNote that Distiller's implementation deviates slightly from the algorithm Song Han describes in his PhD dissertation, in that the threshold value is set only once. In his PhD dissertation, Song Han describes a growing threshold, at each iteration. This requires n+1 hyper-parameters (n being the number of pruning iterations we use): the threshold and the threshold increase (delta) at each pruning iteration. Distiller's implementation takes advantage of the fact that as pruning progresses, more weights are pulled toward zero, and therefore the threshold \"traps\" more weights. Thus, we can use less hyper-parameters and achieve the same results.\n\n\n\n\nDistiller schedule: \ndistiller/examples/sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\n\nCheckpoint file: \nalexnet.checkpoint.89.pth.tar\n\n\n\n\nResults\n\n\nOur reference is TorchVision's pretrained Alexnet model which has a Top1 accuracy of 56.55 and Top5=79.09. We prune away 88.44% of the parameters and achieve Top1=56.61 and Top5=79.45.\nSong Han prunes 89% of the parameters, which is slightly better than our results.\n\n\nParameters:\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean\n|----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n| 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n| 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n| 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n| 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n| 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n| 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n| 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n| 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:35,357 - ==> Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:32:01,274 - ==> Top1: 56.606 Top5: 79.446 Loss: 1.893\n\n\n\n\nTo prune, or not to prune: exploring the efficacy of pruning for model compression\n\n\nIn their paper Zhu and Gupta, \"compare the accuracy of large, but pruned models (large-sparse) and their\nsmaller, but dense (small-dense) counterparts with identical memory footprint.\"\nThey also \"propose a new gradual pruning technique that is simple and straightforward to apply across a variety of models/datasets with\nminimal tuning.\"\n\n\nThis pruning schedule is implemented by distiller.AutomatedGradualPruner, which increases the sparsity level (expressed as a percentage of zero-valued elements) gradually over several pruning steps. Distiller's implementation only prunes elements once in an epoch (the model is fine-tuned in between pruning events), which is a small deviation from Zhu and Gupta's paper. The research paper specifies the schedule in terms of mini-batches, while our implementation specifies the schedule in terms of epochs. We feel that using epochs performs well, and is more \"stable\", since the number of mini-batches will change, if you change the batch size.\n\n\nImageNet files:\n\n\n\n\nDistiller schedule: \ndistiller/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml\n\n\nCheckpoint file: \ncheckpoint.pth.tar\n\n\n\n\nResNet18 files:\n\n\n\n\nDistiller schedule: \ndistiller/examples/agp-pruning/resnet18.schedule_agp.yaml\n\n\nCheckpoint file: \ncheckpoint.pth.tar\n\n\n\n\nResults\n\n\nAs our baseline we used a \npretrained PyTorch MobileNet model\n (width=1) which has Top1=68.848 and Top5=88.740.\n\nIn their paper, Zhu and Gupta prune 50% of the elements of MobileNet (width=1) with a 1.1% drop in accuracy. We pruned about 51.6% of the elements, with virtually no change in the accuracies (Top1: 68.808 and Top5: 88.656). We didn't try to prune more than this, but we do note that the baseline accuracy that we used is almost 2% lower than the accuracy published in the paper. \n\n\n+----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n|----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | module.model.0.0.weight | (32, 3, 3, 3) | 864 | 864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14466 | 0.00103 | 0.06508 |\n| 1 | module.model.1.0.weight | (32, 1, 3, 3) | 288 | 288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.32146 | 0.01020 | 0.12932 |\n| 2 | module.model.1.3.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11942 | 0.00024 | 0.03627 |\n| 3 | module.model.2.0.weight | (64, 1, 3, 3) | 576 | 576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15809 | 0.00543 | 0.11513 |\n| 4 | module.model.2.3.weight | (128, 64, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08442 | -0.00031 | 0.04182 |\n| 5 | module.model.3.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16780 | 0.00125 | 0.10545 |\n| 6 | module.model.3.3.weight | (128, 128, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07126 | -0.00197 | 0.04123 |\n| 7 | module.model.4.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10182 | 0.00171 | 0.08719 |\n| 8 | module.model.4.3.weight | (256, 128, 1, 1) | 32768 | 13108 | 0.00000 | 0.00000 | 10.15625 | 59.99756 | 12.50000 | 59.99756 | 0.05543 | -0.00002 | 0.02760 |\n| 9 | module.model.5.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12516 | -0.00288 | 0.08058 |\n| 10 | module.model.5.3.weight | (256, 256, 1, 1) | 65536 | 26215 | 0.00000 | 0.00000 | 12.50000 | 59.99908 | 23.82812 | 59.99908 | 0.04453 | 0.00002 | 0.02271 |\n| 11 | module.model.6.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08024 | 0.00252 | 0.06377 |\n| 12 | module.model.6.3.weight | (512, 256, 1, 1) | 131072 | 52429 | 0.00000 | 0.00000 | 23.82812 | 59.99985 | 14.25781 | 59.99985 | 0.03561 | -0.00057 | 0.01779 |\n| 13 | module.model.7.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11008 | -0.00018 | 0.06829 |\n| 14 | module.model.7.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 14.25781 | 59.99985 | 21.28906 | 59.99985 | 0.02944 | -0.00060 | 0.01515 |\n| 15 | module.model.8.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08258 | 0.00370 | 0.04905 |\n| 16 | module.model.8.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 21.28906 | 59.99985 | 28.51562 | 59.99985 | 0.02865 | -0.00046 | 0.01465 |\n| 17 | module.model.9.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07578 | 0.00468 | 0.04201 |\n| 18 | module.model.9.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 28.51562 | 59.99985 | 23.43750 | 59.99985 | 0.02939 | -0.00044 | 0.01511 |\n| 19 | module.model.10.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07091 | 0.00014 | 0.04306 |\n| 20 | module.model.10.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 24.60938 | 59.99985 | 20.89844 | 59.99985 | 0.03095 | -0.00059 | 0.01672 |\n| 21 | module.model.11.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05729 | -0.00518 | 0.04267 |\n| 22 | module.model.11.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 20.89844 | 59.99985 | 17.57812 | 59.99985 | 0.03229 | -0.00044 | 0.01797 |\n| 23 | module.model.12.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04981 | -0.00136 | 0.03967 |\n| 24 | module.model.12.3.weight | (1024, 512, 1, 1) | 524288 | 209716 | 0.00000 | 0.00000 | 16.01562 | 59.99985 | 44.23828 | 59.99985 | 0.02514 | -0.00106 | 0.01278 |\n| 25 | module.model.13.0.weight | (1024, 1, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02396 | -0.00949 | 0.01549 |\n| 26 | module.model.13.3.weight | (1024, 1024, 1, 1) | 1048576 | 419431 | 0.00000 | 0.00000 | 44.72656 | 59.99994 | 1.46484 | 59.99994 | 0.01801 | -0.00017 | 0.00931 |\n| 27 | module.fc.weight | (1000, 1024) | 1024000 | 409600 | 1.46484 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 60.00000 | 0.05078 | 0.00271 | 0.02734 |\n| 28 | Total sparsity: | - | 4209088 | 1726917 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 58.97171 | 0.00000 | 0.00000 | 0.00000 |\n+----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\nTotal sparsity: 58.97\n\n--- validate (epoch=199)-----------\n128116 samples (256 per mini-batch)\n==> Top1: 65.337 Top5: 84.984 Loss: 1.494\n\n--- test ---------------------\n50000 samples (256 per mini-batch)\n==> Top1: 68.810 Top5: 88.626 Loss: 1.282\n\n\n\n\n\nLearning Structured Sparsity in Deep Neural Networks\n\n\nThis research paper from the University of Pittsburgh, \"proposes a Structured Sparsity Learning (SSL) method to regularize the structures (i.e., filters, channels, filter shapes, and layer depth) of DNNs. SSL can: (1) learn a compact structure from a bigger DNN to reduce computation cost; (2) obtain a hardware-friendly structured sparsity of DNN to efficiently accelerate the DNN\u2019s evaluation.\"\n\n\nNote that this paper does not use pruning, but instead uses group regularization during the training to force weights towards zero, as a group. We used a schedule which thresholds the regularized elements at a magnitude equal to the regularization strength. At the end of the regularization phase, we save the final sparsity masks generated by the regularization, and exit. Then we load this regularized model, remove the layers corresponding to the zeroed weight tensors (all of a layer's elements have a zero value). \n\n\nBaseline training\n\n\nWe started by training the baseline ResNet20-Cifar dense network since we didn't have a pre-trained model.\n\n\n\n\nDistiller schedule: \ndistiller/examples/ssl/resnet20_cifar_baseline_training.yaml\n\n\nCheckpoint files: \ndistiller/examples/ssl/checkpoints/\n\n\n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../cifar10/resnet20/baseline_training.yaml -j=1 --deterministic\n\n\n\n\nRegularization\n\n\nThen we started training from scratch again, but this time we used Group Lasso regularization on entire layers:\n\nDistiller schedule: \ndistiller/examples/ssl/ssl_4D-removal_4L_training.yaml\n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../ssl/ssl_4D-removal_training.yaml -j=1 --deterministic\n\n\n\n\nThe diagram below shows the training of Resnet20/CIFAR10 using Group Lasso regularization on entire layers (in blue) vs. training Resnet20/CIFAR10 baseline (in red). You may notice several interesting things:\n1. The LR-decay policy is the same, but the two sessions start with different initial LR values.\n2. The data-loss of the regularized training follows the same shape as the un-regularized training (baseline), and eventually the two seem to merge.\n3. We see similar behavior in the validation Top1 and Top5 accuracy results, but the regularized training eventually performs better.\n4. In the top right corner we see the behavior of the regularization loss (\nReg Loss\n), which actually increases for some time, until the data-loss has a sharp drop (after ~16K mini-batches), at which point the regularization loss also starts dropping.\n\n\n\nThis \nregularization\n yields 5 layers with zeroed weight tensors. We load this model, remove the 5 layers, and start the fine tuning of the weights. This process of layer removal is specific to ResNet for CIFAR, which we altered by adding code to skip over layers during the forward path. When you export to ONNX, the removed layers do not participate in the forward path, so they don't get incarnated. \n\n\nWe managed to remove 5 of the 16 3x3 convolution layers which dominate the computation time. It's not bad, but we probably could have done better.\n\n\nFine-tuning\n\n\nDuring the \nfine-tuning\n process, because the removed layers do not participate in the forward path, they do not appear in the backward path and are not backpropogated: therefore they are completely disconnected from the network.\n\nWe copy the checkpoint file of the regularized model to \ncheckpoint_trained_4D_regularized_5Lremoved.pth.tar\n.\n\nDistiller schedule: \ndistiller/examples/ssl/ssl_4D-removal_finetuning.yaml\n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.1 --epochs=250 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --compress=../ssl/ssl_4D-removal_finetuning.yaml -j=1 --deterministic\n\n\n\n\nResults\n\n\nOur baseline results for ResNet20 Cifar are: Top1=91.450 and Top5=99.750\n\n\nWe used Distiller's GroupLassoRegularizer to remove 5 layers from Resnet20 (CIFAR10) with no degradation of the accuracies.\n\nThe regularized model exhibits really poor classification abilities: \n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --evaluate\n\n=> loading checkpoint ../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar\n best top@1: 90.620\nLoaded compression schedule from checkpoint (epoch 179)\nRemoving layer: module.layer1.0.conv1 [layer=0 block=0 conv=0]\nRemoving layer: module.layer1.0.conv2 [layer=0 block=0 conv=1]\nRemoving layer: module.layer1.1.conv1 [layer=0 block=1 conv=0]\nRemoving layer: module.layer1.1.conv2 [layer=0 block=1 conv=1]\nRemoving layer: module.layer2.2.conv2 [layer=1 block=2 conv=1]\nFiles already downloaded and verified\nFiles already downloaded and verified\nDataset sizes:\n training=45000\n validation=5000\n test=10000\n--- test ---------------------\n10000 samples (256 per mini-batch)\n==> Top1: 22.290 Top5: 68.940 Loss: 5.172\n\n\n\n\nHowever, after fine-tuning, we recovered most of the accuracies loss, but not quite all of it: Top1=91.020 and Top5=99.670\n\n\nWe didn't spend time trying to wrestle with this network, and therefore didn't achieve SSL's published results (which showed that they managed to remove 6 layers and at the same time increase accuracies).\n\n\nPruning Filters for Efficient ConvNets\n\n\nQuoting the authors directly:\n\n\n\n\nWe present an acceleration method for CNNs, where we prune filters from CNNs that are identified as having a small effect on the output accuracy. By removing whole filters in the network together with their connecting feature maps, the computation costs are reduced significantly.\nIn contrast to pruning weights, this approach does not result in sparse connectivity patterns. Hence, it does not need the support of sparse convolution libraries and can work with existing efficient BLAS libraries for dense matrix multiplications.\n\n\n\n\nThe implementation of the research by Hao et al. required us to add filter-pruning sensitivity analysis, and support for \"network thinning\".\n\n\nAfter performing filter-pruning sensitivity analysis to assess which layers are more sensitive to the pruning of filters, we execute distiller.L1RankedStructureParameterPruner once in order to rank the filters of each layer by their L1-norm values, and then we prune the schedule-prescribed sparsity level. \n\n\n\n\nDistiller schedule: \ndistiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml\n\n\nCheckpoint files: \ncheckpoint_finetuned.pth.tar\n\n\n\n\nThe excerpt from the schedule, displayed below, shows how we declare the L1RankedStructureParameterPruner. This class currently ranks filters only, but because in the future this class may support ranking of various structures, you need to specify for each parameter both the target sparsity level, and the structure type ('3D' is filter-wise pruning).\n\n\npruners:\n filter_pruner:\n class: 'L1RankedStructureParameterPruner'\n reg_regims:\n 'module.layer1.0.conv1.weight': [0.6, '3D']\n 'module.layer1.1.conv1.weight': [0.6, '3D']\n 'module.layer1.2.conv1.weight': [0.6, '3D']\n 'module.layer1.3.conv1.weight': [0.6, '3D']\n\n\n\n\nIn the policy, we specify that we want to invoke this pruner once, at epoch 180. Because we are starting from a network which was trained for 180 epochs (see Baseline training below), the filter ranking is performed right at the outset of this schedule.\n\n\npolicies:\n - pruner:\n instance_name: filter_pruner\n epochs: [180]\n\n\n\n\n\nFollowing the pruning, we want to \"physically\" remove the pruned filters from the network, which involves reconfiguring the Convolutional layers and the parameter tensors. When we remove filters from Convolution layer \nn\n we need to perform several changes to the network:\n1. Shrink layer \nn\n's weights tensor, leaving only the \"important\" filters.\n2. Configure layer \nn\n's \n.out_channels\n member to its new, smaller, value.\n3. If a BN layer follows layer \nn\n, then it also needs to be reconfigured and its scale and shift parameter vectors need to be shrunk.\n4. If a Convolution layer follows the BN layer, then it will have less input channels which requires reconfiguration and shrinking of its weights.\n\n\nAll of this is performed by distiller.ResnetCifarFilterRemover which is also scheduled at epoch 180. We call this process \"network thinning\".\n\n\nextensions:\n net_thinner:\n class: 'FilterRemover'\n thinning_func_str: remove_filters\n arch: 'resnet56_cifar'\n dataset: 'cifar10'\n\n\n\n\nNetwork thinning requires us to understand the layer connectivity and data-dependency of the DNN, and we are working on a robust method to perform this. On networks with topologies similar to ResNet (residuals) and GoogLeNet (inception), which have several inputs and outputs to/from Convolution layers, there is extra details to consider.\n\nOur current implementation is specific to certain layers in ResNet and is a bit fragile. We will continue to improve and generalize this.\n\n\nBaseline training\n\n\nWe started by training the baseline ResNet56-Cifar dense network (180 epochs) since we didn't have a pre-trained model.\n\n\n\n\nDistiller schedule: \ndistiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml\n\n\nCheckpoint files: \ncheckpoint.resnet56_cifar_baseline.pth.tar\n\n\n\n\nResults\n\n\nWe trained a ResNet56-Cifar10 network and achieve accuracy results which are on-par with published results:\nTop1: 92.970 and Top5: 99.740.\n\n\nWe used Hao et al.'s algorithm to remove 37.3% of the original convolution MACs, while maintaining virtually the same accuracy as the baseline:\nTop1: 92.830 and Top5: 99.760", + "location": "/model_zoo/index.html", + "text": "Distiller Model Zoo\n\n\nHow to contribute models to the Model Zoo\n\n\nWe encourage you to contribute new models to the Model Zoo. We welcome implementations of published papers or of your own work. To assure that models and algorithms shared with others are high-quality, please commit your models with the following:\n\n\n\n\nCommand-line arguments\n\n\nLog files\n\n\nPyTorch model\n\n\n\n\nContents\n\n\nThe Distiller model zoo is not a \"traditional\" model-zoo, because it does not necessarily contain best-in-class compressed models. Instead, the model-zoo contains a number of deep learning models that have been compressed using Distiller following some well-known research papers. These are meant to serve as examples of how Distiller can be used.\n\n\nEach model contains a Distiller schedule detailing how the model was compressed, a PyTorch checkpoint, text logs and TensorBoard logs.\n\n\n\n\ntable, th, td {\n border: 1px solid black;\n}\n\n\n\n\n \n\n \nPaper\n\n \nDataset\n\n \nNetwork\n\n \nMethod \n Granularity\n\n \nSchedule\n\n \nFeatures\n\n \n\n \n\n \nLearning both Weights and Connections for Efficient Neural Networks\n\n \nImageNet\n\n \nAlexnet\n\n \nElement-wise pruning\n\n \nIterative; Manual\n\n \nMagnitude thresholding based on a sensitivity quantifier.\nElement-wise sparsity sensitivity analysis\n\n \n\n \n\n \nTo prune, or not to prune: exploring the efficacy of pruning for model compression\n\n \nImageNet\n\n \nMobileNet\n\n \nElement-wise pruning\n\n \nAutomated gradual; Iterative\n\n \nMagnitude thresholding based on target level\n\n \n\n \n\n \nLearning Structured Sparsity in Deep Neural Networks\n\n \nCIFAR10\n\n \nResNet20\n\n \nGroup regularization\n\n \n1.Train with group-lasso\n2.Remove zero groups and fine-tune\n\n \nGroup Lasso regularization. Groups: kernels (2D), channels, filters (3D), layers (4D), vectors (rows, cols)\n\n \n\n \n\n \nPruning Filters for Efficient ConvNets\n\n \nCIFAR10\n\n \nResNet56\n\n \nFilter ranking; guided by sensitivity analysis\n\n \n1.Rank filters\n2. Remove filters and channels\n3.Fine-tune\n\n \nOne-shot ranking and pruning of filters; with network thinning\n \n\n\n\n\nLearning both Weights and Connections for Efficient Neural Networks\n\n\nThis schedule is an example of \"Iterative Pruning\" for Alexnet/Imagent, as described in chapter 3 of Song Han's PhD dissertation: \nEfficient Methods and Hardware for Deep Learning\n and in his paper \nLearning both Weights and Connections for Efficient Neural Networks\n. \n\n\nThe Distiller schedule uses SensitivityPruner which is similar to MagnitudeParameterPruner, but instead of specifying \"raw\" thresholds, it uses a \"sensitivity parameter\". Song Han's paper says that \"the pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layers weights,\" and this is not explained much further. In Distiller, the \"quality parameter\" is referred to as \"sensitivity\" and\nis based on the values learned from performing sensitivity analysis. Using a parameter that is related to the standard deviation is very helpful: under the assumption that the weights tensors are distributed normally, the standard deviation acts as a threshold normalizer.\n\n\nNote that Distiller's implementation deviates slightly from the algorithm Song Han describes in his PhD dissertation, in that the threshold value is set only once. In his PhD dissertation, Song Han describes a growing threshold, at each iteration. This requires n+1 hyper-parameters (n being the number of pruning iterations we use): the threshold and the threshold increase (delta) at each pruning iteration. Distiller's implementation takes advantage of the fact that as pruning progresses, more weights are pulled toward zero, and therefore the threshold \"traps\" more weights. Thus, we can use less hyper-parameters and achieve the same results.\n\n\n\n\nDistiller schedule: \ndistiller/examples/sensitivity-pruning/alexnet.schedule_sensitivity.yaml\n\n\nCheckpoint file: \nalexnet.checkpoint.89.pth.tar\n\n\n\n\nResults\n\n\nOur reference is TorchVision's pretrained Alexnet model which has a Top1 accuracy of 56.55 and Top5=79.09. We prune away 88.44% of the parameters and achieve Top1=56.61 and Top5=79.45.\nSong Han prunes 89% of the parameters, which is slightly better than our results.\n\n\nParameters:\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean\n|----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n| 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n| 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n| 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n| 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n| 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n| 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n| 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n| 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:35,357 - ==\n Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:32:01,274 - ==\n Top1: 56.606 Top5: 79.446 Loss: 1.893\n\n\n\n\nTo prune, or not to prune: exploring the efficacy of pruning for model compression\n\n\nIn their paper Zhu and Gupta, \"compare the accuracy of large, but pruned models (large-sparse) and their\nsmaller, but dense (small-dense) counterparts with identical memory footprint.\"\nThey also \"propose a new gradual pruning technique that is simple and straightforward to apply across a variety of models/datasets with\nminimal tuning.\"\n\n\nThis pruning schedule is implemented by distiller.AutomatedGradualPruner, which increases the sparsity level (expressed as a percentage of zero-valued elements) gradually over several pruning steps. Distiller's implementation only prunes elements once in an epoch (the model is fine-tuned in between pruning events), which is a small deviation from Zhu and Gupta's paper. The research paper specifies the schedule in terms of mini-batches, while our implementation specifies the schedule in terms of epochs. We feel that using epochs performs well, and is more \"stable\", since the number of mini-batches will change, if you change the batch size.\n\n\nImageNet files:\n\n\n\n\nDistiller schedule: \ndistiller/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml\n\n\nCheckpoint file: \ncheckpoint.pth.tar\n\n\n\n\nResNet18 files:\n\n\n\n\nDistiller schedule: \ndistiller/examples/agp-pruning/resnet18.schedule_agp.yaml\n\n\nCheckpoint file: \ncheckpoint.pth.tar\n\n\n\n\nResults\n\n\nAs our baseline we used a \npretrained PyTorch MobileNet model\n (width=1) which has Top1=68.848 and Top5=88.740.\n\nIn their paper, Zhu and Gupta prune 50% of the elements of MobileNet (width=1) with a 1.1% drop in accuracy. We pruned about 51.6% of the elements, with virtually no change in the accuracies (Top1: 68.808 and Top5: 88.656). We didn't try to prune more than this, but we do note that the baseline accuracy that we used is almost 2% lower than the accuracy published in the paper. \n\n\n+----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n|----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | module.model.0.0.weight | (32, 3, 3, 3) | 864 | 864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14466 | 0.00103 | 0.06508 |\n| 1 | module.model.1.0.weight | (32, 1, 3, 3) | 288 | 288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.32146 | 0.01020 | 0.12932 |\n| 2 | module.model.1.3.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11942 | 0.00024 | 0.03627 |\n| 3 | module.model.2.0.weight | (64, 1, 3, 3) | 576 | 576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15809 | 0.00543 | 0.11513 |\n| 4 | module.model.2.3.weight | (128, 64, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08442 | -0.00031 | 0.04182 |\n| 5 | module.model.3.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16780 | 0.00125 | 0.10545 |\n| 6 | module.model.3.3.weight | (128, 128, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07126 | -0.00197 | 0.04123 |\n| 7 | module.model.4.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10182 | 0.00171 | 0.08719 |\n| 8 | module.model.4.3.weight | (256, 128, 1, 1) | 32768 | 13108 | 0.00000 | 0.00000 | 10.15625 | 59.99756 | 12.50000 | 59.99756 | 0.05543 | -0.00002 | 0.02760 |\n| 9 | module.model.5.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12516 | -0.00288 | 0.08058 |\n| 10 | module.model.5.3.weight | (256, 256, 1, 1) | 65536 | 26215 | 0.00000 | 0.00000 | 12.50000 | 59.99908 | 23.82812 | 59.99908 | 0.04453 | 0.00002 | 0.02271 |\n| 11 | module.model.6.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08024 | 0.00252 | 0.06377 |\n| 12 | module.model.6.3.weight | (512, 256, 1, 1) | 131072 | 52429 | 0.00000 | 0.00000 | 23.82812 | 59.99985 | 14.25781 | 59.99985 | 0.03561 | -0.00057 | 0.01779 |\n| 13 | module.model.7.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11008 | -0.00018 | 0.06829 |\n| 14 | module.model.7.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 14.25781 | 59.99985 | 21.28906 | 59.99985 | 0.02944 | -0.00060 | 0.01515 |\n| 15 | module.model.8.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08258 | 0.00370 | 0.04905 |\n| 16 | module.model.8.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 21.28906 | 59.99985 | 28.51562 | 59.99985 | 0.02865 | -0.00046 | 0.01465 |\n| 17 | module.model.9.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07578 | 0.00468 | 0.04201 |\n| 18 | module.model.9.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 28.51562 | 59.99985 | 23.43750 | 59.99985 | 0.02939 | -0.00044 | 0.01511 |\n| 19 | module.model.10.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07091 | 0.00014 | 0.04306 |\n| 20 | module.model.10.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 24.60938 | 59.99985 | 20.89844 | 59.99985 | 0.03095 | -0.00059 | 0.01672 |\n| 21 | module.model.11.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05729 | -0.00518 | 0.04267 |\n| 22 | module.model.11.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 20.89844 | 59.99985 | 17.57812 | 59.99985 | 0.03229 | -0.00044 | 0.01797 |\n| 23 | module.model.12.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04981 | -0.00136 | 0.03967 |\n| 24 | module.model.12.3.weight | (1024, 512, 1, 1) | 524288 | 209716 | 0.00000 | 0.00000 | 16.01562 | 59.99985 | 44.23828 | 59.99985 | 0.02514 | -0.00106 | 0.01278 |\n| 25 | module.model.13.0.weight | (1024, 1, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02396 | -0.00949 | 0.01549 |\n| 26 | module.model.13.3.weight | (1024, 1024, 1, 1) | 1048576 | 419431 | 0.00000 | 0.00000 | 44.72656 | 59.99994 | 1.46484 | 59.99994 | 0.01801 | -0.00017 | 0.00931 |\n| 27 | module.fc.weight | (1000, 1024) | 1024000 | 409600 | 1.46484 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 60.00000 | 0.05078 | 0.00271 | 0.02734 |\n| 28 | Total sparsity: | - | 4209088 | 1726917 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 58.97171 | 0.00000 | 0.00000 | 0.00000 |\n+----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\nTotal sparsity: 58.97\n\n--- validate (epoch=199)-----------\n128116 samples (256 per mini-batch)\n==\n Top1: 65.337 Top5: 84.984 Loss: 1.494\n\n--- test ---------------------\n50000 samples (256 per mini-batch)\n==\n Top1: 68.810 Top5: 88.626 Loss: 1.282\n\n\n\n\n\nLearning Structured Sparsity in Deep Neural Networks\n\n\nThis research paper from the University of Pittsburgh, \"proposes a Structured Sparsity Learning (SSL) method to regularize the structures (i.e., filters, channels, filter shapes, and layer depth) of DNNs. SSL can: (1) learn a compact structure from a bigger DNN to reduce computation cost; (2) obtain a hardware-friendly structured sparsity of DNN to efficiently accelerate the DNN\u2019s evaluation.\"\n\n\nNote that this paper does not use pruning, but instead uses group regularization during the training to force weights towards zero, as a group. We used a schedule which thresholds the regularized elements at a magnitude equal to the regularization strength. At the end of the regularization phase, we save the final sparsity masks generated by the regularization, and exit. Then we load this regularized model, remove the layers corresponding to the zeroed weight tensors (all of a layer's elements have a zero value). \n\n\nBaseline training\n\n\nWe started by training the baseline ResNet20-Cifar dense network since we didn't have a pre-trained model.\n\n\n\n\nDistiller schedule: \ndistiller/examples/ssl/resnet20_cifar_baseline_training.yaml\n\n\nCheckpoint files: \ndistiller/examples/ssl/checkpoints/\n\n\n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../cifar10/resnet20/baseline_training.yaml -j=1 --deterministic\n\n\n\n\nRegularization\n\n\nThen we started training from scratch again, but this time we used Group Lasso regularization on entire layers:\n\nDistiller schedule: \ndistiller/examples/ssl/ssl_4D-removal_4L_training.yaml\n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../ssl/ssl_4D-removal_training.yaml -j=1 --deterministic\n\n\n\n\nThe diagram below shows the training of Resnet20/CIFAR10 using Group Lasso regularization on entire layers (in blue) vs. training Resnet20/CIFAR10 baseline (in red). You may notice several interesting things:\n1. The LR-decay policy is the same, but the two sessions start with different initial LR values.\n2. The data-loss of the regularized training follows the same shape as the un-regularized training (baseline), and eventually the two seem to merge.\n3. We see similar behavior in the validation Top1 and Top5 accuracy results, but the regularized training eventually performs better.\n4. In the top right corner we see the behavior of the regularization loss (\nReg Loss\n), which actually increases for some time, until the data-loss has a sharp drop (after ~16K mini-batches), at which point the regularization loss also starts dropping.\n\n\n\nThis \nregularization\n yields 5 layers with zeroed weight tensors. We load this model, remove the 5 layers, and start the fine tuning of the weights. This process of layer removal is specific to ResNet for CIFAR, which we altered by adding code to skip over layers during the forward path. When you export to ONNX, the removed layers do not participate in the forward path, so they don't get incarnated. \n\n\nWe managed to remove 5 of the 16 3x3 convolution layers which dominate the computation time. It's not bad, but we probably could have done better.\n\n\nFine-tuning\n\n\nDuring the \nfine-tuning\n process, because the removed layers do not participate in the forward path, they do not appear in the backward path and are not backpropogated: therefore they are completely disconnected from the network.\n\nWe copy the checkpoint file of the regularized model to \ncheckpoint_trained_4D_regularized_5Lremoved.pth.tar\n.\n\nDistiller schedule: \ndistiller/examples/ssl/ssl_4D-removal_finetuning.yaml\n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.1 --epochs=250 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --compress=../ssl/ssl_4D-removal_finetuning.yaml -j=1 --deterministic\n\n\n\n\nResults\n\n\nOur baseline results for ResNet20 Cifar are: Top1=91.450 and Top5=99.750\n\n\nWe used Distiller's GroupLassoRegularizer to remove 5 layers from Resnet20 (CIFAR10) with no degradation of the accuracies.\n\nThe regularized model exhibits really poor classification abilities: \n\n\n$ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --evaluate\n\n=\n loading checkpoint ../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar\n best top@1: 90.620\nLoaded compression schedule from checkpoint (epoch 179)\nRemoving layer: module.layer1.0.conv1 [layer=0 block=0 conv=0]\nRemoving layer: module.layer1.0.conv2 [layer=0 block=0 conv=1]\nRemoving layer: module.layer1.1.conv1 [layer=0 block=1 conv=0]\nRemoving layer: module.layer1.1.conv2 [layer=0 block=1 conv=1]\nRemoving layer: module.layer2.2.conv2 [layer=1 block=2 conv=1]\nFiles already downloaded and verified\nFiles already downloaded and verified\nDataset sizes:\n training=45000\n validation=5000\n test=10000\n--- test ---------------------\n10000 samples (256 per mini-batch)\n==\n Top1: 22.290 Top5: 68.940 Loss: 5.172\n\n\n\n\nHowever, after fine-tuning, we recovered most of the accuracies loss, but not quite all of it: Top1=91.020 and Top5=99.670\n\n\nWe didn't spend time trying to wrestle with this network, and therefore didn't achieve SSL's published results (which showed that they managed to remove 6 layers and at the same time increase accuracies).\n\n\nPruning Filters for Efficient ConvNets\n\n\nQuoting the authors directly:\n\n\n\n\nWe present an acceleration method for CNNs, where we prune filters from CNNs that are identified as having a small effect on the output accuracy. By removing whole filters in the network together with their connecting feature maps, the computation costs are reduced significantly.\nIn contrast to pruning weights, this approach does not result in sparse connectivity patterns. Hence, it does not need the support of sparse convolution libraries and can work with existing efficient BLAS libraries for dense matrix multiplications.\n\n\n\n\nThe implementation of the research by Hao et al. required us to add filter-pruning sensitivity analysis, and support for \"network thinning\".\n\n\nAfter performing filter-pruning sensitivity analysis to assess which layers are more sensitive to the pruning of filters, we execute distiller.L1RankedStructureParameterPruner once in order to rank the filters of each layer by their L1-norm values, and then we prune the schedule-prescribed sparsity level. \n\n\n\n\nDistiller schedule: \ndistiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml\n\n\nCheckpoint files: \ncheckpoint_finetuned.pth.tar\n\n\n\n\nThe excerpt from the schedule, displayed below, shows how we declare the L1RankedStructureParameterPruner. This class currently ranks filters only, but because in the future this class may support ranking of various structures, you need to specify for each parameter both the target sparsity level, and the structure type ('3D' is filter-wise pruning).\n\n\npruners:\n filter_pruner:\n class: 'L1RankedStructureParameterPruner'\n reg_regims:\n 'module.layer1.0.conv1.weight': [0.6, '3D']\n 'module.layer1.1.conv1.weight': [0.6, '3D']\n 'module.layer1.2.conv1.weight': [0.6, '3D']\n 'module.layer1.3.conv1.weight': [0.6, '3D']\n\n\n\n\nIn the policy, we specify that we want to invoke this pruner once, at epoch 180. Because we are starting from a network which was trained for 180 epochs (see Baseline training below), the filter ranking is performed right at the outset of this schedule.\n\n\npolicies:\n - pruner:\n instance_name: filter_pruner\n epochs: [180]\n\n\n\n\n\nFollowing the pruning, we want to \"physically\" remove the pruned filters from the network, which involves reconfiguring the Convolutional layers and the parameter tensors. When we remove filters from Convolution layer \nn\n we need to perform several changes to the network:\n1. Shrink layer \nn\n's weights tensor, leaving only the \"important\" filters.\n2. Configure layer \nn\n's \n.out_channels\n member to its new, smaller, value.\n3. If a BN layer follows layer \nn\n, then it also needs to be reconfigured and its scale and shift parameter vectors need to be shrunk.\n4. If a Convolution layer follows the BN layer, then it will have less input channels which requires reconfiguration and shrinking of its weights.\n\n\nAll of this is performed by distiller.ResnetCifarFilterRemover which is also scheduled at epoch 180. We call this process \"network thinning\".\n\n\nextensions:\n net_thinner:\n class: 'FilterRemover'\n thinning_func_str: remove_filters\n arch: 'resnet56_cifar'\n dataset: 'cifar10'\n\n\n\n\nNetwork thinning requires us to understand the layer connectivity and data-dependency of the DNN, and we are working on a robust method to perform this. On networks with topologies similar to ResNet (residuals) and GoogLeNet (inception), which have several inputs and outputs to/from Convolution layers, there is extra details to consider.\n\nOur current implementation is specific to certain layers in ResNet and is a bit fragile. We will continue to improve and generalize this.\n\n\nBaseline training\n\n\nWe started by training the baseline ResNet56-Cifar dense network (180 epochs) since we didn't have a pre-trained model.\n\n\n\n\nDistiller schedule: \ndistiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml\n\n\nCheckpoint files: \ncheckpoint.resnet56_cifar_baseline.pth.tar\n\n\n\n\nResults\n\n\nWe trained a ResNet56-Cifar10 network and achieve accuracy results which are on-par with published results:\nTop1: 92.970 and Top5: 99.740.\n\n\nWe used Hao et al.'s algorithm to remove 37.3% of the original convolution MACs, while maintaining virtually the same accuracy as the baseline:\nTop1: 92.830 and Top5: 99.760", "title": "Model Zoo" - }, + }, { - "location": "/model_zoo/index.html#distiller-model-zoo", - "text": "", + "location": "/model_zoo/index.html#distiller-model-zoo", + "text": "", "title": "Distiller Model Zoo" - }, + }, { - "location": "/model_zoo/index.html#how-to-contribute-models-to-the-model-zoo", - "text": "We encourage you to contribute new models to the Model Zoo. We welcome implementations of published papers or of your own work. To assure that models and algorithms shared with others are high-quality, please commit your models with the following: Command-line arguments Log files PyTorch model", + "location": "/model_zoo/index.html#how-to-contribute-models-to-the-model-zoo", + "text": "We encourage you to contribute new models to the Model Zoo. We welcome implementations of published papers or of your own work. To assure that models and algorithms shared with others are high-quality, please commit your models with the following: Command-line arguments Log files PyTorch model", "title": "How to contribute models to the Model Zoo" - }, + }, { - "location": "/model_zoo/index.html#contents", - "text": "The Distiller model zoo is not a \"traditional\" model-zoo, because it does not necessarily contain best-in-class compressed models. Instead, the model-zoo contains a number of deep learning models that have been compressed using Distiller following some well-known research papers. These are meant to serve as examples of how Distiller can be used. Each model contains a Distiller schedule detailing how the model was compressed, a PyTorch checkpoint, text logs and TensorBoard logs. \ntable, th, td {\n border: 1px solid black;\n} \n \n Paper \n Dataset \n Network \n Method & Granularity \n Schedule \n Features \n \n \n Learning both Weights and Connections for Efficient Neural Networks \n ImageNet \n Alexnet \n Element-wise pruning \n Iterative; Manual \n Magnitude thresholding based on a sensitivity quantifier. Element-wise sparsity sensitivity analysis \n \n \n To prune, or not to prune: exploring the efficacy of pruning for model compression \n ImageNet \n MobileNet \n Element-wise pruning \n Automated gradual; Iterative \n Magnitude thresholding based on target level \n \n \n Learning Structured Sparsity in Deep Neural Networks \n CIFAR10 \n ResNet20 \n Group regularization \n 1.Train with group-lasso 2.Remove zero groups and fine-tune \n Group Lasso regularization. Groups: kernels (2D), channels, filters (3D), layers (4D), vectors (rows, cols) \n \n \n Pruning Filters for Efficient ConvNets \n CIFAR10 \n ResNet56 \n Filter ranking; guided by sensitivity analysis \n 1.Rank filters 2. Remove filters and channels 3.Fine-tune \n One-shot ranking and pruning of filters; with network thinning", + "location": "/model_zoo/index.html#contents", + "text": "The Distiller model zoo is not a \"traditional\" model-zoo, because it does not necessarily contain best-in-class compressed models. Instead, the model-zoo contains a number of deep learning models that have been compressed using Distiller following some well-known research papers. These are meant to serve as examples of how Distiller can be used. Each model contains a Distiller schedule detailing how the model was compressed, a PyTorch checkpoint, text logs and TensorBoard logs. \ntable, th, td {\n border: 1px solid black;\n} \n \n Paper \n Dataset \n Network \n Method Granularity \n Schedule \n Features \n \n \n Learning both Weights and Connections for Efficient Neural Networks \n ImageNet \n Alexnet \n Element-wise pruning \n Iterative; Manual \n Magnitude thresholding based on a sensitivity quantifier. Element-wise sparsity sensitivity analysis \n \n \n To prune, or not to prune: exploring the efficacy of pruning for model compression \n ImageNet \n MobileNet \n Element-wise pruning \n Automated gradual; Iterative \n Magnitude thresholding based on target level \n \n \n Learning Structured Sparsity in Deep Neural Networks \n CIFAR10 \n ResNet20 \n Group regularization \n 1.Train with group-lasso 2.Remove zero groups and fine-tune \n Group Lasso regularization. Groups: kernels (2D), channels, filters (3D), layers (4D), vectors (rows, cols) \n \n \n Pruning Filters for Efficient ConvNets \n CIFAR10 \n ResNet56 \n Filter ranking; guided by sensitivity analysis \n 1.Rank filters 2. Remove filters and channels 3.Fine-tune \n One-shot ranking and pruning of filters; with network thinning", "title": "Contents" - }, + }, { - "location": "/model_zoo/index.html#learning-both-weights-and-connections-for-efficient-neural-networks", - "text": "This schedule is an example of \"Iterative Pruning\" for Alexnet/Imagent, as described in chapter 3 of Song Han's PhD dissertation: Efficient Methods and Hardware for Deep Learning and in his paper Learning both Weights and Connections for Efficient Neural Networks . The Distiller schedule uses SensitivityPruner which is similar to MagnitudeParameterPruner, but instead of specifying \"raw\" thresholds, it uses a \"sensitivity parameter\". Song Han's paper says that \"the pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layers weights,\" and this is not explained much further. In Distiller, the \"quality parameter\" is referred to as \"sensitivity\" and\nis based on the values learned from performing sensitivity analysis. Using a parameter that is related to the standard deviation is very helpful: under the assumption that the weights tensors are distributed normally, the standard deviation acts as a threshold normalizer. Note that Distiller's implementation deviates slightly from the algorithm Song Han describes in his PhD dissertation, in that the threshold value is set only once. In his PhD dissertation, Song Han describes a growing threshold, at each iteration. This requires n+1 hyper-parameters (n being the number of pruning iterations we use): the threshold and the threshold increase (delta) at each pruning iteration. Distiller's implementation takes advantage of the fact that as pruning progresses, more weights are pulled toward zero, and therefore the threshold \"traps\" more weights. Thus, we can use less hyper-parameters and achieve the same results. Distiller schedule: distiller/examples/sensitivity-pruning/alexnet.schedule_sensitivity.yaml Checkpoint file: alexnet.checkpoint.89.pth.tar", + "location": "/model_zoo/index.html#learning-both-weights-and-connections-for-efficient-neural-networks", + "text": "This schedule is an example of \"Iterative Pruning\" for Alexnet/Imagent, as described in chapter 3 of Song Han's PhD dissertation: Efficient Methods and Hardware for Deep Learning and in his paper Learning both Weights and Connections for Efficient Neural Networks . The Distiller schedule uses SensitivityPruner which is similar to MagnitudeParameterPruner, but instead of specifying \"raw\" thresholds, it uses a \"sensitivity parameter\". Song Han's paper says that \"the pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layers weights,\" and this is not explained much further. In Distiller, the \"quality parameter\" is referred to as \"sensitivity\" and\nis based on the values learned from performing sensitivity analysis. Using a parameter that is related to the standard deviation is very helpful: under the assumption that the weights tensors are distributed normally, the standard deviation acts as a threshold normalizer. Note that Distiller's implementation deviates slightly from the algorithm Song Han describes in his PhD dissertation, in that the threshold value is set only once. In his PhD dissertation, Song Han describes a growing threshold, at each iteration. This requires n+1 hyper-parameters (n being the number of pruning iterations we use): the threshold and the threshold increase (delta) at each pruning iteration. Distiller's implementation takes advantage of the fact that as pruning progresses, more weights are pulled toward zero, and therefore the threshold \"traps\" more weights. Thus, we can use less hyper-parameters and achieve the same results. Distiller schedule: distiller/examples/sensitivity-pruning/alexnet.schedule_sensitivity.yaml Checkpoint file: alexnet.checkpoint.89.pth.tar", "title": "Learning both Weights and Connections for Efficient Neural Networks" - }, + }, { - "location": "/model_zoo/index.html#results", - "text": "Our reference is TorchVision's pretrained Alexnet model which has a Top1 accuracy of 56.55 and Top5=79.09. We prune away 88.44% of the parameters and achieve Top1=56.61 and Top5=79.45.\nSong Han prunes 89% of the parameters, which is slightly better than our results. Parameters:\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean\n|----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n| 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n| 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n| 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n| 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n| 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n| 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n| 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n| 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:35,357 - ==> Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:32:01,274 - ==> Top1: 56.606 Top5: 79.446 Loss: 1.893", + "location": "/model_zoo/index.html#results", + "text": "Our reference is TorchVision's pretrained Alexnet model which has a Top1 accuracy of 56.55 and Top5=79.09. We prune away 88.44% of the parameters and achieve Top1=56.61 and Top5=79.45.\nSong Han prunes 89% of the parameters, which is slightly better than our results. Parameters:\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean\n|----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | features.module.0.weight | (64, 3, 11, 11) | 23232 | 13411 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 42.27359 | 0.14391 | -0.00002 | 0.08805 |\n| 1 | features.module.3.weight | (192, 64, 5, 5) | 307200 | 115560 | 0.00000 | 0.00000 | 0.00000 | 1.91243 | 0.00000 | 62.38281 | 0.04703 | -0.00250 | 0.02289 |\n| 2 | features.module.6.weight | (384, 192, 3, 3) | 663552 | 256565 | 0.00000 | 0.00000 | 0.00000 | 6.18490 | 0.00000 | 61.33445 | 0.03354 | -0.00184 | 0.01803 |\n| 3 | features.module.8.weight | (256, 384, 3, 3) | 884736 | 315065 | 0.00000 | 0.00000 | 0.00000 | 6.96411 | 0.00000 | 64.38881 | 0.02646 | -0.00168 | 0.01422 |\n| 4 | features.module.10.weight | (256, 256, 3, 3) | 589824 | 186938 | 0.00000 | 0.00000 | 0.00000 | 15.49225 | 0.00000 | 68.30614 | 0.02714 | -0.00246 | 0.01409 |\n| 5 | classifier.1.weight | (4096, 9216) | 37748736 | 3398881 | 0.00000 | 0.21973 | 0.00000 | 0.21973 | 0.00000 | 90.99604 | 0.00589 | -0.00020 | 0.00168 |\n| 6 | classifier.4.weight | (4096, 4096) | 16777216 | 1782769 | 0.21973 | 3.46680 | 0.00000 | 3.46680 | 0.00000 | 89.37387 | 0.00849 | -0.00066 | 0.00263 |\n| 7 | classifier.6.weight | (1000, 4096) | 4096000 | 994738 | 3.36914 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 75.71440 | 0.01718 | 0.00030 | 0.00778 |\n| 8 | Total sparsity: | - | 61090496 | 7063928 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 88.43694 | 0.00000 | 0.00000 | 0.00000 |\n+----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n 2018-04-04 21:30:52,499 - Total sparsity: 88.44\n\n 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------\n 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)\n 2018-04-04 21:31:35,357 - == Top1: 51.838 Top5: 74.817 Loss: 2.150\n\n 2018-04-04 21:31:39,251 - --- test ---------------------\n 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)\n 2018-04-04 21:32:01,274 - == Top1: 56.606 Top5: 79.446 Loss: 1.893", "title": "Results" - }, + }, { - "location": "/model_zoo/index.html#to-prune-or-not-to-prune-exploring-the-efficacy-of-pruning-for-model-compression", - "text": "In their paper Zhu and Gupta, \"compare the accuracy of large, but pruned models (large-sparse) and their\nsmaller, but dense (small-dense) counterparts with identical memory footprint.\"\nThey also \"propose a new gradual pruning technique that is simple and straightforward to apply across a variety of models/datasets with\nminimal tuning.\" This pruning schedule is implemented by distiller.AutomatedGradualPruner, which increases the sparsity level (expressed as a percentage of zero-valued elements) gradually over several pruning steps. Distiller's implementation only prunes elements once in an epoch (the model is fine-tuned in between pruning events), which is a small deviation from Zhu and Gupta's paper. The research paper specifies the schedule in terms of mini-batches, while our implementation specifies the schedule in terms of epochs. We feel that using epochs performs well, and is more \"stable\", since the number of mini-batches will change, if you change the batch size. ImageNet files: Distiller schedule: distiller/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml Checkpoint file: checkpoint.pth.tar ResNet18 files: Distiller schedule: distiller/examples/agp-pruning/resnet18.schedule_agp.yaml Checkpoint file: checkpoint.pth.tar", + "location": "/model_zoo/index.html#to-prune-or-not-to-prune-exploring-the-efficacy-of-pruning-for-model-compression", + "text": "In their paper Zhu and Gupta, \"compare the accuracy of large, but pruned models (large-sparse) and their\nsmaller, but dense (small-dense) counterparts with identical memory footprint.\"\nThey also \"propose a new gradual pruning technique that is simple and straightforward to apply across a variety of models/datasets with\nminimal tuning.\" This pruning schedule is implemented by distiller.AutomatedGradualPruner, which increases the sparsity level (expressed as a percentage of zero-valued elements) gradually over several pruning steps. Distiller's implementation only prunes elements once in an epoch (the model is fine-tuned in between pruning events), which is a small deviation from Zhu and Gupta's paper. The research paper specifies the schedule in terms of mini-batches, while our implementation specifies the schedule in terms of epochs. We feel that using epochs performs well, and is more \"stable\", since the number of mini-batches will change, if you change the batch size. ImageNet files: Distiller schedule: distiller/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml Checkpoint file: checkpoint.pth.tar ResNet18 files: Distiller schedule: distiller/examples/agp-pruning/resnet18.schedule_agp.yaml Checkpoint file: checkpoint.pth.tar", "title": "To prune, or not to prune: exploring the efficacy of pruning for model compression" - }, + }, { - "location": "/model_zoo/index.html#results_1", - "text": "As our baseline we used a pretrained PyTorch MobileNet model (width=1) which has Top1=68.848 and Top5=88.740. \nIn their paper, Zhu and Gupta prune 50% of the elements of MobileNet (width=1) with a 1.1% drop in accuracy. We pruned about 51.6% of the elements, with virtually no change in the accuracies (Top1: 68.808 and Top5: 88.656). We didn't try to prune more than this, but we do note that the baseline accuracy that we used is almost 2% lower than the accuracy published in the paper. +----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n|----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | module.model.0.0.weight | (32, 3, 3, 3) | 864 | 864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14466 | 0.00103 | 0.06508 |\n| 1 | module.model.1.0.weight | (32, 1, 3, 3) | 288 | 288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.32146 | 0.01020 | 0.12932 |\n| 2 | module.model.1.3.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11942 | 0.00024 | 0.03627 |\n| 3 | module.model.2.0.weight | (64, 1, 3, 3) | 576 | 576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15809 | 0.00543 | 0.11513 |\n| 4 | module.model.2.3.weight | (128, 64, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08442 | -0.00031 | 0.04182 |\n| 5 | module.model.3.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16780 | 0.00125 | 0.10545 |\n| 6 | module.model.3.3.weight | (128, 128, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07126 | -0.00197 | 0.04123 |\n| 7 | module.model.4.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10182 | 0.00171 | 0.08719 |\n| 8 | module.model.4.3.weight | (256, 128, 1, 1) | 32768 | 13108 | 0.00000 | 0.00000 | 10.15625 | 59.99756 | 12.50000 | 59.99756 | 0.05543 | -0.00002 | 0.02760 |\n| 9 | module.model.5.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12516 | -0.00288 | 0.08058 |\n| 10 | module.model.5.3.weight | (256, 256, 1, 1) | 65536 | 26215 | 0.00000 | 0.00000 | 12.50000 | 59.99908 | 23.82812 | 59.99908 | 0.04453 | 0.00002 | 0.02271 |\n| 11 | module.model.6.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08024 | 0.00252 | 0.06377 |\n| 12 | module.model.6.3.weight | (512, 256, 1, 1) | 131072 | 52429 | 0.00000 | 0.00000 | 23.82812 | 59.99985 | 14.25781 | 59.99985 | 0.03561 | -0.00057 | 0.01779 |\n| 13 | module.model.7.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11008 | -0.00018 | 0.06829 |\n| 14 | module.model.7.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 14.25781 | 59.99985 | 21.28906 | 59.99985 | 0.02944 | -0.00060 | 0.01515 |\n| 15 | module.model.8.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08258 | 0.00370 | 0.04905 |\n| 16 | module.model.8.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 21.28906 | 59.99985 | 28.51562 | 59.99985 | 0.02865 | -0.00046 | 0.01465 |\n| 17 | module.model.9.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07578 | 0.00468 | 0.04201 |\n| 18 | module.model.9.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 28.51562 | 59.99985 | 23.43750 | 59.99985 | 0.02939 | -0.00044 | 0.01511 |\n| 19 | module.model.10.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07091 | 0.00014 | 0.04306 |\n| 20 | module.model.10.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 24.60938 | 59.99985 | 20.89844 | 59.99985 | 0.03095 | -0.00059 | 0.01672 |\n| 21 | module.model.11.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05729 | -0.00518 | 0.04267 |\n| 22 | module.model.11.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 20.89844 | 59.99985 | 17.57812 | 59.99985 | 0.03229 | -0.00044 | 0.01797 |\n| 23 | module.model.12.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04981 | -0.00136 | 0.03967 |\n| 24 | module.model.12.3.weight | (1024, 512, 1, 1) | 524288 | 209716 | 0.00000 | 0.00000 | 16.01562 | 59.99985 | 44.23828 | 59.99985 | 0.02514 | -0.00106 | 0.01278 |\n| 25 | module.model.13.0.weight | (1024, 1, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02396 | -0.00949 | 0.01549 |\n| 26 | module.model.13.3.weight | (1024, 1024, 1, 1) | 1048576 | 419431 | 0.00000 | 0.00000 | 44.72656 | 59.99994 | 1.46484 | 59.99994 | 0.01801 | -0.00017 | 0.00931 |\n| 27 | module.fc.weight | (1000, 1024) | 1024000 | 409600 | 1.46484 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 60.00000 | 0.05078 | 0.00271 | 0.02734 |\n| 28 | Total sparsity: | - | 4209088 | 1726917 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 58.97171 | 0.00000 | 0.00000 | 0.00000 |\n+----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\nTotal sparsity: 58.97\n\n--- validate (epoch=199)-----------\n128116 samples (256 per mini-batch)\n==> Top1: 65.337 Top5: 84.984 Loss: 1.494\n\n--- test ---------------------\n50000 samples (256 per mini-batch)\n==> Top1: 68.810 Top5: 88.626 Loss: 1.282", + "location": "/model_zoo/index.html#results_1", + "text": "As our baseline we used a pretrained PyTorch MobileNet model (width=1) which has Top1=68.848 and Top5=88.740. \nIn their paper, Zhu and Gupta prune 50% of the elements of MobileNet (width=1) with a 1.1% drop in accuracy. We pruned about 51.6% of the elements, with virtually no change in the accuracies (Top1: 68.808 and Top5: 88.656). We didn't try to prune more than this, but we do note that the baseline accuracy that we used is almost 2% lower than the accuracy published in the paper. +----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\n| | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |\n|----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|\n| 0 | module.model.0.0.weight | (32, 3, 3, 3) | 864 | 864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14466 | 0.00103 | 0.06508 |\n| 1 | module.model.1.0.weight | (32, 1, 3, 3) | 288 | 288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.32146 | 0.01020 | 0.12932 |\n| 2 | module.model.1.3.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11942 | 0.00024 | 0.03627 |\n| 3 | module.model.2.0.weight | (64, 1, 3, 3) | 576 | 576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15809 | 0.00543 | 0.11513 |\n| 4 | module.model.2.3.weight | (128, 64, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08442 | -0.00031 | 0.04182 |\n| 5 | module.model.3.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16780 | 0.00125 | 0.10545 |\n| 6 | module.model.3.3.weight | (128, 128, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07126 | -0.00197 | 0.04123 |\n| 7 | module.model.4.0.weight | (128, 1, 3, 3) | 1152 | 1152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10182 | 0.00171 | 0.08719 |\n| 8 | module.model.4.3.weight | (256, 128, 1, 1) | 32768 | 13108 | 0.00000 | 0.00000 | 10.15625 | 59.99756 | 12.50000 | 59.99756 | 0.05543 | -0.00002 | 0.02760 |\n| 9 | module.model.5.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12516 | -0.00288 | 0.08058 |\n| 10 | module.model.5.3.weight | (256, 256, 1, 1) | 65536 | 26215 | 0.00000 | 0.00000 | 12.50000 | 59.99908 | 23.82812 | 59.99908 | 0.04453 | 0.00002 | 0.02271 |\n| 11 | module.model.6.0.weight | (256, 1, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08024 | 0.00252 | 0.06377 |\n| 12 | module.model.6.3.weight | (512, 256, 1, 1) | 131072 | 52429 | 0.00000 | 0.00000 | 23.82812 | 59.99985 | 14.25781 | 59.99985 | 0.03561 | -0.00057 | 0.01779 |\n| 13 | module.model.7.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11008 | -0.00018 | 0.06829 |\n| 14 | module.model.7.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 14.25781 | 59.99985 | 21.28906 | 59.99985 | 0.02944 | -0.00060 | 0.01515 |\n| 15 | module.model.8.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08258 | 0.00370 | 0.04905 |\n| 16 | module.model.8.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 21.28906 | 59.99985 | 28.51562 | 59.99985 | 0.02865 | -0.00046 | 0.01465 |\n| 17 | module.model.9.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07578 | 0.00468 | 0.04201 |\n| 18 | module.model.9.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 28.51562 | 59.99985 | 23.43750 | 59.99985 | 0.02939 | -0.00044 | 0.01511 |\n| 19 | module.model.10.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07091 | 0.00014 | 0.04306 |\n| 20 | module.model.10.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 24.60938 | 59.99985 | 20.89844 | 59.99985 | 0.03095 | -0.00059 | 0.01672 |\n| 21 | module.model.11.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05729 | -0.00518 | 0.04267 |\n| 22 | module.model.11.3.weight | (512, 512, 1, 1) | 262144 | 104858 | 0.00000 | 0.00000 | 20.89844 | 59.99985 | 17.57812 | 59.99985 | 0.03229 | -0.00044 | 0.01797 |\n| 23 | module.model.12.0.weight | (512, 1, 3, 3) | 4608 | 4608 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04981 | -0.00136 | 0.03967 |\n| 24 | module.model.12.3.weight | (1024, 512, 1, 1) | 524288 | 209716 | 0.00000 | 0.00000 | 16.01562 | 59.99985 | 44.23828 | 59.99985 | 0.02514 | -0.00106 | 0.01278 |\n| 25 | module.model.13.0.weight | (1024, 1, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02396 | -0.00949 | 0.01549 |\n| 26 | module.model.13.3.weight | (1024, 1024, 1, 1) | 1048576 | 419431 | 0.00000 | 0.00000 | 44.72656 | 59.99994 | 1.46484 | 59.99994 | 0.01801 | -0.00017 | 0.00931 |\n| 27 | module.fc.weight | (1000, 1024) | 1024000 | 409600 | 1.46484 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 60.00000 | 0.05078 | 0.00271 | 0.02734 |\n| 28 | Total sparsity: | - | 4209088 | 1726917 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 58.97171 | 0.00000 | 0.00000 | 0.00000 |\n+----+--------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+\nTotal sparsity: 58.97\n\n--- validate (epoch=199)-----------\n128116 samples (256 per mini-batch)\n== Top1: 65.337 Top5: 84.984 Loss: 1.494\n\n--- test ---------------------\n50000 samples (256 per mini-batch)\n== Top1: 68.810 Top5: 88.626 Loss: 1.282", "title": "Results" - }, + }, { - "location": "/model_zoo/index.html#learning-structured-sparsity-in-deep-neural-networks", - "text": "This research paper from the University of Pittsburgh, \"proposes a Structured Sparsity Learning (SSL) method to regularize the structures (i.e., filters, channels, filter shapes, and layer depth) of DNNs. SSL can: (1) learn a compact structure from a bigger DNN to reduce computation cost; (2) obtain a hardware-friendly structured sparsity of DNN to efficiently accelerate the DNN\u2019s evaluation.\" Note that this paper does not use pruning, but instead uses group regularization during the training to force weights towards zero, as a group. We used a schedule which thresholds the regularized elements at a magnitude equal to the regularization strength. At the end of the regularization phase, we save the final sparsity masks generated by the regularization, and exit. Then we load this regularized model, remove the layers corresponding to the zeroed weight tensors (all of a layer's elements have a zero value).", + "location": "/model_zoo/index.html#learning-structured-sparsity-in-deep-neural-networks", + "text": "This research paper from the University of Pittsburgh, \"proposes a Structured Sparsity Learning (SSL) method to regularize the structures (i.e., filters, channels, filter shapes, and layer depth) of DNNs. SSL can: (1) learn a compact structure from a bigger DNN to reduce computation cost; (2) obtain a hardware-friendly structured sparsity of DNN to efficiently accelerate the DNN\u2019s evaluation.\" Note that this paper does not use pruning, but instead uses group regularization during the training to force weights towards zero, as a group. We used a schedule which thresholds the regularized elements at a magnitude equal to the regularization strength. At the end of the regularization phase, we save the final sparsity masks generated by the regularization, and exit. Then we load this regularized model, remove the layers corresponding to the zeroed weight tensors (all of a layer's elements have a zero value).", "title": "Learning Structured Sparsity in Deep Neural Networks" - }, + }, { - "location": "/model_zoo/index.html#baseline-training", - "text": "We started by training the baseline ResNet20-Cifar dense network since we didn't have a pre-trained model. Distiller schedule: distiller/examples/ssl/resnet20_cifar_baseline_training.yaml Checkpoint files: distiller/examples/ssl/checkpoints/ $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../cifar10/resnet20/baseline_training.yaml -j=1 --deterministic", + "location": "/model_zoo/index.html#baseline-training", + "text": "We started by training the baseline ResNet20-Cifar dense network since we didn't have a pre-trained model. Distiller schedule: distiller/examples/ssl/resnet20_cifar_baseline_training.yaml Checkpoint files: distiller/examples/ssl/checkpoints/ $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../cifar10/resnet20/baseline_training.yaml -j=1 --deterministic", "title": "Baseline training" - }, + }, { - "location": "/model_zoo/index.html#regularization", - "text": "Then we started training from scratch again, but this time we used Group Lasso regularization on entire layers: \nDistiller schedule: distiller/examples/ssl/ssl_4D-removal_4L_training.yaml $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../ssl/ssl_4D-removal_training.yaml -j=1 --deterministic The diagram below shows the training of Resnet20/CIFAR10 using Group Lasso regularization on entire layers (in blue) vs. training Resnet20/CIFAR10 baseline (in red). You may notice several interesting things:\n1. The LR-decay policy is the same, but the two sessions start with different initial LR values.\n2. The data-loss of the regularized training follows the same shape as the un-regularized training (baseline), and eventually the two seem to merge.\n3. We see similar behavior in the validation Top1 and Top5 accuracy results, but the regularized training eventually performs better.\n4. In the top right corner we see the behavior of the regularization loss ( Reg Loss ), which actually increases for some time, until the data-loss has a sharp drop (after ~16K mini-batches), at which point the regularization loss also starts dropping. This regularization yields 5 layers with zeroed weight tensors. We load this model, remove the 5 layers, and start the fine tuning of the weights. This process of layer removal is specific to ResNet for CIFAR, which we altered by adding code to skip over layers during the forward path. When you export to ONNX, the removed layers do not participate in the forward path, so they don't get incarnated. We managed to remove 5 of the 16 3x3 convolution layers which dominate the computation time. It's not bad, but we probably could have done better.", + "location": "/model_zoo/index.html#regularization", + "text": "Then we started training from scratch again, but this time we used Group Lasso regularization on entire layers: \nDistiller schedule: distiller/examples/ssl/ssl_4D-removal_4L_training.yaml $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../ssl/ssl_4D-removal_training.yaml -j=1 --deterministic The diagram below shows the training of Resnet20/CIFAR10 using Group Lasso regularization on entire layers (in blue) vs. training Resnet20/CIFAR10 baseline (in red). You may notice several interesting things:\n1. The LR-decay policy is the same, but the two sessions start with different initial LR values.\n2. The data-loss of the regularized training follows the same shape as the un-regularized training (baseline), and eventually the two seem to merge.\n3. We see similar behavior in the validation Top1 and Top5 accuracy results, but the regularized training eventually performs better.\n4. In the top right corner we see the behavior of the regularization loss ( Reg Loss ), which actually increases for some time, until the data-loss has a sharp drop (after ~16K mini-batches), at which point the regularization loss also starts dropping. This regularization yields 5 layers with zeroed weight tensors. We load this model, remove the 5 layers, and start the fine tuning of the weights. This process of layer removal is specific to ResNet for CIFAR, which we altered by adding code to skip over layers during the forward path. When you export to ONNX, the removed layers do not participate in the forward path, so they don't get incarnated. We managed to remove 5 of the 16 3x3 convolution layers which dominate the computation time. It's not bad, but we probably could have done better.", "title": "Regularization" - }, + }, { - "location": "/model_zoo/index.html#fine-tuning", - "text": "During the fine-tuning process, because the removed layers do not participate in the forward path, they do not appear in the backward path and are not backpropogated: therefore they are completely disconnected from the network. \nWe copy the checkpoint file of the regularized model to checkpoint_trained_4D_regularized_5Lremoved.pth.tar . \nDistiller schedule: distiller/examples/ssl/ssl_4D-removal_finetuning.yaml $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.1 --epochs=250 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --compress=../ssl/ssl_4D-removal_finetuning.yaml -j=1 --deterministic", + "location": "/model_zoo/index.html#fine-tuning", + "text": "During the fine-tuning process, because the removed layers do not participate in the forward path, they do not appear in the backward path and are not backpropogated: therefore they are completely disconnected from the network. \nWe copy the checkpoint file of the regularized model to checkpoint_trained_4D_regularized_5Lremoved.pth.tar . \nDistiller schedule: distiller/examples/ssl/ssl_4D-removal_finetuning.yaml $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --lr=0.1 --epochs=250 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --compress=../ssl/ssl_4D-removal_finetuning.yaml -j=1 --deterministic", "title": "Fine-tuning" - }, + }, { - "location": "/model_zoo/index.html#results_2", - "text": "Our baseline results for ResNet20 Cifar are: Top1=91.450 and Top5=99.750 We used Distiller's GroupLassoRegularizer to remove 5 layers from Resnet20 (CIFAR10) with no degradation of the accuracies. \nThe regularized model exhibits really poor classification abilities: $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --evaluate\n\n=> loading checkpoint ../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar\n best top@1: 90.620\nLoaded compression schedule from checkpoint (epoch 179)\nRemoving layer: module.layer1.0.conv1 [layer=0 block=0 conv=0]\nRemoving layer: module.layer1.0.conv2 [layer=0 block=0 conv=1]\nRemoving layer: module.layer1.1.conv1 [layer=0 block=1 conv=0]\nRemoving layer: module.layer1.1.conv2 [layer=0 block=1 conv=1]\nRemoving layer: module.layer2.2.conv2 [layer=1 block=2 conv=1]\nFiles already downloaded and verified\nFiles already downloaded and verified\nDataset sizes:\n training=45000\n validation=5000\n test=10000\n--- test ---------------------\n10000 samples (256 per mini-batch)\n==> Top1: 22.290 Top5: 68.940 Loss: 5.172 However, after fine-tuning, we recovered most of the accuracies loss, but not quite all of it: Top1=91.020 and Top5=99.670 We didn't spend time trying to wrestle with this network, and therefore didn't achieve SSL's published results (which showed that they managed to remove 6 layers and at the same time increase accuracies).", + "location": "/model_zoo/index.html#results_2", + "text": "Our baseline results for ResNet20 Cifar are: Top1=91.450 and Top5=99.750 We used Distiller's GroupLassoRegularizer to remove 5 layers from Resnet20 (CIFAR10) with no degradation of the accuracies. \nThe regularized model exhibits really poor classification abilities: $ time python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar --evaluate\n\n= loading checkpoint ../cifar10/resnet20/checkpoint_trained_4D_regularized_5Lremoved.pth.tar\n best top@1: 90.620\nLoaded compression schedule from checkpoint (epoch 179)\nRemoving layer: module.layer1.0.conv1 [layer=0 block=0 conv=0]\nRemoving layer: module.layer1.0.conv2 [layer=0 block=0 conv=1]\nRemoving layer: module.layer1.1.conv1 [layer=0 block=1 conv=0]\nRemoving layer: module.layer1.1.conv2 [layer=0 block=1 conv=1]\nRemoving layer: module.layer2.2.conv2 [layer=1 block=2 conv=1]\nFiles already downloaded and verified\nFiles already downloaded and verified\nDataset sizes:\n training=45000\n validation=5000\n test=10000\n--- test ---------------------\n10000 samples (256 per mini-batch)\n== Top1: 22.290 Top5: 68.940 Loss: 5.172 However, after fine-tuning, we recovered most of the accuracies loss, but not quite all of it: Top1=91.020 and Top5=99.670 We didn't spend time trying to wrestle with this network, and therefore didn't achieve SSL's published results (which showed that they managed to remove 6 layers and at the same time increase accuracies).", "title": "Results" - }, + }, { - "location": "/model_zoo/index.html#pruning-filters-for-efficient-convnets", - "text": "Quoting the authors directly: We present an acceleration method for CNNs, where we prune filters from CNNs that are identified as having a small effect on the output accuracy. By removing whole filters in the network together with their connecting feature maps, the computation costs are reduced significantly.\nIn contrast to pruning weights, this approach does not result in sparse connectivity patterns. Hence, it does not need the support of sparse convolution libraries and can work with existing efficient BLAS libraries for dense matrix multiplications. The implementation of the research by Hao et al. required us to add filter-pruning sensitivity analysis, and support for \"network thinning\". After performing filter-pruning sensitivity analysis to assess which layers are more sensitive to the pruning of filters, we execute distiller.L1RankedStructureParameterPruner once in order to rank the filters of each layer by their L1-norm values, and then we prune the schedule-prescribed sparsity level. Distiller schedule: distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml Checkpoint files: checkpoint_finetuned.pth.tar The excerpt from the schedule, displayed below, shows how we declare the L1RankedStructureParameterPruner. This class currently ranks filters only, but because in the future this class may support ranking of various structures, you need to specify for each parameter both the target sparsity level, and the structure type ('3D' is filter-wise pruning). pruners:\n filter_pruner:\n class: 'L1RankedStructureParameterPruner'\n reg_regims:\n 'module.layer1.0.conv1.weight': [0.6, '3D']\n 'module.layer1.1.conv1.weight': [0.6, '3D']\n 'module.layer1.2.conv1.weight': [0.6, '3D']\n 'module.layer1.3.conv1.weight': [0.6, '3D'] In the policy, we specify that we want to invoke this pruner once, at epoch 180. Because we are starting from a network which was trained for 180 epochs (see Baseline training below), the filter ranking is performed right at the outset of this schedule. policies:\n - pruner:\n instance_name: filter_pruner\n epochs: [180] Following the pruning, we want to \"physically\" remove the pruned filters from the network, which involves reconfiguring the Convolutional layers and the parameter tensors. When we remove filters from Convolution layer n we need to perform several changes to the network:\n1. Shrink layer n 's weights tensor, leaving only the \"important\" filters.\n2. Configure layer n 's .out_channels member to its new, smaller, value.\n3. If a BN layer follows layer n , then it also needs to be reconfigured and its scale and shift parameter vectors need to be shrunk.\n4. If a Convolution layer follows the BN layer, then it will have less input channels which requires reconfiguration and shrinking of its weights. All of this is performed by distiller.ResnetCifarFilterRemover which is also scheduled at epoch 180. We call this process \"network thinning\". extensions:\n net_thinner:\n class: 'FilterRemover'\n thinning_func_str: remove_filters\n arch: 'resnet56_cifar'\n dataset: 'cifar10' Network thinning requires us to understand the layer connectivity and data-dependency of the DNN, and we are working on a robust method to perform this. On networks with topologies similar to ResNet (residuals) and GoogLeNet (inception), which have several inputs and outputs to/from Convolution layers, there is extra details to consider. \nOur current implementation is specific to certain layers in ResNet and is a bit fragile. We will continue to improve and generalize this.", + "location": "/model_zoo/index.html#pruning-filters-for-efficient-convnets", + "text": "Quoting the authors directly: We present an acceleration method for CNNs, where we prune filters from CNNs that are identified as having a small effect on the output accuracy. By removing whole filters in the network together with their connecting feature maps, the computation costs are reduced significantly.\nIn contrast to pruning weights, this approach does not result in sparse connectivity patterns. Hence, it does not need the support of sparse convolution libraries and can work with existing efficient BLAS libraries for dense matrix multiplications. The implementation of the research by Hao et al. required us to add filter-pruning sensitivity analysis, and support for \"network thinning\". After performing filter-pruning sensitivity analysis to assess which layers are more sensitive to the pruning of filters, we execute distiller.L1RankedStructureParameterPruner once in order to rank the filters of each layer by their L1-norm values, and then we prune the schedule-prescribed sparsity level. Distiller schedule: distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml Checkpoint files: checkpoint_finetuned.pth.tar The excerpt from the schedule, displayed below, shows how we declare the L1RankedStructureParameterPruner. This class currently ranks filters only, but because in the future this class may support ranking of various structures, you need to specify for each parameter both the target sparsity level, and the structure type ('3D' is filter-wise pruning). pruners:\n filter_pruner:\n class: 'L1RankedStructureParameterPruner'\n reg_regims:\n 'module.layer1.0.conv1.weight': [0.6, '3D']\n 'module.layer1.1.conv1.weight': [0.6, '3D']\n 'module.layer1.2.conv1.weight': [0.6, '3D']\n 'module.layer1.3.conv1.weight': [0.6, '3D'] In the policy, we specify that we want to invoke this pruner once, at epoch 180. Because we are starting from a network which was trained for 180 epochs (see Baseline training below), the filter ranking is performed right at the outset of this schedule. policies:\n - pruner:\n instance_name: filter_pruner\n epochs: [180] Following the pruning, we want to \"physically\" remove the pruned filters from the network, which involves reconfiguring the Convolutional layers and the parameter tensors. When we remove filters from Convolution layer n we need to perform several changes to the network:\n1. Shrink layer n 's weights tensor, leaving only the \"important\" filters.\n2. Configure layer n 's .out_channels member to its new, smaller, value.\n3. If a BN layer follows layer n , then it also needs to be reconfigured and its scale and shift parameter vectors need to be shrunk.\n4. If a Convolution layer follows the BN layer, then it will have less input channels which requires reconfiguration and shrinking of its weights. All of this is performed by distiller.ResnetCifarFilterRemover which is also scheduled at epoch 180. We call this process \"network thinning\". extensions:\n net_thinner:\n class: 'FilterRemover'\n thinning_func_str: remove_filters\n arch: 'resnet56_cifar'\n dataset: 'cifar10' Network thinning requires us to understand the layer connectivity and data-dependency of the DNN, and we are working on a robust method to perform this. On networks with topologies similar to ResNet (residuals) and GoogLeNet (inception), which have several inputs and outputs to/from Convolution layers, there is extra details to consider. \nOur current implementation is specific to certain layers in ResNet and is a bit fragile. We will continue to improve and generalize this.", "title": "Pruning Filters for Efficient ConvNets" - }, + }, { - "location": "/model_zoo/index.html#baseline-training_1", - "text": "We started by training the baseline ResNet56-Cifar dense network (180 epochs) since we didn't have a pre-trained model. Distiller schedule: distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml Checkpoint files: checkpoint.resnet56_cifar_baseline.pth.tar", + "location": "/model_zoo/index.html#baseline-training_1", + "text": "We started by training the baseline ResNet56-Cifar dense network (180 epochs) since we didn't have a pre-trained model. Distiller schedule: distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_baseline_training.yaml Checkpoint files: checkpoint.resnet56_cifar_baseline.pth.tar", "title": "Baseline training" - }, + }, { - "location": "/model_zoo/index.html#results_3", - "text": "We trained a ResNet56-Cifar10 network and achieve accuracy results which are on-par with published results:\nTop1: 92.970 and Top5: 99.740. We used Hao et al.'s algorithm to remove 37.3% of the original convolution MACs, while maintaining virtually the same accuracy as the baseline:\nTop1: 92.830 and Top5: 99.760", + "location": "/model_zoo/index.html#results_3", + "text": "We trained a ResNet56-Cifar10 network and achieve accuracy results which are on-par with published results:\nTop1: 92.970 and Top5: 99.740. We used Hao et al.'s algorithm to remove 37.3% of the original convolution MACs, while maintaining virtually the same accuracy as the baseline:\nTop1: 92.830 and Top5: 99.760", "title": "Results" - }, + }, { - "location": "/jupyter/index.html", - "text": "Jupyter environment\n\n\nThe Jupyter notebooks environment allows us to plan our compression session and load Distiller data summaries to study and analyze compression results.\n\n\nEach notebook has embedded instructions and explanations, so here we provide only a brief description of each notebook.\n\n\nInstallation\n\n\nJupyter and its dependencies are included as part of the main \nrequirements.txt\n file, so there is no need for a dedicated installation step.\n\nHowever, to use the ipywidgets extension, you will need to enable it:\n\n\n$ jupyter nbextension enable --py widgetsnbextension --sys-prefix\n\n\n\n\nYou may want to refer to the \nipywidgets extension installation documentation\n.\n\n\nAnother extension which requires special installation handling is \nQgrid\n. Qgrid is a Jupyter notebook widget that adds interactive features, such as sorting, to Panadas DataFrames rendering. To enable Qgrid:\n\n\n$ jupyter nbextension enable --py --sys-prefix qgrid\n\n\n\n\nLaunching the Jupyter server\n\n\nThere are all kinds of options to use when launching Jupyter which you can use. The example below tells the server to listen to connections from any IP address, and not to launch the browser window, but of course, you are free to launch Jupyter any way you want.\n\nConsult the \nuser's guide\n for more details.\n\n\n$ jupyter-notebook --ip=0.0.0.0 --no-browser\n\n\n\n\nUsing the Distiller notebooks\n\n\nThe Distiller Jupyter notebooks are located in the \ndistiller/jupyter\n directory.\n\nThey are provided as tools that you can use to prepare your compression experiments and study their results.\nWe welcome new ideas and implementations of Jupyter.\n\n\nRoughly, the notebooks can be divided into three categories.\n\n\nTheory\n\n\n\n\njupyter/L1-regularization.ipynb\n: Experience hands-on how L1 and L2 regularization affect the solution of a toy loss-minimization problem, to get a better grasp on the interaction between regularization and sparsity.\n\n\njupyter/alexnet_insights.ipynb\n: This notebook reviews and compares a couple of pruning sessions on Alexnet. We compare distributions, performance, statistics and show some visualizations of the weights tensors.\n\n\n\n\nPreparation for compression\n\n\n\n\njupyter/model_summary.ipynb\n: Begin by getting familiar with your model. Examine the sizes and properties of layers and connections. Study which layers are compute-bound, and which are bandwidth-bound, and decide how to prune or regularize the model.\n\n\njupyter/sensitivity_analysis.ipynb\n: If you performed pruning sensitivity analysis on your model, this notebook can help you load the results and graphically study how the layers behave.\n\n\njupyter/interactive_lr_scheduler.ipynb\n: The learning rate decay policy affects pruning results, perhaps as much as it affects training results. Graph a few LR-decay policies to see how they behave.\n\n\njupyter/jupyter/agp_schedule.ipynb\n: If you are using the Automated Gradual Pruner, this notebook can help you tune the schedule.\n\n\n\n\nReviewing experiment results\n\n\n\n\njupyter/compare_executions.ipynb\n: This is a simple notebook to help you graphically compare the results of executions of several experiments.\n\n\njupyter/compression_insights.ipynb\n: This notebook is packed with code, tables and graphs to us understand the results of a compression session. Distiller provides \nsummaries\n, which are Pandas dataframes, which contain statistical information about you model. We chose to use Pandas dataframes because they can be sliced, queried, summarized and graphed with a few lines of code.", + "location": "/jupyter/index.html", + "text": "Jupyter environment\n\n\nThe Jupyter notebooks environment allows us to plan our compression session and load Distiller data summaries to study and analyze compression results.\n\n\nEach notebook has embedded instructions and explanations, so here we provide only a brief description of each notebook.\n\n\nInstallation\n\n\nJupyter and its dependencies are included as part of the main \nrequirements.txt\n file, so there is no need for a dedicated installation step.\n\nHowever, to use the ipywidgets extension, you will need to enable it:\n\n\n$ jupyter nbextension enable --py widgetsnbextension --sys-prefix\n\n\n\n\nYou may want to refer to the \nipywidgets extension installation documentation\n.\n\n\nAnother extension which requires special installation handling is \nQgrid\n. Qgrid is a Jupyter notebook widget that adds interactive features, such as sorting, to Panadas DataFrames rendering. To enable Qgrid:\n\n\n$ jupyter nbextension enable --py --sys-prefix qgrid\n\n\n\n\nLaunching the Jupyter server\n\n\nThere are all kinds of options to use when launching Jupyter which you can use. The example below tells the server to listen to connections from any IP address, and not to launch the browser window, but of course, you are free to launch Jupyter any way you want.\n\nConsult the \nuser's guide\n for more details.\n\n\n$ jupyter-notebook --ip=0.0.0.0 --no-browser\n\n\n\n\nUsing the Distiller notebooks\n\n\nThe Distiller Jupyter notebooks are located in the \ndistiller/jupyter\n directory.\n\nThey are provided as tools that you can use to prepare your compression experiments and study their results.\nWe welcome new ideas and implementations of Jupyter.\n\n\nRoughly, the notebooks can be divided into three categories.\n\n\nTheory\n\n\n\n\njupyter/L1-regularization.ipynb\n: Experience hands-on how L1 and L2 regularization affect the solution of a toy loss-minimization problem, to get a better grasp on the interaction between regularization and sparsity.\n\n\njupyter/alexnet_insights.ipynb\n: This notebook reviews and compares a couple of pruning sessions on Alexnet. We compare distributions, performance, statistics and show some visualizations of the weights tensors.\n\n\n\n\nPreparation for compression\n\n\n\n\njupyter/model_summary.ipynb\n: Begin by getting familiar with your model. Examine the sizes and properties of layers and connections. Study which layers are compute-bound, and which are bandwidth-bound, and decide how to prune or regularize the model.\n\n\njupyter/sensitivity_analysis.ipynb\n: If you performed pruning sensitivity analysis on your model, this notebook can help you load the results and graphically study how the layers behave.\n\n\njupyter/interactive_lr_scheduler.ipynb\n: The learning rate decay policy affects pruning results, perhaps as much as it affects training results. Graph a few LR-decay policies to see how they behave.\n\n\njupyter/jupyter/agp_schedule.ipynb\n: If you are using the Automated Gradual Pruner, this notebook can help you tune the schedule.\n\n\n\n\nReviewing experiment results\n\n\n\n\njupyter/compare_executions.ipynb\n: This is a simple notebook to help you graphically compare the results of executions of several experiments.\n\n\njupyter/compression_insights.ipynb\n: This notebook is packed with code, tables and graphs to us understand the results of a compression session. Distiller provides \nsummaries\n, which are Pandas dataframes, which contain statistical information about you model. We chose to use Pandas dataframes because they can be sliced, queried, summarized and graphed with a few lines of code.", "title": "Jupyter notebooks" - }, + }, { - "location": "/jupyter/index.html#jupyter-environment", - "text": "The Jupyter notebooks environment allows us to plan our compression session and load Distiller data summaries to study and analyze compression results. Each notebook has embedded instructions and explanations, so here we provide only a brief description of each notebook.", + "location": "/jupyter/index.html#jupyter-environment", + "text": "The Jupyter notebooks environment allows us to plan our compression session and load Distiller data summaries to study and analyze compression results. Each notebook has embedded instructions and explanations, so here we provide only a brief description of each notebook.", "title": "Jupyter environment" - }, + }, { - "location": "/jupyter/index.html#installation", - "text": "Jupyter and its dependencies are included as part of the main requirements.txt file, so there is no need for a dedicated installation step. \nHowever, to use the ipywidgets extension, you will need to enable it: $ jupyter nbextension enable --py widgetsnbextension --sys-prefix You may want to refer to the ipywidgets extension installation documentation . Another extension which requires special installation handling is Qgrid . Qgrid is a Jupyter notebook widget that adds interactive features, such as sorting, to Panadas DataFrames rendering. To enable Qgrid: $ jupyter nbextension enable --py --sys-prefix qgrid", + "location": "/jupyter/index.html#installation", + "text": "Jupyter and its dependencies are included as part of the main requirements.txt file, so there is no need for a dedicated installation step. \nHowever, to use the ipywidgets extension, you will need to enable it: $ jupyter nbextension enable --py widgetsnbextension --sys-prefix You may want to refer to the ipywidgets extension installation documentation . Another extension which requires special installation handling is Qgrid . Qgrid is a Jupyter notebook widget that adds interactive features, such as sorting, to Panadas DataFrames rendering. To enable Qgrid: $ jupyter nbextension enable --py --sys-prefix qgrid", "title": "Installation" - }, + }, { - "location": "/jupyter/index.html#launching-the-jupyter-server", - "text": "There are all kinds of options to use when launching Jupyter which you can use. The example below tells the server to listen to connections from any IP address, and not to launch the browser window, but of course, you are free to launch Jupyter any way you want. \nConsult the user's guide for more details. $ jupyter-notebook --ip=0.0.0.0 --no-browser", + "location": "/jupyter/index.html#launching-the-jupyter-server", + "text": "There are all kinds of options to use when launching Jupyter which you can use. The example below tells the server to listen to connections from any IP address, and not to launch the browser window, but of course, you are free to launch Jupyter any way you want. \nConsult the user's guide for more details. $ jupyter-notebook --ip=0.0.0.0 --no-browser", "title": "Launching the Jupyter server" - }, + }, { - "location": "/jupyter/index.html#using-the-distiller-notebooks", - "text": "The Distiller Jupyter notebooks are located in the distiller/jupyter directory. \nThey are provided as tools that you can use to prepare your compression experiments and study their results.\nWe welcome new ideas and implementations of Jupyter. Roughly, the notebooks can be divided into three categories.", + "location": "/jupyter/index.html#using-the-distiller-notebooks", + "text": "The Distiller Jupyter notebooks are located in the distiller/jupyter directory. \nThey are provided as tools that you can use to prepare your compression experiments and study their results.\nWe welcome new ideas and implementations of Jupyter. Roughly, the notebooks can be divided into three categories.", "title": "Using the Distiller notebooks" - }, + }, { - "location": "/jupyter/index.html#theory", - "text": "jupyter/L1-regularization.ipynb : Experience hands-on how L1 and L2 regularization affect the solution of a toy loss-minimization problem, to get a better grasp on the interaction between regularization and sparsity. jupyter/alexnet_insights.ipynb : This notebook reviews and compares a couple of pruning sessions on Alexnet. We compare distributions, performance, statistics and show some visualizations of the weights tensors.", + "location": "/jupyter/index.html#theory", + "text": "jupyter/L1-regularization.ipynb : Experience hands-on how L1 and L2 regularization affect the solution of a toy loss-minimization problem, to get a better grasp on the interaction between regularization and sparsity. jupyter/alexnet_insights.ipynb : This notebook reviews and compares a couple of pruning sessions on Alexnet. We compare distributions, performance, statistics and show some visualizations of the weights tensors.", "title": "Theory" - }, + }, { - "location": "/jupyter/index.html#preparation-for-compression", - "text": "jupyter/model_summary.ipynb : Begin by getting familiar with your model. Examine the sizes and properties of layers and connections. Study which layers are compute-bound, and which are bandwidth-bound, and decide how to prune or regularize the model. jupyter/sensitivity_analysis.ipynb : If you performed pruning sensitivity analysis on your model, this notebook can help you load the results and graphically study how the layers behave. jupyter/interactive_lr_scheduler.ipynb : The learning rate decay policy affects pruning results, perhaps as much as it affects training results. Graph a few LR-decay policies to see how they behave. jupyter/jupyter/agp_schedule.ipynb : If you are using the Automated Gradual Pruner, this notebook can help you tune the schedule.", + "location": "/jupyter/index.html#preparation-for-compression", + "text": "jupyter/model_summary.ipynb : Begin by getting familiar with your model. Examine the sizes and properties of layers and connections. Study which layers are compute-bound, and which are bandwidth-bound, and decide how to prune or regularize the model. jupyter/sensitivity_analysis.ipynb : If you performed pruning sensitivity analysis on your model, this notebook can help you load the results and graphically study how the layers behave. jupyter/interactive_lr_scheduler.ipynb : The learning rate decay policy affects pruning results, perhaps as much as it affects training results. Graph a few LR-decay policies to see how they behave. jupyter/jupyter/agp_schedule.ipynb : If you are using the Automated Gradual Pruner, this notebook can help you tune the schedule.", "title": "Preparation for compression" - }, + }, { - "location": "/jupyter/index.html#reviewing-experiment-results", - "text": "jupyter/compare_executions.ipynb : This is a simple notebook to help you graphically compare the results of executions of several experiments. jupyter/compression_insights.ipynb : This notebook is packed with code, tables and graphs to us understand the results of a compression session. Distiller provides summaries , which are Pandas dataframes, which contain statistical information about you model. We chose to use Pandas dataframes because they can be sliced, queried, summarized and graphed with a few lines of code.", + "location": "/jupyter/index.html#reviewing-experiment-results", + "text": "jupyter/compare_executions.ipynb : This is a simple notebook to help you graphically compare the results of executions of several experiments. jupyter/compression_insights.ipynb : This notebook is packed with code, tables and graphs to us understand the results of a compression session. Distiller provides summaries , which are Pandas dataframes, which contain statistical information about you model. We chose to use Pandas dataframes because they can be sliced, queried, summarized and graphed with a few lines of code.", "title": "Reviewing experiment results" - }, + }, { - "location": "/design/index.html", - "text": "Distiller design\n\n\nDistiller is designed to be easily integrated into your own PyTorch research applications.\n\nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models (\ncompress_classifier.py\n).\n\n\nThe application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand.\n\n\nIntegrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training. The training skeleton looks like the pseudo code below. The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler.\n\n\nFor each epoch:\n compression_scheduler.on_epoch_begin(epoch)\n train()\n validate()\n save_checkpoint()\n compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n For each training step:\n compression_scheduler.on_minibatch_begin(epoch)\n output = model(input_var)\n loss = criterion(output, target_var)\n compression_scheduler.before_backward_pass(epoch)\n loss.backward()\n optimizer.step()\n compression_scheduler.on_minibatch_end(epoch)\n\n\n\n\nThese callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's \nScheduler\n, which invokes the correct algorithm. The application also uses Distiller services to collect statistics in \nSummaries\n and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.\n\n\n\n\nSparsification and fine-tuning\n\n\n\n\nThe application sets up a model as normally done in PyTorch.\n\n\nAnd then instantiates a Scheduler and configures it:\n\n\nScheduler configuration is defined in a YAML file\n\n\nThe configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training.\n\n\nSome types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\".\n\n\nSome algorithms control some parameter of the training process, such as the learning-rate decay scheduler (\nlr_scheduler\n).\n\n\nThe parameters of each algorithm are also specified in the configuration.\n\n\n\n\n\n\n\n\n\n\nIn addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency.\n\n\nThe Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined.\n\n\nThese callbacks are placed the training loop.\n\n\n\n\nQuantization\n\n\nA quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary.\n\n\nIn Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided.\n\n\nWe also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the \nQuantizer\n class. \nQuantizer\n should be sub-classed for each quantization method.\n\n\nModel Transformation\n\n\nThe high-level flow is as follows:\n\n\n\n\nDefine a \nmapping\n between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the \nreplacement_factory\n attribute of the \nQuantizer\n class.\n\n\nIterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it.\n\n\nReplace the existing module with the module returned by the function. It is important to note that the \nname\n of the module \ndoes not\n change, as that could break the \nforward\n function of the parent module.\n\n\n\n\nDifferent quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different \nmapping\n will likely be defined.\n\nEach sub-class of \nQuantizer\n should populate the \nreplacement_factory\n dictionary attribute with the appropriate mapping.\n\nTo execute the model transformation, call the \nprepare_model\n function of the \nQuantizer\n instance.\n\n\nFlexible Bit-Widths\n\n\n\n\nEach instance of \nQuantizer\n is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the \nbits_activations\n and \nbits_weights\n parameters in \nQuantizer\n's constructor. Sub-classes may define bit-widths for other tensor types as needed.\n\n\nWe also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.\n\n\nSo, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the \nbits_overrides\n parameter in the constructor.\n\n\nThe \nbits_overrides\n mapping is required to be an instance of \ncollections.OrderedDict\n (as opposed to just a simple Python \ndict\n). This is done in order to enable handling of overlapping name patterns.\n\n So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.\n\n The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.\n\n\n\n\nWeights Quantization\n\n\nThe \nQuantizer\n class also provides an API to quantize the weights of all layers at once. To use it, the \nparam_quantization_fn\n attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the \nQuantizer\n class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the \nquantize_params\n function can be called, which will iterate over all parameters and quantize them using \nparams_quantization_fn\n.\n\n\nTraining with Quantization\n\n\nThe \nQuantizer\n class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:\n\n\n\n\n\n\nMaintaining a full precision copy of the weights, as described \nhere\n. This is enabled by setting \ntrain_with_fp_copy=True\n in the \nQuantizer\n constructor. At model transformation, in each module that has parameters that should be quantized, a new \ntorch.nn.Parameter\n is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module \nis not\n created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\": \n\n\n\n\nThe existing \ntorch.nn.Parameter\n, e.g. \nweights\n, is replaced by a \ntorch.nn.Parameter\n named \nfloat_weight\n.\n\n\nTo maintain the existing functionality of the module, we then register a \nbuffer\n in the module with the original name - \nweights\n.\n\n\nDuring training, \nfloat_weight\n will be passed to \nparam_quantization_fn\n and the result will be stored in \nweight\n.\n\n\n\n\n\n\n\n\nIn addition, some quantization methods may introduce additional learned parameters to the model. For example, in the \nPACT\n method, acitvations are clipped to a value \n\\alpha\n, which is a learned parameter per-layer\n\n\n\n\n\n\nTo support these two cases, the \nQuantizer\n class also accepts an instance of a \ntorch.optim.Optimizer\n (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters. \n\n\n\n\nOptimizing New Parameters\n\n\nIn cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override \nQuantizer._get_updated_optimizer_params_groups()\n, and return the proper groups plus any desired hyper-parameter overrides.\n\n\n\n\nExamples\n\n\nThe base \nQuantizer\n class is implemented in \ndistiller/quantization/quantizer.py\n.\n\nFor a simple sub-class implementing symmetric linear quantization, see \nSymmetricLinearQuantizer\n in \ndistiller/quantization/range_linear.py\n.\n\nIn \ndistiller/quantization/clipped_linear.py\n there are examples of lower-precision methods which use training with quantization. Specifically, see \nPACTQuantizer\n for an example of overriding \nQuantizer._get_updated_optimizer_params_groups()\n.", + "location": "/design/index.html", + "text": "Distiller design\n\n\nDistiller is designed to be easily integrated into your own PyTorch research applications.\n\nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models (\ncompress_classifier.py\n).\n\n\nThe application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand.\n\n\nIntegrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training. The training skeleton looks like the pseudo code below. The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler.\n\n\nFor each epoch:\n compression_scheduler.on_epoch_begin(epoch)\n train()\n validate()\n save_checkpoint()\n compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n For each training step:\n compression_scheduler.on_minibatch_begin(epoch)\n output = model(input_var)\n loss = criterion(output, target_var)\n compression_scheduler.before_backward_pass(epoch)\n loss.backward()\n optimizer.step()\n compression_scheduler.on_minibatch_end(epoch)\n\n\n\n\nThese callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's \nScheduler\n, which invokes the correct algorithm. The application also uses Distiller services to collect statistics in \nSummaries\n and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.\n\n\n\n\nSparsification and fine-tuning\n\n\n\n\nThe application sets up a model as normally done in PyTorch.\n\n\nAnd then instantiates a Scheduler and configures it:\n\n\nScheduler configuration is defined in a YAML file\n\n\nThe configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training.\n\n\nSome types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\".\n\n\nSome algorithms control some parameter of the training process, such as the learning-rate decay scheduler (\nlr_scheduler\n).\n\n\nThe parameters of each algorithm are also specified in the configuration.\n\n\n\n\n\n\n\n\n\n\nIn addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency.\n\n\nThe Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined.\n\n\nThese callbacks are placed the training loop.\n\n\n\n\nQuantization\n\n\nA quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary.\n\n\nIn Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided.\n\n\nWe also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the \nQuantizer\n class. \nQuantizer\n should be sub-classed for each quantization method.\n\n\nModel Transformation\n\n\nThe high-level flow is as follows:\n\n\n\n\nDefine a \nmapping\n between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the \nreplacement_factory\n attribute of the \nQuantizer\n class.\n\n\nIterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it.\n\n\nReplace the existing module with the module returned by the function. It is important to note that the \nname\n of the module \ndoes not\n change, as that could break the \nforward\n function of the parent module.\n\n\n\n\nDifferent quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different \nmapping\n will likely be defined.\n\nEach sub-class of \nQuantizer\n should populate the \nreplacement_factory\n dictionary attribute with the appropriate mapping.\n\nTo execute the model transformation, call the \nprepare_model\n function of the \nQuantizer\n instance.\n\n\nFlexible Bit-Widths\n\n\n\n\nEach instance of \nQuantizer\n is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the \nbits_activations\n and \nbits_weights\n parameters in \nQuantizer\n's constructor. Sub-classes may define bit-widths for other tensor types as needed.\n\n\nWe also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.\n\n\nSo, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the \nbits_overrides\n parameter in the constructor.\n\n\nThe \nbits_overrides\n mapping is required to be an instance of \ncollections.OrderedDict\n (as opposed to just a simple Python \ndict\n). This is done in order to enable handling of overlapping name patterns.\n\n So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.\n\n The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.\n\n\n\n\nWeights Quantization\n\n\nThe \nQuantizer\n class also provides an API to quantize the weights of all layers at once. To use it, the \nparam_quantization_fn\n attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the \nQuantizer\n class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the \nquantize_params\n function can be called, which will iterate over all parameters and quantize them using \nparams_quantization_fn\n.\n\n\nQuantization-Aware Training\n\n\nThe \nQuantizer\n class supports quantization-aware training, that is - training with quantization in the loop. This requires handling of a couple of flows / scenarios:\n\n\n\n\n\n\nMaintaining a full precision copy of the weights, as described \nhere\n. This is enabled by setting \ntrain_with_fp_copy=True\n in the \nQuantizer\n constructor. At model transformation, in each module that has parameters that should be quantized, a new \ntorch.nn.Parameter\n is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module \nis not\n created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\": \n\n\n\n\nThe existing \ntorch.nn.Parameter\n, e.g. \nweights\n, is replaced by a \ntorch.nn.Parameter\n named \nfloat_weight\n.\n\n\nTo maintain the existing functionality of the module, we then register a \nbuffer\n in the module with the original name - \nweights\n.\n\n\nDuring training, \nfloat_weight\n will be passed to \nparam_quantization_fn\n and the result will be stored in \nweight\n.\n\n\n\n\n\n\n\n\nIn addition, some quantization methods may introduce additional learned parameters to the model. For example, in the \nPACT\n method, acitvations are clipped to a value \n\\alpha\n, which is a learned parameter per-layer\n\n\n\n\n\n\nTo support these two cases, the \nQuantizer\n class also accepts an instance of a \ntorch.optim.Optimizer\n (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters. \n\n\n\n\nOptimizing New Parameters\n\n\nIn cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override \nQuantizer._get_updated_optimizer_params_groups()\n, and return the proper groups plus any desired hyper-parameter overrides.\n\n\n\n\nExamples\n\n\nThe base \nQuantizer\n class is implemented in \ndistiller/quantization/quantizer.py\n.\n\nFor a simple sub-class implementing symmetric linear quantization, see \nSymmetricLinearQuantizer\n in \ndistiller/quantization/range_linear.py\n.\n\nIn \ndistiller/quantization/clipped_linear.py\n there are examples of lower-precision methods which use training with quantization. Specifically, see \nPACTQuantizer\n for an example of overriding \nQuantizer._get_updated_optimizer_params_groups()\n.", "title": "Design" - }, + }, { - "location": "/design/index.html#distiller-design", - "text": "Distiller is designed to be easily integrated into your own PyTorch research applications. \nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models ( compress_classifier.py ). The application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand. Integrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training. The training skeleton looks like the pseudo code below. The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler. For each epoch:\n compression_scheduler.on_epoch_begin(epoch)\n train()\n validate()\n save_checkpoint()\n compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n For each training step:\n compression_scheduler.on_minibatch_begin(epoch)\n output = model(input_var)\n loss = criterion(output, target_var)\n compression_scheduler.before_backward_pass(epoch)\n loss.backward()\n optimizer.step()\n compression_scheduler.on_minibatch_end(epoch) These callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's Scheduler , which invokes the correct algorithm. The application also uses Distiller services to collect statistics in Summaries and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.", + "location": "/design/index.html#distiller-design", + "text": "Distiller is designed to be easily integrated into your own PyTorch research applications. \nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models ( compress_classifier.py ). The application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand. Integrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training. The training skeleton looks like the pseudo code below. The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler. For each epoch:\n compression_scheduler.on_epoch_begin(epoch)\n train()\n validate()\n save_checkpoint()\n compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n For each training step:\n compression_scheduler.on_minibatch_begin(epoch)\n output = model(input_var)\n loss = criterion(output, target_var)\n compression_scheduler.before_backward_pass(epoch)\n loss.backward()\n optimizer.step()\n compression_scheduler.on_minibatch_end(epoch) These callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's Scheduler , which invokes the correct algorithm. The application also uses Distiller services to collect statistics in Summaries and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.", "title": "Distiller design" - }, + }, { - "location": "/design/index.html#sparsification-and-fine-tuning", - "text": "The application sets up a model as normally done in PyTorch. And then instantiates a Scheduler and configures it: Scheduler configuration is defined in a YAML file The configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training. Some types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\". Some algorithms control some parameter of the training process, such as the learning-rate decay scheduler ( lr_scheduler ). The parameters of each algorithm are also specified in the configuration. In addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency. The Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined. These callbacks are placed the training loop.", + "location": "/design/index.html#sparsification-and-fine-tuning", + "text": "The application sets up a model as normally done in PyTorch. And then instantiates a Scheduler and configures it: Scheduler configuration is defined in a YAML file The configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training. Some types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\". Some algorithms control some parameter of the training process, such as the learning-rate decay scheduler ( lr_scheduler ). The parameters of each algorithm are also specified in the configuration. In addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency. The Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined. These callbacks are placed the training loop.", "title": "Sparsification and fine-tuning" - }, + }, { - "location": "/design/index.html#quantization", - "text": "A quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary. In Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided. We also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the Quantizer class. Quantizer should be sub-classed for each quantization method.", + "location": "/design/index.html#quantization", + "text": "A quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary. In Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided. We also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the Quantizer class. Quantizer should be sub-classed for each quantization method.", "title": "Quantization" - }, + }, { - "location": "/design/index.html#model-transformation", - "text": "The high-level flow is as follows: Define a mapping between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the replacement_factory attribute of the Quantizer class. Iterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it. Replace the existing module with the module returned by the function. It is important to note that the name of the module does not change, as that could break the forward function of the parent module. Different quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different mapping will likely be defined. \nEach sub-class of Quantizer should populate the replacement_factory dictionary attribute with the appropriate mapping. \nTo execute the model transformation, call the prepare_model function of the Quantizer instance.", + "location": "/design/index.html#model-transformation", + "text": "The high-level flow is as follows: Define a mapping between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the replacement_factory attribute of the Quantizer class. Iterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it. Replace the existing module with the module returned by the function. It is important to note that the name of the module does not change, as that could break the forward function of the parent module. Different quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different mapping will likely be defined. \nEach sub-class of Quantizer should populate the replacement_factory dictionary attribute with the appropriate mapping. \nTo execute the model transformation, call the prepare_model function of the Quantizer instance.", "title": "Model Transformation" - }, + }, { - "location": "/design/index.html#flexible-bit-widths", - "text": "Each instance of Quantizer is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the bits_activations and bits_weights parameters in Quantizer 's constructor. Sub-classes may define bit-widths for other tensor types as needed. We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern. So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the bits_overrides parameter in the constructor. The bits_overrides mapping is required to be an instance of collections.OrderedDict (as opposed to just a simple Python dict ). This is done in order to enable handling of overlapping name patterns. \n So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'. \n The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.", + "location": "/design/index.html#flexible-bit-widths", + "text": "Each instance of Quantizer is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the bits_activations and bits_weights parameters in Quantizer 's constructor. Sub-classes may define bit-widths for other tensor types as needed. We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern. So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the bits_overrides parameter in the constructor. The bits_overrides mapping is required to be an instance of collections.OrderedDict (as opposed to just a simple Python dict ). This is done in order to enable handling of overlapping name patterns. \n So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'. \n The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.", "title": "Flexible Bit-Widths" - }, + }, { - "location": "/design/index.html#weights-quantization", - "text": "The Quantizer class also provides an API to quantize the weights of all layers at once. To use it, the param_quantization_fn attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the Quantizer class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the quantize_params function can be called, which will iterate over all parameters and quantize them using params_quantization_fn .", + "location": "/design/index.html#weights-quantization", + "text": "The Quantizer class also provides an API to quantize the weights of all layers at once. To use it, the param_quantization_fn attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the Quantizer class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the quantize_params function can be called, which will iterate over all parameters and quantize them using params_quantization_fn .", "title": "Weights Quantization" - }, + }, { - "location": "/design/index.html#training-with-quantization", - "text": "The Quantizer class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios: Maintaining a full precision copy of the weights, as described here . This is enabled by setting train_with_fp_copy=True in the Quantizer constructor. At model transformation, in each module that has parameters that should be quantized, a new torch.nn.Parameter is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module is not created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\": The existing torch.nn.Parameter , e.g. weights , is replaced by a torch.nn.Parameter named float_weight . To maintain the existing functionality of the module, we then register a buffer in the module with the original name - weights . During training, float_weight will be passed to param_quantization_fn and the result will be stored in weight . In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the PACT method, acitvations are clipped to a value \\alpha , which is a learned parameter per-layer To support these two cases, the Quantizer class also accepts an instance of a torch.optim.Optimizer (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters. Optimizing New Parameters In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override Quantizer._get_updated_optimizer_params_groups() , and return the proper groups plus any desired hyper-parameter overrides.", - "title": "Training with Quantization" - }, + "location": "/design/index.html#quantization-aware-training", + "text": "The Quantizer class supports quantization-aware training, that is - training with quantization in the loop. This requires handling of a couple of flows / scenarios: Maintaining a full precision copy of the weights, as described here . This is enabled by setting train_with_fp_copy=True in the Quantizer constructor. At model transformation, in each module that has parameters that should be quantized, a new torch.nn.Parameter is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module is not created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\": The existing torch.nn.Parameter , e.g. weights , is replaced by a torch.nn.Parameter named float_weight . To maintain the existing functionality of the module, we then register a buffer in the module with the original name - weights . During training, float_weight will be passed to param_quantization_fn and the result will be stored in weight . In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the PACT method, acitvations are clipped to a value \\alpha , which is a learned parameter per-layer To support these two cases, the Quantizer class also accepts an instance of a torch.optim.Optimizer (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters. Optimizing New Parameters In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override Quantizer._get_updated_optimizer_params_groups() , and return the proper groups plus any desired hyper-parameter overrides.", + "title": "Quantization-Aware Training" + }, { - "location": "/design/index.html#examples", - "text": "The base Quantizer class is implemented in distiller/quantization/quantizer.py . \nFor a simple sub-class implementing symmetric linear quantization, see SymmetricLinearQuantizer in distiller/quantization/range_linear.py . \nIn distiller/quantization/clipped_linear.py there are examples of lower-precision methods which use training with quantization. Specifically, see PACTQuantizer for an example of overriding Quantizer._get_updated_optimizer_params_groups() .", + "location": "/design/index.html#examples", + "text": "The base Quantizer class is implemented in distiller/quantization/quantizer.py . \nFor a simple sub-class implementing symmetric linear quantization, see SymmetricLinearQuantizer in distiller/quantization/range_linear.py . \nIn distiller/quantization/clipped_linear.py there are examples of lower-precision methods which use training with quantization. Specifically, see PACTQuantizer for an example of overriding Quantizer._get_updated_optimizer_params_groups() .", "title": "Examples" } ] diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 83779119a2842ae5a478419ebaa1890d6e617690..4581df54aeb8c6653ce9f73ba8a4b5a507844ec0 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -4,7 +4,7 @@ <url> <loc>/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -12,7 +12,7 @@ <url> <loc>/install/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -20,7 +20,7 @@ <url> <loc>/usage/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -28,7 +28,7 @@ <url> <loc>/schedule/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -37,31 +37,31 @@ <url> <loc>/pruning/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/regularization/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/quantization/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/knowledge_distillation/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/conditional_computation/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -71,19 +71,19 @@ <url> <loc>/algo_pruning/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/algo_quantization/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/algo_earlyexit/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -92,7 +92,7 @@ <url> <loc>/model_zoo/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -100,7 +100,7 @@ <url> <loc>/jupyter/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> @@ -108,7 +108,7 @@ <url> <loc>/design/index.html</loc> - <lastmod>2018-11-25</lastmod> + <lastmod>2018-12-04</lastmod> <changefreq>daily</changefreq> </url> diff --git a/docs/usage/index.html b/docs/usage/index.html index a9d53fdcde0a0748dd1c85e71cd71f0f3cfcc096..20434be505f9d37f1dcae28cf7450b337d10422c 100644 --- a/docs/usage/index.html +++ b/docs/usage/index.html @@ -75,7 +75,7 @@ <li><a class="toctree-l3" href="#performing-pruning-sensitivity-analysis">Performing pruning sensitivity analysis</a></li> - <li><a class="toctree-l3" href="#direct-quantization-without-training">"Direct" Quantization Without Training</a></li> + <li><a class="toctree-l3" href="#post-training-quantization">Post-Training Quantization</a></li> <li><a class="toctree-l3" href="#summaries">Summaries</a></li> @@ -336,38 +336,39 @@ Results are output as a CSV file (<code>sensitivity.csv</code>) and PNG file (<c <p>The <code>sense</code> command-line argument can be set to either <code>element</code> or <code>filter</code>, depending on the type of analysis you want done.<br></p> <p>There is also a <a href="http://localhost:8888/notebooks/sensitivity_analysis.ipynb">Jupyter notebook</a> with example invocations, outputs and explanations.</p> -<h2 id="direct-quantization-without-training">"Direct" Quantization Without Training</h2> -<p>Distiller supports 8-bit quantization of trained modules without re-training (using <a href="../algo_quantization/index.html#symmetric-linear-quantization">Symmetric Linear Quantization</a>). So, any model (whether pruned or not) can be quantized.<br /> -Use the <code>--quantize</code> command-line flag, together with <code>--evaluate</code> to evaluate the accuracy of your model after quantization. The following example qunatizes ResNet18 for ImageNet:</p> -<pre><code>$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize --evaluate +<h2 id="post-training-quantization">Post-Training Quantization</h2> +<p>Distiller supports post-training quantization of trained modules without re-training (using <a href="../algo_quantization/index.html#range-based-linear-quantization">Range-Based Linear Quantization</a>). So, any model (whether pruned or not) can be quantized. To invoke post-training quantization, use <code>--quantize-eval</code> along with <code>--evaluate</code>. Additional arguments are available to control parameters of the quantization:</p> +<pre><code>Arguments controlling quantization at evaluation time("post-training quantization"): + --quantize-eval, --qe + Apply linear quantization to model before evaluation. + Applicable only if --evaluate is also set + --qe-mode QE_MODE, --qem QE_MODE + Linear quantization mode. Choices: asym_s | asym_u | + sym + --qe-bits-acts NUM_BITS, --qeba NUM_BITS + Number of bits for quantization of activations + --qe-bits-wts NUM_BITS, --qebw NUM_BITS + Number of bits for quantization of weights + --qe-bits-accum NUM_BITS + Number of bits for quantization of the accumulator + --qe-clip-acts, --qeca + Enable clipping of activations using min/max values + averaging over batch + --qe-no-clip-layers LAYER_NAME [LAYER_NAME ...], --qencl LAYER_NAME [LAYER_NAME ...] + List of fully-qualified layer names for which not to + clip activations. Applicable only if --qe-clip-acts is + also set + --qe-per-channel, --qepc + Enable per-channel quantization of weights (per output channel) + </code></pre> -<p>Generates:</p> -<pre><code>Preparing model for quantization ---- test --------------------- -50000 samples (256 per mini-batch) -Test: [ 10/ 195] Loss 0.856354 Top1 79.257812 Top5 92.500000 -Test: [ 20/ 195] Loss 0.923131 Top1 76.953125 Top5 92.246094 -Test: [ 30/ 195] Loss 0.885186 Top1 77.955729 Top5 92.486979 -Test: [ 40/ 195] Loss 0.930263 Top1 76.181641 Top5 92.597656 -Test: [ 50/ 195] Loss 0.931062 Top1 75.726562 Top5 92.906250 -Test: [ 60/ 195] Loss 0.932019 Top1 75.651042 Top5 93.151042 -Test: [ 70/ 195] Loss 0.921287 Top1 76.060268 Top5 93.270089 -Test: [ 80/ 195] Loss 0.932539 Top1 75.986328 Top5 93.100586 -Test: [ 90/ 195] Loss 0.996000 Top1 74.700521 Top5 92.330729 -Test: [ 100/ 195] Loss 1.066699 Top1 73.289062 Top5 91.437500 -Test: [ 110/ 195] Loss 1.100970 Top1 72.574574 Top5 91.001420 -Test: [ 120/ 195] Loss 1.122376 Top1 72.268880 Top5 90.696615 -Test: [ 130/ 195] Loss 1.171726 Top1 71.198918 Top5 90.120192 -Test: [ 140/ 195] Loss 1.191500 Top1 70.797991 Top5 89.902344 -Test: [ 150/ 195] Loss 1.219954 Top1 70.210938 Top5 89.453125 -Test: [ 160/ 195] Loss 1.240942 Top1 69.855957 Top5 89.162598 -Test: [ 170/ 195] Loss 1.265741 Top1 69.342831 Top5 88.807445 -Test: [ 180/ 195] Loss 1.281185 Top1 69.051649 Top5 88.589410 -Test: [ 190/ 195] Loss 1.279682 Top1 69.019326 Top5 88.632812 -==> Top1: 69.130 Top5: 88.732 Loss: 1.276 +<p>The following example qunatizes ResNet18 for ImageNet:</p> +<pre><code class="bash">$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet --pretrained --quantize-eval --evaluate </code></pre> +<p>A checkpoint with the quantized model will be dumped in the run directory. It will contain the quantized model parameters (the data type will still be FP32, but the values will be integers). The calculated quantization parameters (scale and zero-point) are stored as well in each quantized layer.</p> +<p>For more examples of post-training quantization see <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant.md">here</a></p> <h2 id="summaries">Summaries</h2> <p>You can use the sample compression application to generate model summary reports, such as the attributes and compute summary report (see screen capture below). You can log sparsity statistics (written to console and CSV file), performance, optimizer and model information, and also create a PNG image of the DNN. diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 62f1af9e09e1faeb934e44ddf9ff59acfbc4249a..a4c9a98a806de0d7f2a508d98ce862adcf320be7 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -159,11 +159,26 @@ parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, help='Load a model without DataParallel wrapping it') +str_to_quant_mode_map = {'sym': quantization.LinearQuantMode.SYMMETRIC, + 'asym_s': quantization.LinearQuantMode.ASYMMETRIC_SIGNED, + 'asym_u': quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED} + + +def linear_quant_mode_str(val_str): + try: + return str_to_quant_mode_map[val_str] + except KeyError: + raise argparse.ArgumentError('Must be one of {0} (received {1})'.format(list(str_to_quant_mode_map.keys()), + val_str)) + + quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' '("post-training quantization)') quant_group.add_argument('--quantize-eval', '--qe', action='store_true', - help='Apply linear-symmetric quantization to model before evaluation. Applicable only if' + help='Apply linear quantization to model before evaluation. Applicable only if' '--evaluate is also set') +quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym', + help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys())) quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS', help='Number of bits for quantization of activations') quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS', @@ -171,10 +186,12 @@ quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS', help='Number of bits for quantization of the accumulator') quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true', - help='Enable clipping of activations using max-abs-value averaging over batch') + help='Enable clipping of activations using min/max values averaging over batch') quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[], - help='List of fully-qualified layer names for which not to clip activations. Applicable' - 'only if --qe-clip-acts is also set') + help='List of layer names for which not to clip activations. Applicable only if ' + '--qe-clip-acts is also set') +quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true', + help='Enable per-channel quantization of weights (per output channel)') distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True) @@ -314,6 +331,7 @@ def main(): if args.resume: model, compression_scheduler, start_epoch = apputils.load_checkpoint( model, chkpt_file=args.resume) + model.cuda() # Define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() @@ -716,9 +734,9 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector if args.quantize_eval: model.cpu() - quantizer = quantization.SymmetricLinearQuantizer(model, args.qe_bits_acts, args.qe_bits_wts, - args.qe_bits_accum, args.qe_clip_acts, - args.qe_no_clip_layers) + quantizer = quantization.PostTrainLinearQuantizer(model, args.qe_bits_acts, args.qe_bits_wts, + args.qe_bits_accum, args.qe_mode, args.qe_clip_acts, + args.qe_no_clip_layers, args.qe_per_channel) quantizer.prepare_model() model.cuda() diff --git a/examples/quantization/post_training_quant.md b/examples/quantization/post_training_quant.md new file mode 100644 index 0000000000000000000000000000000000000000..413bbef56aafb6123fe10f441e16588fb0452198 --- /dev/null +++ b/examples/quantization/post_training_quant.md @@ -0,0 +1,51 @@ +# Post-Training Quantization Examples + +Following are a few examples of invoking post-training quantization on ResNet-50, using Distiller's image classification sample. Note that for post-training quantization we don't use a YAML schedule file, instead we specify command line arguments. The available command line arguments are: + +| Long Form | Short | Description | Default | +|-----------------------|-----------|--------------------------------------------------------------------------|---------| +| `--quantize-eval` | `--qe` | Apply linear quantization to model before evaluation | Off | +| `--qe-mode` | `--qem` | Linear quantization mode. Choices: "sym", "asym_u", "asym_s" | "sym" | +| `--qe-bits-acts` | `--qeba` | # of bits for quantization of activations | 8 | +| `--qe-bits-wts` | `--qebw` | # of bits for quantization of weights | 8 | +| `--qe-bits-accum` | N/A | # of bits for quantization of the accumulator | 32 | +| `--qe-clip-acts` | `--qeca` | Enable clipping of activations using min/max values averaging over batch | Off | +| `--qe-no-clip-layers` | `--qencl` | List of layer names (space-separated) for which not to clip activations | '' | +| `qe-per-channel` | `--qepc` | Enable per-channel quantization of weights (per output channel) | Off | + +This table summarizes the settings and results for each run. The command lines for all runs follow in the next table. + +| | Mode | # Bits Acts | # Bits Weights | Per-Channel | Clip Acts | Top-1 Accuracy | +|---|------------|-------------|----------------|-------------|-----------------------|----------------| +| 1 | FP32 | 32 | 32 | N/A | | 76.13% | +| 2 | Symmetric | 8 | 8 | No | No | 75.42% | +| 3 | Symmetric | 8 | 8 | Yes | No | 75.66% | +| 4 | Symmetric | 8 | 8 | Yes | Yes | 72.54% (See Note 1 below) | +| 5 | Symmetric | 8 | 8 | Yes | Yes (exc. last layer) | 75.94% | +| 6 | Asymmetric | 8 | 8 | No | No | 75.90% | +| 7 | Symmetric | 6 | 6 | No | No | 48.46% (See Note 2 below) | +| 8 | Asymmetric | 6 | 6 | No | No | 63.31% | +| 9 | Asymmetric | 6 | 6 | Yes | Yes (exc. last layer) | 73.08% | + +Command lines: + +| | Command Line | +|---|--------------| +| 1 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate` +| 2 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval` +| 3 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-per-channel` +| 4 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-per-channel --qe-clip-acts` +| 5 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-per-channel --qe-clip-acts --qe-no-clip-layers fc` +| 6 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-mode asym_u` +| 7 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6` +| 8 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6 --qe-mode asym_u` +| 9 | `python compress_classifier.py -a resnet50 --pretrained ~/datasets/imagenet --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6 --qe-mode asym_u --qe-per-channel --qe-clip-acts --qe-no-clip-layers fc` + +## Note 1: Accuracy Loss When Clipping Activations + +Notice the degradation in accuracy in run (4) - ~3% compared to per-channel without clipping. Let's recall that the output of the final layer of the model holds the "score" of each class (which, since we're using softmax, can be interpreted as the un-normalized log probability of each class). So if we clip the outputs of this layer, we're in fact "cutting-off" the highest (and lowest) scores. If the highest scores for some sample are close enough, this can result in a wrong classification of that sample. +We can provide Distiller with a list of layers for which not to clip activations. In this case we just want to skip the last layer, which in the case of the ResNet-50 model is called `fc`. This is what we do in run (5), and we regain most of the accuracy back. + +## Note 2: Under 8-bits + +Runs (7) - (9) are examples of trying post-training quantization below 8-bits. Notice how with the most basic settings we get a massive accuracy loss of almost 28%. Even with asymmetric quantization and all other optimizations enabled, we still get a non-trivial degradation of just over 3% vs. FP32. Quantizing with less than 8-bits, in most cases, required quantization-aware training. \ No newline at end of file diff --git a/examples/quantization/quant_aware_train_linear_quant.yaml b/examples/quantization/quant_aware_train_linear_quant.yaml new file mode 100644 index 0000000000000000000000000000000000000000..065809391704308748bd215de8dda92e04418605 --- /dev/null +++ b/examples/quantization/quant_aware_train_linear_quant.yaml @@ -0,0 +1,44 @@ +# Scheduler for training / re-training a model using quantization aware training, with a linear, range-based quantizer +# +# The setting here is 8-bit weights and activations. For vision models, this is usually applied to the entire model, +# without exceptions. Hence, this scheduler isn't model-specific as-is. It doesn't define any name-based overrides. +# +# At the moment this quantizer will: +# * Quantize weights and biases for all convolution and FC layers +# * Quantize all ReLU activations +# +# Here's an example run for fine-tuning the ResNet-18 model from torchvision: +# +# python compress_classifier.py -a resnet18 -p 50 -b 256 ~/datasets/imagenet --epochs 10 --compress=../quantization/quant_aware_train_linear_quant.yaml --pretrained -j 22 --lr 0.0001 --vs 0 +# +# After 6 epochs we get: +# +# 2018-11-22 20:41:03,662 - --- validate (epoch=6)----------- +# 2018-11-22 20:41:03,663 - 50000 samples (256 per mini-batch) +# 2018-11-22 20:41:23,507 - Epoch: [6][ 50/ 195] Loss 0.896985 Top1 76.320312 Top5 93.460938 +# 2018-11-22 20:41:33,633 - Epoch: [6][ 100/ 195] Loss 1.026040 Top1 74.007812 Top5 91.984375 +# 2018-11-22 20:41:44,142 - Epoch: [6][ 150/ 195] Loss 1.168643 Top1 71.197917 Top5 90.041667 +# 2018-11-22 20:41:51,505 - ==> Top1: 70.188 Top5: 89.376 Loss: 1.223 +# +# This is an improvement compared to the pre-trained torchvision model: +# 2018-11-07 15:45:53,435 - ==> Top1: 69.758 Top5: 89.078 Loss: 1.251 +# +# (Note that the command line above is not using --deterministic, so results could vary a little bit) + +quantizers: + linear_quantizer: + class: QuantAwareTrainRangeLinearQuantizer + bits_activations: 8 + bits_weights: 8 + mode: 'ASYMMETRIC_UNSIGNED' # Can try "SYMMETRIC" as well + ema_decay: 0.999 # Decay value for exponential moving average tracking of activation ranges + per_channel_wts: True + +policies: + - quantizer: + instance_name: linear_quantizer + # For now putting a large range here, which should cover both training from scratch or resuming from some + # pre-trained checkpoint at some unknown epoch + starting_epoch: 0 + ending_epoch: 300 + frequency: 1 diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index f23c1c3241d0e7145bd64abe5eb1b64482c14a89..c78d0fc0916738212864d07f4ae8b93adf009412 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -118,7 +118,7 @@ test_configs = [ TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [48.340, 92.630]), TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), - DS_CIFAR, accuracy_checker, [91.620, 99.630]), + DS_CIFAR, accuracy_checker, [91.580, 99.620]), TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')), DS_CIFAR, accuracy_checker, [48.290, 94.460]), diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..6bae2a78c7b24bada3996a7d318b028af29ca4b2 --- /dev/null +++ b/tests/test_post_train_quant.py @@ -0,0 +1,134 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import pytest +import torch +import torch.testing + +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) +from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode + + +@pytest.fixture() +def conv_input(): + return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float64), + torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float64)), 0) + + +@pytest.fixture() +def conv_weights(): + return torch.tensor([[[[-1, -0.5, 0], [0.5, 1, 1.5], [2, 2.5, 3]]], + [[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float64) + + +@pytest.mark.parametrize( + "mode, clip_acts, per_channel_wts, expected_output", + [ + (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False, + torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]], + [[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float64), + torch.tensor([[[[12.51811144, 13.01883589], [14.0918168, 14.59254133]], + [[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float64)), + dim=0) + ), + (LinearQuantMode.ASYMMETRIC_UNSIGNED, True, False, + torch.cat((torch.tensor([[[[-1.089218234, -1.089218234], [1.055180164, 2.518817167]], + [[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float64), + torch.tensor([[[[7.59048957, 7.59048957], [7.59048957, 7.59048957]], + [[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float64)), + dim=0) + ), + (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True, + torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]], + [[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float64), + torch.tensor([[[[12.51811144, 13.01883589], [14.09181687, 14.59254133]], + [[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float64)), + dim=0) + ), + (LinearQuantMode.ASYMMETRIC_UNSIGNED, True, True, + torch.cat((torch.tensor([[[[-1.089768056, -1.089768056], [1.055712804, 2.52008863]], + [[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float64), + torch.tensor([[[[7.59432114, 7.59432114], [7.59432114, 7.59432114]], + [[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float64)), + dim=0) + ) + ] +) +def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_channel_wts, expected_output): + layer = torch.nn.Conv2d(conv_input.shape[1], expected_output.shape[1], conv_weights.shape[-1], + padding=1, bias=False) + layer.weight.data = conv_weights + + model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts, + per_channel_wts=per_channel_wts) + + with pytest.raises(RuntimeError): + model(conv_input) + + model.eval() + + output = model(conv_input) + + torch.testing.assert_allclose(output, expected_output) + + +@pytest.fixture() +def linear_input(): + return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float64) + + +@pytest.fixture() +def linear_weights(): + return torch.tensor([[-1, 0.5, 0, 0.5], + [-0.05, 0, 0.05, 0.1], + [0.3, 0.6, -0.1, -0.2]], dtype=torch.float64) + + +@pytest.fixture() +def linear_bias(): + return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float64) + + +@pytest.mark.parametrize( + "mode, clip_acts, per_channel_wts, expected_output", + [ + (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False, + torch.tensor([[7.698556917, 0.262450804, 0.787352412]], dtype=torch.float64)), + (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True, + torch.tensor([[7.71233218, 0.262920415, 0.788761246]], dtype=torch.float64)) + ] +) +def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias, + mode, clip_acts, per_channel_wts, expected_output): + layer = torch.nn.Linear(linear_input.shape[1], expected_output.shape[1], bias=True) + layer.weight.data = linear_weights + layer.bias.data = linear_bias + + model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts, + per_channel_wts=per_channel_wts) + + with pytest.raises(RuntimeError): + model(linear_input) + + model.eval() + + output = model(linear_input) + + torch.testing.assert_allclose(output, expected_output) diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index 811e46c912283fb61deae0bad76a0cc60e4a6a96..8f3fa41ee41ebbb988a72589bbe7c723ed84fa2a 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -26,7 +26,7 @@ module_path = os.path.abspath(os.path.join('..')) if module_path not in sys.path: sys.path.append(module_path) from distiller.quantization import Quantizer -from distiller.quantization.quantizer import QBits +from distiller.quantization.quantizer import QBits, _ParamToQuant from distiller.quantization.quantizer import FP_BKP_PREFIX from distiller import has_children @@ -90,8 +90,8 @@ class DummyModel(nn.Sequential): # Dummy Quantizer ############################# -def dummy_quantize_params(param, num_bits): - return param + num_bits +def dummy_quantize_params(param, param_meta): + return param + param_meta.num_bits class DummyQuantizer(Quantizer): @@ -341,5 +341,6 @@ def test_param_quantization(model, optimizer, qbits, bits_overrides, explicit_ex else: quant_param = getattr(post_quant_module, param_name) - expected = dummy_quantize_params(pre_quant_param, num_bits) if quantizable else pre_quant_param + expected = dummy_quantize_params(pre_quant_param, + _ParamToQuant(None, None, None, None, num_bits)) if quantizable else pre_quant_param assert torch.equal(quant_param, expected)