From 6b166cec59c9eb00882d485f0fbbd6376d2692dd Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Tue, 17 Jul 2018 14:45:53 +0300 Subject: [PATCH] Quantizer tests, fixes and docs update * Add Quantizer unit tests * Require 'bits_overrides' to be OrderedDict to support overlapping patterns in a predictable manner + update documentation to reflect this * Quantizer class cleanup * Use "public" nn.Module APIs instead of protected attributes * Call the builtins set/get/delattr instead of the class special methods (__***__) * Fix issues reported in #24 * Bug in RangeLinearQuantParamLayerWrapper - add explicit override of pre_quantized_forward accpeting single input (#15) * Add DoReFa test to full_flow_tests --- distiller/quantization/clipped_linear.py | 21 +- distiller/quantization/quantizer.py | 54 +-- distiller/quantization/range_linear.py | 3 + distiller/utils.py | 14 +- docs-src/docs/design.md | 3 + docs-src/docs/schedule.md | 26 +- docs/design/index.html | 3 + docs/index.html | 2 +- docs/schedule/index.html | 23 +- docs/search/search_index.json | 13 +- docs/sitemap.xml | 24 +- tests/full_flow_tests.py | 3 + .../preact_resnet20_cifar_dorefa_test.yaml | 38 ++ tests/test_quantizer.py | 339 ++++++++++++++++++ 14 files changed, 516 insertions(+), 50 deletions(-) create mode 100644 tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml create mode 100644 tests/test_quantizer.py diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index 9e218ef..5854ab8 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -1,3 +1,20 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict import torch.nn as nn from .quantizer import Quantizer @@ -55,7 +72,7 @@ class WRPNQuantizer(Quantizer): 1. This class does not take care of layer widening as described in the paper 2. The paper defines special handling for 1-bit weights which isn't supported here yet """ - def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}, quantize_bias=False): + def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=False): super(WRPNQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, bits_overrides=bits_overrides, train_with_fp_copy=True, quantize_bias=quantize_bias) @@ -87,7 +104,7 @@ class DorefaQuantizer(Quantizer): 1. Gradients quantization not supported yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet """ - def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}, quantize_bias=False): + def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=False): super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, bits_overrides=bits_overrides, train_with_fp_copy=True, quantize_bias=quantize_bias) diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 854845b..77ef217 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -14,12 +14,13 @@ # limitations under the License. # -from collections import namedtuple +from collections import namedtuple, OrderedDict import re import copy import logging import torch import torch.nn as nn +import distiller msglogger = logging.getLogger() @@ -38,13 +39,14 @@ def hack_float_backup_parameter(module, name): except KeyError: raise ValueError('Module has no Parameter named ' + name) module.register_parameter(FP_BKP_PREFIX + name, nn.Parameter(data)) - module.__delattr__(name) + delattr(module, name) module.register_buffer(name, torch.zeros_like(data)) class _ParamToQuant(object): - def __init__(self, module, fp_attr_name, q_attr_name, num_bits): + def __init__(self, module, module_name, fp_attr_name, q_attr_name, num_bits): self.module = module + self.module_name = module_name self.fp_attr_name = fp_attr_name self.q_attr_name = q_attr_name self.num_bits = num_bits @@ -58,8 +60,13 @@ class Quantizer(object): model (torch.nn.Module): The model to be quantized bits_activations/weights (int): Default number of bits to use when quantizing each tensor type. Value of None means do not quantize. - bits_overrides (dict): Dictionary mapping regular expressions of layer name patterns to dictionary with + bits_overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with values for 'acts' and/or 'wts' to override the default values. + OrderedDict is used to enable handling of overlapping name patterns. So, for example, one could define + certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for + specific layers in that group, e.g. 'conv1'. + The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must + come before the broad patterns. quantize_bias (bool): Flag indicating whether to quantize bias (w. same number of bits as weights) or not. train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized and floating-point copy, such that the following flow occurs in each training iteration: @@ -70,8 +77,11 @@ class Quantizer(object): 3.2 We also back-prop through the 'quantize' operation from step 1 4. Update fp_weights with gradients calculated in step 3.2 """ - def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides={}, quantize_bias=False, - train_with_fp_copy=False): + def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(), + quantize_bias=False, train_with_fp_copy=False): + if not isinstance(bits_overrides, OrderedDict): + raise TypeError('bits_overrides must be an instance of collections.OrderedDict') + self.default_qbits = QBits(acts=bits_activations, wts=bits_weights) self.quantize_bias = quantize_bias @@ -89,13 +99,11 @@ class Quantizer(object): bits_overrides[k] = qbits # Prepare explicit mapping from each layer to QBits based on default + overrides + patterns = [] regex = None if bits_overrides: - regex_str = '' - keys_list = list(bits_overrides.keys()) - for pattern in keys_list: - regex_str += '(^{0}$)|'.format(pattern) - regex_str = regex_str[:-1] # Remove trailing '|' + patterns = list(bits_overrides.keys()) + regex_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns]) regex = re.compile(regex_str) self.module_qbits_map = {} @@ -111,7 +119,7 @@ class Quantizer(object): groups = m.groups() while groups[group_idx] is None: group_idx += 1 - qbits = bits_overrides[keys_list[group_idx]] + qbits = bits_overrides[patterns[group_idx]] self._add_qbits_entry(module_full_name, type(module), qbits) # Mapping from module type to function generating a replacement module suited for quantization @@ -152,7 +160,7 @@ class Quantizer(object): if self.train_with_fp_copy: hack_float_backup_parameter(module, param_name) fp_attr_name = FP_BKP_PREFIX + param_name - self.params_to_quantize.append(_ParamToQuant(module, fp_attr_name, param_name, qbits.wts)) + self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, qbits.wts)) param_full_name = '.'.join([module_name, param_name]) msglogger.info( @@ -164,21 +172,23 @@ class Quantizer(object): # Iterate through model, insert quantization functions as appropriate for name, module in container.named_children(): full_name = prefix + name + current_qbits = self.module_qbits_map[full_name] + if current_qbits.acts is None and current_qbits.wts is None: + continue try: new_module = self.replacement_factory[type(module)](module, full_name, self.module_qbits_map) msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module)) - container._modules[name] = new_module + setattr(container, name, new_module) # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping - if len(module._modules) == 0 and len(new_module._modules) > 0: - current_qbits = self.module_qbits_map[full_name] - for sub_module_name, module in new_module.named_modules(): - self._add_qbits_entry(full_name + '.' + sub_module_name, type(module), current_qbits) + if not distiller.has_children(module) and distiller.has_children(new_module): + for sub_module_name, sub_module in new_module.named_modules(): + self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits) self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None) except KeyError: pass - if len(module._modules) > 0: + if distiller.has_children(module): # For container we call recursively self._pre_process_container(module, full_name + '.') @@ -188,8 +198,8 @@ class Quantizer(object): of bits for each parameter) """ for ptq in self.params_to_quantize: - q_param = self.param_quantization_fn(ptq.module.__getattr__(ptq.fp_attr_name), ptq.num_bits) + q_param = self.param_quantization_fn(getattr(ptq.module, ptq.fp_attr_name), ptq.num_bits) if self.train_with_fp_copy: - ptq.module.__setattr__(ptq.q_attr_name, q_param) + setattr(ptq.module, ptq.q_attr_name, q_param) else: - ptq.module.__getattr__(ptq.q_attr_name).data = q_param.data + getattr(ptq.module, ptq.q_attr_name).data = q_param.data diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index bca1cb9..e8187df 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -125,6 +125,9 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.current_accum_scale = 1 + def pre_quantized_forward(self, input): + super(RangeLinearQuantParamLayerWrapper, self).forward(input) + def pre_quantized_forward(self, input): in_scale = symmetric_linear_quantization_scale_factor(self.num_bits_acts, get_tensor_max_abs(input)) self.current_accum_scale = in_scale * self.w_scale diff --git a/distiller/utils.py b/distiller/utils.py index 588b1f5..073357d 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -266,6 +266,14 @@ def log_weights_sparsity(model, epoch, loggers): logger.log_weights_sparsity(model, epoch) +def has_children(module): + try: + next(module.children()) + return True + except StopIteration: + return False + + class DoNothingModuleWrapper(nn.Module): """Implement a nn.Module which wraps another nn.Module. @@ -294,15 +302,15 @@ def make_non_parallel_copy(model): full_name = prefix + name if isinstance(module, nn.DataParallel): # msglogger.debug('Replacing module {}'.format(full_name)) - container._modules[name] = DoNothingModuleWrapper(module.module) - if len(module._modules) > 0: + setattr(container, name, DoNothingModuleWrapper(module.module)) + if has_children(module): # For a container we call recursively replace_data_parallel(module, full_name + '.') # Make a copy of the model, because we're going to change it new_model = deepcopy(model) if isinstance(new_model, nn.DataParallel): - #new_model = new_model.module # + # new_model = new_model.module # new_model = DoNothingModuleWrapper(new_model.module) replace_data_parallel(new_model) diff --git a/docs-src/docs/design.md b/docs-src/docs/design.md index 7ff309e..62f66d2 100755 --- a/docs-src/docs/design.md +++ b/docs-src/docs/design.md @@ -66,6 +66,9 @@ To execute the model transformation, call the `prepare_model` function of the `Q - Each instance of `Quantizer` is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the `bits_activations` and `bits_weights` parameters in `Quantizer`'s constructor. Sub-classes may define bit-widths for other tensor types as needed. - We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks ("container" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern. - So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the `bits_overrides` parameter in the constructor. + - The `bits_overrides` mapping is required to be an instance of [`collections.OrderedDict`](https://docs.python.org/3.5/library/collections.html#collections.OrderedDict) (as opposed to just a simple Python [`dict`](https://docs.python.org/3.5/library/stdtypes.html#dict)). This is done in order to enable handling of overlapping name patterns. + So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'. + The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns. ### Weights Quantization diff --git a/docs-src/docs/schedule.md b/docs-src/docs/schedule.md index 1f4c184..54d5704 100755 --- a/docs-src/docs/schedule.md +++ b/docs-src/docs/schedule.md @@ -267,17 +267,37 @@ quantizers: - The specific quantization method we're instantiating here is `DorefaQuantizer`. - Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. -- Then, we define the `bits_overrides` mapping. In this case, we choose not to quantize the first and last layer of the model. In the case of `DorefaQuantizer`, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters `conv1`, the first activation layer `relu1`, the last activation layer `final_relu` and the last layer with parameters `fc`. +- Then, we define the `bits_overrides` mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of `DorefaQuantizer`, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters `conv1`, the first activation layer `relu1`, the last activation layer `final_relu` and the last layer with parameters `fc`. - Specifying `null` means "do not quantize". - Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers. -- We can also reference **groups of layers** in the `bits_overrides` mapping. This is done using regular expressions. Suppose we have a sub-module in our model named `block1`, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named `conv1`, `conv2` and so on. In that case we would define the following: + +### Defining overrides for **groups of layers** using regular expressions + +Suppose we have a sub-module in our model named `block1`, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named `conv1`, `conv2` and so on. In that case we would define the following: + +``` +bits_overrides: + 'block1\.conv*': + wts: 2 + acts: null +``` + +- **RegEx Note**: Remember that the dot (`.`) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: `\.` + +**Overlapping patterns** are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for `block1.conv1`: ``` bits_overrides: - block1.conv*: + 'block1\.conv1': + wts: 4 + acts: null + 'block1\.conv*': wts: 2 acts: null ``` + +- **Important Note**: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed **before** the broad one. + The `QuantizationPolicy`, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the `prepare_model()` function of the `Quantizer` when it's initialized, followed by the first call to `quantize_params()`. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the `quantize_params()` function again. diff --git a/docs/design/index.html b/docs/design/index.html index 991b270..9d256ff 100644 --- a/docs/design/index.html +++ b/docs/design/index.html @@ -224,6 +224,9 @@ To execute the model transformation, call the <code>prepare_model</code> functio <li>Each instance of <code>Quantizer</code> is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the <code>bits_activations</code> and <code>bits_weights</code> parameters in <code>Quantizer</code>'s constructor. Sub-classes may define bit-widths for other tensor types as needed.</li> <li>We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks ("container" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.</li> <li>So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the <code>bits_overrides</code> parameter in the constructor.</li> +<li>The <code>bits_overrides</code> mapping is required to be an instance of <a href="https://docs.python.org/3.5/library/collections.html#collections.OrderedDict"><code>collections.OrderedDict</code></a> (as opposed to just a simple Python <a href="https://docs.python.org/3.5/library/stdtypes.html#dict"><code>dict</code></a>). This is done in order to enable handling of overlapping name patterns.<br /> + So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.<br /> + The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.</li> </ul> <h3 id="weights-quantization">Weights Quantization</h3> <p>The <code>Quantizer</code> class also provides an API to quantize the weights of all layers at once. To use it, the <code>param_quantization_fn</code> attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the <code>Quantizer</code> class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the <code>quantize_params</code> function can be called, which will iterate over all parameters and quantize them using <code>params_quantization_fn</code>.</p> diff --git a/docs/index.html b/docs/index.html index 5c5cd23..2170e16 100644 --- a/docs/index.html +++ b/docs/index.html @@ -246,5 +246,5 @@ And of course, if we used a sparse or compressed representation, then we are red <!-- MkDocs version : 0.17.2 -Build Date UTC : 2018-07-01 07:53:34 +Build Date UTC : 2018-07-17 10:59:29 --> diff --git a/docs/schedule/index.html b/docs/schedule/index.html index ee6e523..4f15baa 100644 --- a/docs/schedule/index.html +++ b/docs/schedule/index.html @@ -427,17 +427,34 @@ Let's see an example:</p> <ul> <li>The specific quantization method we're instantiating here is <code>DorefaQuantizer</code>.</li> <li>Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. </li> -<li>Then, we define the <code>bits_overrides</code> mapping. In this case, we choose not to quantize the first and last layer of the model. In the case of <code>DorefaQuantizer</code>, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters <code>conv1</code>, the first activation layer <code>relu1</code>, the last activation layer <code>final_relu</code> and the last layer with parameters <code>fc</code>.</li> +<li>Then, we define the <code>bits_overrides</code> mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of <code>DorefaQuantizer</code>, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters <code>conv1</code>, the first activation layer <code>relu1</code>, the last activation layer <code>final_relu</code> and the last layer with parameters <code>fc</code>.</li> <li>Specifying <code>null</code> means "do not quantize".</li> <li>Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.</li> -<li>We can also reference <strong>groups of layers</strong> in the <code>bits_overrides</code> mapping. This is done using regular expressions. Suppose we have a sub-module in our model named <code>block1</code>, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named <code>conv1</code>, <code>conv2</code> and so on. In that case we would define the following:</li> </ul> +<h3 id="defining-overrides-for-groups-of-layers-using-regular-expressions">Defining overrides for <strong>groups of layers</strong> using regular expressions</h3> +<p>Suppose we have a sub-module in our model named <code>block1</code>, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named <code>conv1</code>, <code>conv2</code> and so on. In that case we would define the following:</p> <pre><code>bits_overrides: - block1.conv*: + 'block1\.conv*': wts: 2 acts: null </code></pre> +<ul> +<li><strong>RegEx Note</strong>: Remember that the dot (<code>.</code>) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: <code>\.</code></li> +</ul> +<p><strong>Overlapping patterns</strong> are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for <code>block1.conv1</code>:</p> +<pre><code>bits_overrides: + 'block1\.conv1': + wts: 4 + acts: null + 'block1\.conv*': + wts: 2 + acts: null +</code></pre> + +<ul> +<li><strong>Important Note</strong>: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed <strong>before</strong> the broad one.</li> +</ul> <p>The <code>QuantizationPolicy</code>, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the <code>prepare_model()</code> function of the <code>Quantizer</code> when it's initialized, followed by the first call to <code>quantize_params()</code>. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the <code>quantize_params()</code> function again. </p> <pre><code>policies: - quantizer: diff --git a/docs/search/search_index.json b/docs/search/search_index.json index e6dba17..f242561 100644 --- a/docs/search/search_index.json +++ b/docs/search/search_index.json @@ -137,7 +137,7 @@ }, { "location": "/schedule/index.html", - "text": "Compression scheduler\n\n\nIn iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of \nCompressionScheduler\n: it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code. \n\n\nHigh level overview\n\n\nLet's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies.\n\n\n\n\nPruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. \n\n\nAn LR-scheduler specifies the LR-decay algorithm. \n\n\n\n\nThese define the \nwhat\n part of the schedule. \n\n\nThe Policies define the \nwhen\n part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing.\n\nThe CompressionScheduler is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.\n\n\nSyntax through example\n\n\nWe'll use \nalexnet.schedule_agp.yaml\n to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThere is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2.\n\n\nversion: 1\n\n\n\n\nIn the \npruners\n section, we define the instances of pruners we want the scheduler to instantiate and use.\n\nWe define a single pruner instance, named \nmy_pruner\n, of algorithm \nSensitivityPruner\n. We will refer to this instance in the \nPolicies\n section.\n\nThen we list the sensitivity multipliers, \\(s\\), of each of the weight tensors.\n\nYou may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule.\n\n\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.6\n\n\n\n\nNext, we want to specify the learning-rate decay scheduling in the \nlr_schedulers\n section. We assign a name to this instance: \npruning_lr\n. As in the \npruners\n section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's \n_LRScheduler\n. You can use any of the schedulers defined in \ntorch.optim.lr_scheduler\n (see \nhere\n). In addition, we've implemented some additional schedulers in Distiller (see \nhere\n). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to \ntorch.optim.lr_scheduler\n, they can be used without changing the application code.\n\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\n\n\n\nFinally, we define the \npolicies\n section which defines the actual scheduling. A \nPolicy\n manages an instance of a \nPruner\n, \nRegularizer\n, \nQuantizer\n, or \nLRScheduler\n, by naming the instance. In the example below, a \nPruningPolicy\n uses the pruner instance named \nmy_pruner\n: it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. \n\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThis is \niterative pruning\n:\n\n\n\n\n\n\nTrain Connectivity\n\n\n\n\n\n\nPrune Connections\n\n\n\n\n\n\nRetrain Weights\n\n\n\n\n\n\nGoto 2\n\n\n\n\n\n\nIt is described in \nLearning both Weights and Connections for Efficient Neural Networks\n:\n\n\n\n\n\"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks \u2014 learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity.\"\n\n\n\n\nRegularization\n\n\nYou can also define and schedule regularization.\n\n\nL1 regularization\n\n\nFormat (this is an informal specification, not a valid \nABNF\n specification):\n\n\nregularizers:\n \nREGULARIZER_NAME_STR\n:\n class: L1Regularizer\n reg_regims:\n \nPYTORCH_PARAM_NAME_STR\n: \nSTRENGTH_FLOAT\n\n ...\n \nPYTORCH_PARAM_NAME_STR\n: \nSTRENGTH_FLOAT\n\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_L1_reg:\n class: L1Regularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': 0.000002\n 'module.layer3.1.conv2.weight': 0.000002\n 'module.layer3.1.conv3.weight': 0.000002\n 'module.layer3.2.conv1.weight': 0.000002\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_L1_reg\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nGroup regularization\n\n\nFormat (informal specification):\n\n\nFormat:\n regularizers:\n \nREGULARIZER_NAME_STR\n:\n class: L1Regularizer\n reg_regims:\n \nPYTORCH_PARAM_NAME_STR\n: [\nSTRENGTH_FLOAT\n, \n'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'\n]\n \nPYTORCH_PARAM_NAME_STR\n: [\nSTRENGTH_FLOAT\n, \n'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'\n]\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_filter_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': [0.00005, '3D']\n 'module.layer3.1.conv2.weight': [0.00005, '3D']\n 'module.layer3.1.conv3.weight': [0.00005, '3D']\n 'module.layer3.2.conv1.weight': [0.00005, '3D']\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_filter_regularizer\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nMixing it up\n\n\nYou can mix pruning and regularization.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nregularizers:\n 2d_groups_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'features.module.0.weight': [0.000012, '2D']\n 'features.module.3.weight': [0.000012, '2D']\n 'features.module.6.weight': [0.000012, '2D']\n 'features.module.8.weight': [0.000012, '2D']\n 'features.module.10.weight': [0.000012, '2D']\n\n\nlr_schedulers:\n # Learning rate decay scheduler\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - regularizer:\n instance_name: '2d_groups_regularizer'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 1\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\n\nQuantization\n\n\nSimilarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the \nQuantizer\n class (see details \nhere\n).\n\n\nNotes\n: Only a single quantizer instance may be defined.\n\nLet's see an example:\n\n\nquantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null\n\n\n\n\n\n\nThe specific quantization method we're instantiating here is \nDorefaQuantizer\n.\n\n\nThen we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. \n\n\nThen, we define the \nbits_overrides\n mapping. In this case, we choose not to quantize the first and last layer of the model. In the case of \nDorefaQuantizer\n, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters \nconv1\n, the first activation layer \nrelu1\n, the last activation layer \nfinal_relu\n and the last layer with parameters \nfc\n.\n\n\nSpecifying \nnull\n means \"do not quantize\".\n\n\nNote that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.\n\n\nWe can also reference \ngroups of layers\n in the \nbits_overrides\n mapping. This is done using regular expressions. Suppose we have a sub-module in our model named \nblock1\n, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named \nconv1\n, \nconv2\n and so on. In that case we would define the following:\n\n\n\n\nbits_overrides:\n block1.conv*:\n wts: 2\n acts: null\n\n\n\n\nThe \nQuantizationPolicy\n, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the \nprepare_model()\n function of the \nQuantizer\n when it's initialized, followed by the first call to \nquantize_params()\n. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the \nquantize_params()\n function again. \n\n\npolicies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1\n\n\n\n\nImportant Note\n: As mentioned \nhere\n, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to \nprepare_model()\n must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.", + "text": "Compression scheduler\n\n\nIn iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of \nCompressionScheduler\n: it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code. \n\n\nHigh level overview\n\n\nLet's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies.\n\n\n\n\nPruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. \n\n\nAn LR-scheduler specifies the LR-decay algorithm. \n\n\n\n\nThese define the \nwhat\n part of the schedule. \n\n\nThe Policies define the \nwhen\n part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing.\n\nThe CompressionScheduler is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.\n\n\nSyntax through example\n\n\nWe'll use \nalexnet.schedule_agp.yaml\n to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThere is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2.\n\n\nversion: 1\n\n\n\n\nIn the \npruners\n section, we define the instances of pruners we want the scheduler to instantiate and use.\n\nWe define a single pruner instance, named \nmy_pruner\n, of algorithm \nSensitivityPruner\n. We will refer to this instance in the \nPolicies\n section.\n\nThen we list the sensitivity multipliers, \\(s\\), of each of the weight tensors.\n\nYou may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule.\n\n\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.6\n\n\n\n\nNext, we want to specify the learning-rate decay scheduling in the \nlr_schedulers\n section. We assign a name to this instance: \npruning_lr\n. As in the \npruners\n section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's \n_LRScheduler\n. You can use any of the schedulers defined in \ntorch.optim.lr_scheduler\n (see \nhere\n). In addition, we've implemented some additional schedulers in Distiller (see \nhere\n). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to \ntorch.optim.lr_scheduler\n, they can be used without changing the application code.\n\n\nlr_schedulers:\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\n\n\n\nFinally, we define the \npolicies\n section which defines the actual scheduling. A \nPolicy\n manages an instance of a \nPruner\n, \nRegularizer\n, \nQuantizer\n, or \nLRScheduler\n, by naming the instance. In the example below, a \nPruningPolicy\n uses the pruner instance named \nmy_pruner\n: it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. \n\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\nThis is \niterative pruning\n:\n\n\n\n\n\n\nTrain Connectivity\n\n\n\n\n\n\nPrune Connections\n\n\n\n\n\n\nRetrain Weights\n\n\n\n\n\n\nGoto 2\n\n\n\n\n\n\nIt is described in \nLearning both Weights and Connections for Efficient Neural Networks\n:\n\n\n\n\n\"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks \u2014 learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity.\"\n\n\n\n\nRegularization\n\n\nYou can also define and schedule regularization.\n\n\nL1 regularization\n\n\nFormat (this is an informal specification, not a valid \nABNF\n specification):\n\n\nregularizers:\n \nREGULARIZER_NAME_STR\n:\n class: L1Regularizer\n reg_regims:\n \nPYTORCH_PARAM_NAME_STR\n: \nSTRENGTH_FLOAT\n\n ...\n \nPYTORCH_PARAM_NAME_STR\n: \nSTRENGTH_FLOAT\n\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_L1_reg:\n class: L1Regularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': 0.000002\n 'module.layer3.1.conv2.weight': 0.000002\n 'module.layer3.1.conv3.weight': 0.000002\n 'module.layer3.2.conv1.weight': 0.000002\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_L1_reg\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nGroup regularization\n\n\nFormat (informal specification):\n\n\nFormat:\n regularizers:\n \nREGULARIZER_NAME_STR\n:\n class: L1Regularizer\n reg_regims:\n \nPYTORCH_PARAM_NAME_STR\n: [\nSTRENGTH_FLOAT\n, \n'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'\n]\n \nPYTORCH_PARAM_NAME_STR\n: [\nSTRENGTH_FLOAT\n, \n'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'\n]\n threshold_criteria: [Mean_Abs | Max]\n\n\n\n\nFor example:\n\n\nversion: 1\n\nregularizers:\n my_filter_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'module.layer3.1.conv1.weight': [0.00005, '3D']\n 'module.layer3.1.conv2.weight': [0.00005, '3D']\n 'module.layer3.1.conv3.weight': [0.00005, '3D']\n 'module.layer3.2.conv1.weight': [0.00005, '3D']\n threshold_criteria: Mean_Abs\n\npolicies:\n - regularizer:\n instance_name: my_filter_regularizer\n starting_epoch: 0\n ending_epoch: 60\n frequency: 1\n\n\n\n\nMixing it up\n\n\nYou can mix pruning and regularization.\n\n\nversion: 1\npruners:\n my_pruner:\n class: 'SensitivityPruner'\n sensitivities:\n 'features.module.0.weight': 0.25\n 'features.module.3.weight': 0.35\n 'features.module.6.weight': 0.40\n 'features.module.8.weight': 0.45\n 'features.module.10.weight': 0.55\n 'classifier.1.weight': 0.875\n 'classifier.4.weight': 0.875\n 'classifier.6.weight': 0.625\n\nregularizers:\n 2d_groups_regularizer:\n class: GroupLassoRegularizer\n reg_regims:\n 'features.module.0.weight': [0.000012, '2D']\n 'features.module.3.weight': [0.000012, '2D']\n 'features.module.6.weight': [0.000012, '2D']\n 'features.module.8.weight': [0.000012, '2D']\n 'features.module.10.weight': [0.000012, '2D']\n\n\nlr_schedulers:\n # Learning rate decay scheduler\n pruning_lr:\n class: ExponentialLR\n gamma: 0.9\n\npolicies:\n - pruner:\n instance_name : 'my_pruner'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 2\n\n - regularizer:\n instance_name: '2d_groups_regularizer'\n starting_epoch: 0\n ending_epoch: 38\n frequency: 1\n\n - lr_scheduler:\n instance_name: pruning_lr\n starting_epoch: 24\n ending_epoch: 200\n frequency: 1\n\n\n\n\n\nQuantization\n\n\nSimilarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the \nQuantizer\n class (see details \nhere\n).\n\n\nNotes\n: Only a single quantizer instance may be defined.\n\nLet's see an example:\n\n\nquantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null\n\n\n\n\n\n\nThe specific quantization method we're instantiating here is \nDorefaQuantizer\n.\n\n\nThen we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. \n\n\nThen, we define the \nbits_overrides\n mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of \nDorefaQuantizer\n, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters \nconv1\n, the first activation layer \nrelu1\n, the last activation layer \nfinal_relu\n and the last layer with parameters \nfc\n.\n\n\nSpecifying \nnull\n means \"do not quantize\".\n\n\nNote that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.\n\n\n\n\nDefining overrides for \ngroups of layers\n using regular expressions\n\n\nSuppose we have a sub-module in our model named \nblock1\n, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named \nconv1\n, \nconv2\n and so on. In that case we would define the following:\n\n\nbits_overrides:\n 'block1\\.conv*':\n wts: 2\n acts: null\n\n\n\n\n\n\nRegEx Note\n: Remember that the dot (\n.\n) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: \n\\.\n\n\n\n\nOverlapping patterns\n are also possible, which allows to define some override for a groups of layers and also \"single-out\" specific layers for different overrides. For example, let's take the last example and configure a different override for \nblock1.conv1\n:\n\n\nbits_overrides:\n 'block1\\.conv1':\n wts: 4\n acts: null\n 'block1\\.conv*':\n wts: 2\n acts: null\n\n\n\n\n\n\nImportant Note\n: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using \"broad\" patterns and more \"specific\" patterns as just shown, make sure the specific pattern is listed \nbefore\n the broad one.\n\n\n\n\nThe \nQuantizationPolicy\n, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the \nprepare_model()\n function of the \nQuantizer\n when it's initialized, followed by the first call to \nquantize_params()\n. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the \nquantize_params()\n function again. \n\n\npolicies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1\n\n\n\n\nImportant Note\n: As mentioned \nhere\n, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to \nprepare_model()\n must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.", "title": "Compression scheduling" }, { @@ -177,9 +177,14 @@ }, { "location": "/schedule/index.html#quantization", - "text": "Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the Quantizer class (see details here ). Notes : Only a single quantizer instance may be defined. \nLet's see an example: quantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null The specific quantization method we're instantiating here is DorefaQuantizer . Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. Then, we define the bits_overrides mapping. In this case, we choose not to quantize the first and last layer of the model. In the case of DorefaQuantizer , the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters conv1 , the first activation layer relu1 , the last activation layer final_relu and the last layer with parameters fc . Specifying null means \"do not quantize\". Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers. We can also reference groups of layers in the bits_overrides mapping. This is done using regular expressions. Suppose we have a sub-module in our model named block1 , which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named conv1 , conv2 and so on. In that case we would define the following: bits_overrides:\n block1.conv*:\n wts: 2\n acts: null The QuantizationPolicy , which controls the quantization procedure during training, is actually quite simplistic. All it does is call the prepare_model() function of the Quantizer when it's initialized, followed by the first call to quantize_params() . Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the quantize_params() function again. policies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1 Important Note : As mentioned here , since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to prepare_model() must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.", + "text": "Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the Quantizer class (see details here ). Notes : Only a single quantizer instance may be defined. \nLet's see an example: quantizers:\n dorefa_quantizer:\n class: DorefaQuantizer\n bits_activations: 8\n bits_weights: 4\n bits_overrides:\n conv1:\n wts: null\n acts: null\n relu1:\n wts: null\n acts: null\n final_relu:\n wts: null\n acts: null\n fc:\n wts: null\n acts: null The specific quantization method we're instantiating here is DorefaQuantizer . Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. Then, we define the bits_overrides mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of DorefaQuantizer , the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters conv1 , the first activation layer relu1 , the last activation layer final_relu and the last layer with parameters fc . Specifying null means \"do not quantize\". Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.", "title": "Quantization" }, + { + "location": "/schedule/index.html#defining-overrides-for-groups-of-layers-using-regular-expressions", + "text": "Suppose we have a sub-module in our model named block1 , which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named conv1 , conv2 and so on. In that case we would define the following: bits_overrides:\n 'block1\\.conv*':\n wts: 2\n acts: null RegEx Note : Remember that the dot ( . ) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: \\. Overlapping patterns are also possible, which allows to define some override for a groups of layers and also \"single-out\" specific layers for different overrides. For example, let's take the last example and configure a different override for block1.conv1 : bits_overrides:\n 'block1\\.conv1':\n wts: 4\n acts: null\n 'block1\\.conv*':\n wts: 2\n acts: null Important Note : The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using \"broad\" patterns and more \"specific\" patterns as just shown, make sure the specific pattern is listed before the broad one. The QuantizationPolicy , which controls the quantization procedure during training, is actually quite simplistic. All it does is call the prepare_model() function of the Quantizer when it's initialized, followed by the first call to quantize_params() . Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the quantize_params() function again. policies:\n - quantizer:\n instance_name: dorefa_quantizer\n starting_epoch: 0\n ending_epoch: 200\n frequency: 1 Important Note : As mentioned here , since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to prepare_model() must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a \"warm-startup\" (or \"boot-strapping\"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.", + "title": "Defining overrides for groups of layers using regular expressions" + }, { "location": "/pruning/index.html", "text": "Pruning\n\n\nA common methodology for inducing sparsity in weights and activations is called \npruning\n. Pruning is the application of a binary criteria to decide which weights to prune: weights which match the pruning criteria are assigned a value of zero. Pruned elements are \"trimmed\" from the model: we zero their values and also make sure they don't take part in the back-propagation process.\n\n\nWe can prune weights, biases, and activations. Biases are few and their contribution to a layer's output is relatively large, so there is little incentive to prune them. We usually see sparse activations following a ReLU layer, because ReLU quenches negative activations to exact zero (\\(ReLU(x): max(0,x)\\)). Sparsity in weights is less common, as weights tend to be very small, but are often not exact zeros.\n\n\n\nLet's define sparsity\n\n\nSparsity is a a measure of how many elements in a tensor are exact zeros, relative to the tensor size. A tensor is considered sparse if \"most\" of its elements are zero. How much is \"most\", is not strictly defined, but when you see a sparse tensor you know it ;-)\n\nThe \n\\(l_0\\)-\"norm\" function\n measures how many zero-elements are in a tensor \nx\n:\n\\[\\lVert x \\rVert_0\\;=\\;|x_1|^0 + |x_2|^0 + ... + |x_n|^0 \\]\nIn other words, an element contributes either a value of 1 or 0 to \\(l_0\\). Anything but an exact zero contributes a value of 1 - that's pretty cool.\n\nSometimes it helps to think about density, the number of non-zero elements (NNZ) and sparsity's complement:\n\\[\ndensity = 1 - sparsity\n\\]\nYou can use \ndistiller.sparsity\n and \ndistiller.density\n to query a PyTorch tensor's sparsity and density.\n\n\nWhat is weights pruning?\n\n\nWeights pruning, or model pruning, is a set of methods to increase the sparsity (amount of zero-valued elements in a tensor) of a network's weights. In general, the term 'parameters' refers to both weights and bias tensors of a model. Biases are rarely, if ever, pruned because there are very few bias elements compared to weights elements, and it is just not worth the trouble.\n\n\nPruning requires a criteria for choosing which elements to prune - this is called the \npruning criteria\n. The most common pruning criteria is the absolute value of each element: the element's absolute value is compared to some threshold value, and if it is below the threshold the element is set to zero (i.e. pruned) . This is implemented by the \ndistiller.MagnitudeParameterPruner\n class. The idea behind this method, is that weights with small \\(l_1\\)-norms (absolute value) contribute little to the final result (low saliency), so they are less important and can be removed.\n\n\nA related idea motivating pruning, is that models are over-parametrized and contain redundant logic and features. Therefore, some of these redundancies can be removed by setting their weights to zero.\n\n\nAnd yet another way to think of pruning is to phrase it as a search for a set of weights with as many zeros as possible, which still produces acceptable inference accuracies compared to the dense-model (non-pruned model). Another way to look at it, is to imagine that because of the very high-dimensionality of the parameter space, the immediate space around the dense-model's solution likely contains some sparse solutions, and we want to use find these sparse solutions. \n\n\n\n\nPruning schedule\n\n\nThe most straight-forward to prune is to take a trained model and prune it once; also called \none-shot pruning\n. In \nLearning both Weights and Connections for Efficient Neural Networks\n Song Han et. al show that this is surprisingly effective, but also leaves a lot of potential sparsity untapped. The surprise is what they call the \"free lunch\" effect: \n\"reducing 2x the connections without losing accuracy even without retraining.\"\n\nHowever, they also note that when employing a pruning-followed-by-retraining regimen, they can achieve much better results (higher sparsity at no accuracy loss). This is called \niterative pruning\n, and the retraining that follows pruning is often referred to as \nfine-tuning\n. How the pruning criteria changes between iterations, how many iterations we perform and how often, and which tensors are pruned - this is collectively called the \npruning schedule\n.\n\n\nWe can think of iterative pruning as repeatedly learning which weights are important, removing the least important ones based on some importance criteria, and then retraining the model to let it \"recover\" from the pruning by adjusting the remaining weights. At each iteration, we prune more weights.\n\nThe decision of when to stop pruning is also expressed in the schedule, and it depends on the pruning algorithm. For example, if we are trying to achieve a specific sparsity level, then we stop when the pruning achieves that level. And if we are pruning weights structures in order to reduce the required compute budget, then we stop the pruning when this compute reduction is achieved.\n\n\nDistiller supports expressing the pruning schedule as a YAML file (which is then executed by an instance of a PruningScheduler).\n\n\nPruning granularity\n\n\nPruning individual weight elements is called \nelement-wise pruning\n, and it is also sometimes referred to as \nfine-grained\n pruning.\n\n\nCoarse-grained pruning\n - also referred to as \nstructured pruning\n, \ngroup pruning\n, or \nblock pruning\n - is pruning entire groups of elements which have some significance. Groups come in various shapes and sizes, but an easy to visualize group-pruning is filter-pruning, in which entire filters are removed.\n\n\nSensitivity analysis\n\n\nThe hard part about inducing sparsity via pruning is determining what threshold, or sparsity level, to use for each layer's tensors. Sensitivity analysis is a method that tries to help us rank the tensors by their sensitivity to pruning. \n\nThe idea is to set the pruning level (percentage) of a specific layer, and then to prune once, run an evaluation on the test dataset and record the accuracy score. We do this for all of the parameterized layers, and for each layer we examine several sparsity levels. This should teach us about the \"sensitivity\" of each of the layers to pruning.\n\n\nThe evaluated model should be trained to maximum accuracy before running the analysis, because we aim to understand the behavior of the trained model's performance in relation to pruning of a specific weights tensor.\n\n\nMuch as we can prune structures, we can also perform sensitivity analysis on structures. Distiller implements element-wise pruning sensitivity analysis using the \\(l_1\\)-norm of individual elements; and filter-wise pruning sensitivity analysis using the mean \\(l_1\\)-norm of filters.\n\n\n\nThe authors of \nPruning Filters for Efficient ConvNets\n describe how they do sensitivity analysis:\n\n\n\n\n\"To understand the sensitivity of each layer, we prune each layer independently and evaluate the resulting pruned network\u2019s accuracy on the validation set. Figure 2(b) shows that layers that maintain their accuracy as filters are pruned away correspond to layers with larger slopes in Figure 2(a). On the contrary, layers with relatively flat slopes are more sensitive to pruning. We empirically determine the number of filters to prune for each layer based on their sensitivity to pruning. For deep networks such as VGG-16 or ResNets, we observe that layers in the same stage (with the same feature map size) have a similar sensitivity to pruning. To avoid introducing layer-wise meta-parameters, we use the same pruning ratio for all layers in the same stage. For layers that are sensitive to pruning, we prune a smaller percentage of these layers or completely skip pruning them.\"\n\n\n\n\nThe diagram below shows the results of running an element-wise sensitivity analysis on Alexnet, using Distillers's \nperform_sensitivity_analysis\n utility function.\n\n\nAs reported by Song Han, and exhibited in the diagram, in Alexnet the feature detecting layers (convolution layers) are more sensitive to pruning, and their sensitivity drops, the deeper they are. The fully-connected layers are much less sensitive, which is great, because that's where most of the parameters are.\n\n\n\n\nReferences\n\n\n \nSong Han, Jeff Pool, John Tran, William J. Dally\n.\n \nLearning both Weights and Connections for Efficient Neural Networks\n,\n arXiv:1607.04381v2,\n 2015.\n\n\n\n\n\nHao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, Hans Peter Graf\n.\n \nPruning Filters for Efficient ConvNets\n,\n arXiv:1608.08710v3,\n 2017.", @@ -522,7 +527,7 @@ }, { "location": "/design/index.html", - "text": "Distiller design\n\n\nDistiller is designed to be easily integrated into your own PyTorch research applications.\n\nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models (\ncompress_classifier.py\n).\n\n\nThe application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand.\n\n\nIntegrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training. The training skeleton looks like the pseudo code below. The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler.\n\n\nFor each epoch:\n compression_scheduler.on_epoch_begin(epoch)\n train()\n validate()\n save_checkpoint()\n compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n For each training step:\n compression_scheduler.on_minibatch_begin(epoch)\n output = model(input_var)\n loss = criterion(output, target_var)\n compression_scheduler.before_backward_pass(epoch)\n loss.backward()\n optimizer.step()\n compression_scheduler.on_minibatch_end(epoch)\n\n\n\n\nThese callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's \nScheduler\n, which invokes the correct algorithm. The application also uses Distiller services to collect statistics in \nSummaries\n and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.\n\n\n\n\nSparsification and fine-tuning\n\n\n\n\nThe application sets up a model as normally done in PyTorch.\n\n\nAnd then instantiates a Scheduler and configures it:\n\n\nScheduler configuration is defined in a YAML file\n\n\nThe configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training.\n\n\nSome types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\".\n\n\nSome algorithms control some parameter of the training process, such as the learning-rate decay scheduler (\nlr_scheduler\n).\n\n\nThe parameters of each algorithm are also specified in the configuration.\n\n\n\n\n\n\n\n\n\n\nIn addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency.\n\n\nThe Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined.\n\n\nThese callbacks are placed the training loop.\n\n\n\n\nQuantization\n\n\nA quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary.\n\n\nIn Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided.\n\n\nWe also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the \nQuantizer\n class. \nQuantizer\n should be sub-classed for each quantization method.\n\n\nModel Transformation\n\n\nThe high-level flow is as follows:\n\n\n\n\nDefine a \nmapping\n between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the \nreplacement_factory\n attribute of the \nQuantizer\n class.\n\n\nIterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it.\n\n\nReplace the existing module with the module returned by the function. It is important to note that the \nname\n of the module \ndoes not\n change, as that could break the \nforward\n function of the parent module.\n\n\n\n\nDifferent quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different \nmapping\n will likely be defined.\n\nEach sub-class of \nQuantizer\n should populate the \nreplacement_factory\n dictionary attribute with the appropriate mapping.\n\nTo execute the model transformation, call the \nprepare_model\n function of the \nQuantizer\n instance.\n\n\nFlexible Bit-Widths\n\n\n\n\nEach instance of \nQuantizer\n is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the \nbits_activations\n and \nbits_weights\n parameters in \nQuantizer\n's constructor. Sub-classes may define bit-widths for other tensor types as needed.\n\n\nWe also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.\n\n\nSo, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the \nbits_overrides\n parameter in the constructor.\n\n\n\n\nWeights Quantization\n\n\nThe \nQuantizer\n class also provides an API to quantize the weights of all layers at once. To use it, the \nparam_quantization_fn\n attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the \nQuantizer\n class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the \nquantize_params\n function can be called, which will iterate over all parameters and quantize them using \nparams_quantization_fn\n.\n\n\nTraining with Quantization\n\n\nThe \nQuantizer\n class supports training with quantization in the loop, as described \nhere\n. This is enabled by setting \ntrain_with_fp_copy=True\n in the \nQuantizer\n constructor. At model transformation, in each module that has parameters that should be quantized, a new \ntorch.nn.Parameter\n is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module \nis not\n created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\":\n\n\n\n\nThe existing \ntorch.nn.Parameter\n, e.g. \nweights\n, is replaced by a \ntorch.nn.Parameter\n named \nfloat_weight\n.\n\n\nTo maintain the existing functionality of the module, we then register a \nbuffer\n in the module with the original name - \nweights\n.\n\n\nDuring training, \nfloat_weight\n will be passed to \nparam_quantization_fn\n and the result will be stored in \nweight\n.\n\n\n\n\nImportant Note\n: Since this process modifies the model's parameters, it must be done \nbefore\n an PyTorch \nOptimizer\n is created (this refers to any of the sub-classes defined \nhere\n).\n\n\nThe base \nQuantizer\n class is implemented in \ndistiller/quantization/quantizer.py\n.\n\nFor a simple sub-class implementing symmetric linear quantization, see \nSymmetricLinearQuantizer\n in \ndistiller/quantization/range_linear.py\n. For examples of lower-precision methods using training with quantization see \nDorefaQuantizer\n and \nWRPNQuantizer\n in \ndistiller/quantization/clipped_linear.py", + "text": "Distiller design\n\n\nDistiller is designed to be easily integrated into your own PyTorch research applications.\n\nIt is easiest to understand this integration by examining the code of the sample application for compressing image classification models (\ncompress_classifier.py\n).\n\n\nThe application borrows its main flow code from torchvision's ImageNet classification training sample application (https://github.com/pytorch/examples/tree/master/imagenet). We tried to keep it similar, in order to make it familiar and easy to understand.\n\n\nIntegrating compression is very simple: simply add invocations of the appropriate compression_scheduler callbacks, for each stage in the training. The training skeleton looks like the pseudo code below. The boiler-plate Pytorch classification training is speckled with invocations of CompressionScheduler.\n\n\nFor each epoch:\n compression_scheduler.on_epoch_begin(epoch)\n train()\n validate()\n save_checkpoint()\n compression_scheduler.on_epoch_end(epoch)\n\ntrain():\n For each training step:\n compression_scheduler.on_minibatch_begin(epoch)\n output = model(input_var)\n loss = criterion(output, target_var)\n compression_scheduler.before_backward_pass(epoch)\n loss.backward()\n optimizer.step()\n compression_scheduler.on_minibatch_end(epoch)\n\n\n\n\nThese callbacks can be seen in the diagram below, as the arrow pointing from the Training Loop and into Distiller's \nScheduler\n, which invokes the correct algorithm. The application also uses Distiller services to collect statistics in \nSummaries\n and logs files, which can be queried at a later time, from Jupyter notebooks or TensorBoard.\n\n\n\n\nSparsification and fine-tuning\n\n\n\n\nThe application sets up a model as normally done in PyTorch.\n\n\nAnd then instantiates a Scheduler and configures it:\n\n\nScheduler configuration is defined in a YAML file\n\n\nThe configuration specifies Policies. Each Policy is tied to a specific algorithm which controls some aspect of the training.\n\n\nSome types of algorithms control the actual sparsification of the model. Such types are \"pruner\" and \"regularizer\".\n\n\nSome algorithms control some parameter of the training process, such as the learning-rate decay scheduler (\nlr_scheduler\n).\n\n\nThe parameters of each algorithm are also specified in the configuration.\n\n\n\n\n\n\n\n\n\n\nIn addition to specifying the algorithm, each Policy specifies scheduling parameters which control when the algorithm is executed: start epoch, end epoch and frequency.\n\n\nThe Scheduler exposes callbacks for relevant training stages: epoch start/end, mini-batch start/end and pre-backward pass. Each scheduler callback activates the policies that were defined according the schedule that was defined.\n\n\nThese callbacks are placed the training loop.\n\n\n\n\nQuantization\n\n\nA quantized model is obtained by replacing existing operations with quantized versions. The quantized versions can be either complete replacements, or wrappers. A wrapper will use the existing modules internally and add quantization and de-quantization operations before/after as necessary.\n\n\nIn Distiller we will provide a set of quantized versions of common operations which will enable implementation of different quantization methods. The user can write a quantized model from scratch, using the quantized operations provided.\n\n\nWe also provide a mechanism which takes an existing model and automatically replaces required operations with quantized versions. This mechanism is exposed by the \nQuantizer\n class. \nQuantizer\n should be sub-classed for each quantization method.\n\n\nModel Transformation\n\n\nThe high-level flow is as follows:\n\n\n\n\nDefine a \nmapping\n between the module types to be replaced (e.g. Conv2D, Linear, etc.) to a function which generates the replacement module. The mapping is defined in the \nreplacement_factory\n attribute of the \nQuantizer\n class.\n\n\nIterate over the modules defined in the model. For each module, if its type is in the mapping, call the replacement generation function. We pass the existing module to this function to allow wrapping of it.\n\n\nReplace the existing module with the module returned by the function. It is important to note that the \nname\n of the module \ndoes not\n change, as that could break the \nforward\n function of the parent module.\n\n\n\n\nDifferent quantization methods may, obviously, use different quantized operations. In addition, different methods may employ different \"strategies\" of replacing / wrapping existing modules. For instance, some methods replace ReLU with another activation function, while others keep it. Hence, for each quantization method, a different \nmapping\n will likely be defined.\n\nEach sub-class of \nQuantizer\n should populate the \nreplacement_factory\n dictionary attribute with the appropriate mapping.\n\nTo execute the model transformation, call the \nprepare_model\n function of the \nQuantizer\n instance.\n\n\nFlexible Bit-Widths\n\n\n\n\nEach instance of \nQuantizer\n is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the \nbits_activations\n and \nbits_weights\n parameters in \nQuantizer\n's constructor. Sub-classes may define bit-widths for other tensor types as needed.\n\n\nWe also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.\n\n\nSo, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the \nbits_overrides\n parameter in the constructor.\n\n\nThe \nbits_overrides\n mapping is required to be an instance of \ncollections.OrderedDict\n (as opposed to just a simple Python \ndict\n). This is done in order to enable handling of overlapping name patterns.\n\n So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.\n\n The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.\n\n\n\n\nWeights Quantization\n\n\nThe \nQuantizer\n class also provides an API to quantize the weights of all layers at once. To use it, the \nparam_quantization_fn\n attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the \nQuantizer\n class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the \nquantize_params\n function can be called, which will iterate over all parameters and quantize them using \nparams_quantization_fn\n.\n\n\nTraining with Quantization\n\n\nThe \nQuantizer\n class supports training with quantization in the loop, as described \nhere\n. This is enabled by setting \ntrain_with_fp_copy=True\n in the \nQuantizer\n constructor. At model transformation, in each module that has parameters that should be quantized, a new \ntorch.nn.Parameter\n is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module \nis not\n created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following \"hack\":\n\n\n\n\nThe existing \ntorch.nn.Parameter\n, e.g. \nweights\n, is replaced by a \ntorch.nn.Parameter\n named \nfloat_weight\n.\n\n\nTo maintain the existing functionality of the module, we then register a \nbuffer\n in the module with the original name - \nweights\n.\n\n\nDuring training, \nfloat_weight\n will be passed to \nparam_quantization_fn\n and the result will be stored in \nweight\n.\n\n\n\n\nImportant Note\n: Since this process modifies the model's parameters, it must be done \nbefore\n an PyTorch \nOptimizer\n is created (this refers to any of the sub-classes defined \nhere\n).\n\n\nThe base \nQuantizer\n class is implemented in \ndistiller/quantization/quantizer.py\n.\n\nFor a simple sub-class implementing symmetric linear quantization, see \nSymmetricLinearQuantizer\n in \ndistiller/quantization/range_linear.py\n. For examples of lower-precision methods using training with quantization see \nDorefaQuantizer\n and \nWRPNQuantizer\n in \ndistiller/quantization/clipped_linear.py", "title": "Design" }, { @@ -547,7 +552,7 @@ }, { "location": "/design/index.html#flexible-bit-widths", - "text": "Each instance of Quantizer is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the bits_activations and bits_weights parameters in Quantizer 's constructor. Sub-classes may define bit-widths for other tensor types as needed. We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern. So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the bits_overrides parameter in the constructor.", + "text": "Each instance of Quantizer is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the bits_activations and bits_weights parameters in Quantizer 's constructor. Sub-classes may define bit-widths for other tensor types as needed. We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks (\"container\" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern. So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the bits_overrides parameter in the constructor. The bits_overrides mapping is required to be an instance of collections.OrderedDict (as opposed to just a simple Python dict ). This is done in order to enable handling of overlapping name patterns. \n So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'. \n The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.", "title": "Flexible Bit-Widths" }, { diff --git a/docs/sitemap.xml b/docs/sitemap.xml index ea7fb0b..8ca15b0 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -4,7 +4,7 @@ <url> <loc>/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -12,7 +12,7 @@ <url> <loc>/install/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -20,7 +20,7 @@ <url> <loc>/usage/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -28,7 +28,7 @@ <url> <loc>/schedule/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -37,19 +37,19 @@ <url> <loc>/pruning/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/regularization/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/quantization/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -59,13 +59,13 @@ <url> <loc>/algo_pruning/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> <url> <loc>/algo_quantization/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -74,7 +74,7 @@ <url> <loc>/model_zoo/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -82,7 +82,7 @@ <url> <loc>/jupyter/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> @@ -90,7 +90,7 @@ <url> <loc>/design/index.html</loc> - <lastmod>2018-07-01</lastmod> + <lastmod>2018-07-17</lastmod> <changefreq>daily</changefreq> </url> diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 6f77e61..e8ea127 100644 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -105,6 +105,9 @@ test_configs = [ TestConfig('-a resnet20_cifar --resume {0} --quantize --evaluate'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), DS_CIFAR, accuracy_checker, [91.620, 99.630]), + TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. + format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_dorefa_test.yaml')), + DS_CIFAR, accuracy_checker, [41.190, 91.160]) ] diff --git a/tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml b/tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml new file mode 100644 index 0000000..a987053 --- /dev/null +++ b/tests/full_flow_tests/preact_resnet20_cifar_dorefa_test.yaml @@ -0,0 +1,38 @@ +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 8 + bits_weights: 3 + bits_overrides: + # Don't quantize first and last layer + conv1: + wts: null + acts: null + layer1.0.pre_relu: + wts: null + acts: null + final_relu: + wts: null + acts: null + fc: + wts: null + acts: null + +lr_schedulers: + training_lr: + class: MultiStepMultiGammaLR + milestones: [80, 120, 160] + gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 161 + frequency: 1 diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py new file mode 100644 index 0000000..fece996 --- /dev/null +++ b/tests/test_quantizer.py @@ -0,0 +1,339 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import os +import sys +import torch.nn as nn +from copy import deepcopy +from collections import OrderedDict +import pytest + +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) +from distiller.quantization import Quantizer +from distiller.quantization.quantizer import QBits +from distiller.quantization.quantizer import FP_BKP_PREFIX +from distiller import has_children + + +############################# +# Dummy modules +############################# + +class DummyQuantLayer(nn.Module): + def __init__(self, qbits): + super(DummyQuantLayer, self).__init__() + self.qbits = qbits + + def forward(self, *input): + return input + + +class DummyWrapperLayer(nn.Module): + def __init__(self, module, qbits): + super(DummyWrapperLayer, self).__init__() + self.qbits = qbits + self.inner = module + + def forward(self, *input): + return input + + +class DummyModel(nn.Sequential): + def __init__(self): + super(DummyModel, self).__init__() + + self.add_module('conv1', nn.Conv2d(3, 16, 1)) + self.add_module('bn1', nn.BatchNorm2d(16)) + self.add_module('relu1', nn.ReLU()) + self.add_module('pool1', nn.MaxPool2d(2, 2)) + + def gen_sub_module(): + sub_m = nn.Sequential() + sub_m.add_module('conv1', nn.Conv2d(16, 32, 1)) + sub_m.add_module('bn1', nn.BatchNorm2d(32)) + sub_m.add_module('relu1', nn.ReLU()) + sub_m.add_module('pool1', nn.MaxPool2d(2, 2)) + sub_m.add_module('conv2', nn.Conv2d(32, 16, 1)) + sub_m.add_module('bn2', nn.BatchNorm2d(16)) + sub_m.add_module('relu2', nn.ReLU()) + sub_m.add_module('pool2', nn.MaxPool2d(2, 2)) + return sub_m + + self.add_module('sub1', gen_sub_module()) + self.add_module('sub2', gen_sub_module()) + + self.add_module('fc', nn.Linear(16, 10)) + self.add_module('last_relu', nn.ReLU(10)) + + # Use zeroed parameters to make it easier to validate our dummy quantization function + for p in self.parameters(): + p.data = torch.zeros_like(p) + + +############################# +# Dummy Quantizer +############################# + +def dummy_quantize_params(param, num_bits): + return param + num_bits + + +class DummyQuantizer(Quantizer): + def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(), + quantize_bias=False, train_with_fp_copy=False): + super(DummyQuantizer, self).__init__(model, bits_activations, bits_weights, bits_overrides, quantize_bias, + train_with_fp_copy) + + self.replacement_factory[nn.Conv2d] = lambda module, name, qbits_map: DummyWrapperLayer(module, qbits_map[name]) + self.replacement_factory[nn.ReLU] = lambda module, name, qbits_map: DummyQuantLayer(qbits_map[name]) + self.param_quantization_fn = dummy_quantize_params + + +############################# +# Other utils +############################# + +expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer} + + +def params_quantizable(module): + return isinstance(module, (nn.Conv2d, nn.Linear)) + + +def get_expected_qbits(model, qbits, expected_overrides): + expected_qbits = {} + post_prepare_changes = {} + prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else '' + for orig_name, orig_module in model.named_modules(): + bits_a, bits_w = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits) + if not params_quantizable(orig_module): + bits_w = None + expected_qbits[orig_name] = QBits(bits_a, bits_w) + + # We're testing replacement of module with container + if isinstance(orig_module, nn.Conv2d): + post_prepare_changes[orig_name] = QBits(bits_a, None) + post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name] + + return expected_qbits, post_prepare_changes + + +############################# +# Fixtures +############################# + +@pytest.fixture(name='model') +def fixture_model(): + return DummyModel() + + +@pytest.fixture(name='train_with_fp_copy', params=[False, True], ids=['fp_copy_off', 'fp_copy_on']) +def fixture_train_with_fp_copy(request): + return request.param + + +@pytest.fixture(name='quantize_bias', params=[False, True], ids=['bias_off', 'bias_on']) +def fixture_quantize_bias(request): + return request.param + + +@pytest.fixture(name='parallel', params=[False, True], ids=['parallel_off', 'parallel_on']) +def fixture_parallel(request): + return request.param + + +############################# +# Tests +############################# + +def test_no_quantization(model): + m_orig = deepcopy(model) + + q = DummyQuantizer(model) + assert all(qbits.acts is None and qbits.wts is None for qbits in q.module_qbits_map.values()) + + q.prepare_model() + assert len(q.params_to_quantize) == 0 + assert all(type(q_module) == type(orig_module) for q_module, orig_module in zip(model.modules(), m_orig.modules())) + + q.quantize_params() + assert all(torch.equal(q_param, orig_param) for q_param, orig_param in zip(model.parameters(), m_orig.parameters())) + + +def test_overrides_ordered_dict(model): + with pytest.raises(TypeError, message='Expecting TypeError when bits_overrides is not an OrderedDict'): + DummyQuantizer(model, bits_overrides={}) + + +@pytest.mark.parametrize( + "qbits, bits_overrides, explicit_expected_overrides", + [ + (QBits(8, 4), OrderedDict(), {}), + (QBits(8, 4), + OrderedDict([('conv1', {'acts': None, 'wts': None}), ('relu1', {'acts': None, 'wts': None})]), + {'conv1': QBits(None, None), 'relu1': QBits(None, None)}), + (QBits(8, 8), + OrderedDict([('sub.*conv1', {'wts': 4}), ('sub.*conv2', {'acts': 4, 'wts': 4})]), + {'sub1.conv1': QBits(8, 4), 'sub1.conv2': QBits(4, 4), 'sub2.conv1': QBits(8, 4), 'sub2.conv2': QBits(4, 4)}), + (QBits(4, 4), + OrderedDict([('sub1\..*1', {'acts': 16, 'wts': 16}), ('sub1\..*', {'acts': 8, 'wts': 8})]), + {'sub1.conv1': QBits(16, 16), 'sub1.bn1': QBits(16, None), + 'sub1.relu1': QBits(16, None), 'sub1.pool1': QBits(16, None), + 'sub1.conv2': QBits(8, 8), 'sub1.bn2': QBits(8, None), + 'sub1.relu2': QBits(8, None), 'sub1.pool2': QBits(8, None)}), + (QBits(4, 4), + OrderedDict([('sub1\..*', {'acts': 8, 'wts': 8}), ('sub1\..*1', {'acts': 16, 'wts': 16})]), + {'sub1.conv1': QBits(8, 8), 'sub1.bn1': QBits(8, None), + 'sub1.relu1': QBits(8, None), 'sub1.pool1': QBits(8, None), + 'sub1.conv2': QBits(8, 8), 'sub1.bn2': QBits(8, None), + 'sub1.relu2': QBits(8, None), 'sub1.pool2': QBits(8, None)}), + ], + ids=[ + 'no_override', + 'simple_override', + 'pattern_override', + 'overlap_pattern_override_proper', # "proper" ==> Specific pattern before broader pattern + 'overlap_pattern_override_wrong' # "wrong" ==> Broad pattern before specific pattern, so specific pattern + # never actually matched + ] +) +def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides, + train_with_fp_copy, quantize_bias, parallel): + if parallel: + model = torch.nn.DataParallel(model) + m_orig = deepcopy(model) + + # Build expected QBits + expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides) + + # Initialize Quantizer + q = DummyQuantizer(model, bits_activations=qbits.acts, bits_weights=qbits.wts, + bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy, + quantize_bias=quantize_bias) + + # Check number of bits for quantization were registered correctly + assert q.module_qbits_map == expected_qbits + + q.prepare_model() + expected_qbits.update(post_prepare_changes) + + for ptq in q.params_to_quantize: + assert params_quantizable(ptq.module) + assert expected_qbits[ptq.module_name].wts is not None + + # Check parameter names are as expected + assert ptq.q_attr_name in ['weight', 'bias'] + + # Check bias will be quantized only if flag is enabled + if ptq.q_attr_name == 'bias': + assert quantize_bias + + named_params = dict(ptq.module.named_parameters()) + if q.train_with_fp_copy: + # Checking parameter replacement is as expected + assert ptq.fp_attr_name == FP_BKP_PREFIX + ptq.q_attr_name + assert ptq.fp_attr_name in named_params + assert ptq.q_attr_name not in named_params + # Making sure the following doesn't throw an exception, + # so we know q_attr_name is still a buffer in the module + getattr(ptq.module, ptq.q_attr_name) + else: + # Make sure we didn't screw anything up + assert ptq.fp_attr_name == ptq.q_attr_name + assert ptq.fp_attr_name in named_params + + # Check number of bits registered correctly + assert ptq.num_bits == expected_qbits[ptq.module_name].wts + + q_named_modules = dict(model.named_modules()) + orig_named_modules = dict(m_orig.named_modules()) + for orig_name, orig_module in orig_named_modules.items(): + # Check no module name from original model is missing + assert orig_name in q_named_modules + + # Check module replacement is as expected + q_module = q_named_modules[orig_name] + expected_type = expected_type_replacements.get(type(orig_module)) + if expected_type is None or expected_qbits[orig_name] == QBits(None, None): + assert type(orig_module) == type(q_module) + else: + assert type(q_module) == expected_type + if expected_type == DummyWrapperLayer: + assert expected_qbits[orig_name + '.inner'] == q_module.qbits + else: + assert expected_qbits[orig_name] == q_module.qbits + + +@pytest.mark.parametrize( + "qbits, bits_overrides, explicit_expected_overrides", + [ + (QBits(8, 8), + OrderedDict([('conv1', {'acts': None, 'wts': None}), ('relu1', {'acts': None, 'wts': None}), + ('sub.*conv1', {'acts': 8, 'wts': 4}), ('sub.*conv2', {'acts': 4, 'wts': 4})]), + {'conv1': QBits(None, None), 'relu1': QBits(None, None), + 'sub1.conv1': QBits(8, 4), 'sub1.conv2': QBits(4, 4), 'sub2.conv1': QBits(8, 4), 'sub2.conv2': QBits(4, 4)}), + ] +) +def test_param_quantization(model, qbits, bits_overrides, explicit_expected_overrides, + train_with_fp_copy, quantize_bias): + # Build expected QBits + expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides) + + q = DummyQuantizer(model, bits_activations=qbits.acts, bits_weights=qbits.wts, + bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy, + quantize_bias=quantize_bias) + q.prepare_model() + expected_qbits.update(post_prepare_changes) + + q_model_pre_quant = deepcopy(model) + q.quantize_params() + for (name, pre_quant_module), post_quant_module in zip(q_model_pre_quant.named_modules(), model.modules()): + # Skip containers + # if len(list(pre_quant_module.modules())) > 1: + if has_children(pre_quant_module): + continue + + num_bits = expected_qbits[name].wts + + for param_name, pre_quant_param in pre_quant_module.named_parameters(): + quantizable = num_bits is not None + if param_name.endswith('bias'): + quantizable = quantizable and quantize_bias + + if quantizable and train_with_fp_copy: + # "param_name" and "pre_quant_param" refer to the float copy + + # Check the float copy didn't change + post_quant_fp_copy = getattr(post_quant_module, param_name) + assert torch.equal(pre_quant_param, post_quant_fp_copy) + + quant_param = getattr(post_quant_module, param_name.replace(FP_BKP_PREFIX, '')) + + # Check weights quantization properly recorded for autograd + gfn = quant_param.grad_fn + assert gfn is not None + assert str(type(gfn).__name__) == 'AddBackward0' + gfn = gfn.next_functions[0][0] + assert str(type(gfn).__name__) == 'AccumulateGrad' + assert id(gfn.variable) == id(post_quant_fp_copy) + else: + quant_param = getattr(post_quant_module, param_name) + + expected = dummy_quantize_params(pre_quant_param, num_bits) if quantizable else pre_quant_param + assert torch.equal(quant_param, expected) -- GitLab