Skip to content
Snippets Groups Projects
Commit df9a00ce authored by Gal Novik's avatar Gal Novik Committed by Guy Jacob
Browse files

PACT quantizer (#30)

* Adding PACT quantization method
* Move logic modifying the optimizer due to changes the quantizer makes into the Quantizer itself
* Updated documentation and tests
parent 9f0c0832
No related branches found
No related tags found
No related merge requests found
Showing
with 310 additions and 70 deletions
...@@ -56,7 +56,7 @@ def dict_config(model, optimizer, sched_dict): ...@@ -56,7 +56,7 @@ def dict_config(model, optimizer, sched_dict):
pruners = __factory('pruners', model, sched_dict) pruners = __factory('pruners', model, sched_dict)
regularizers = __factory('regularizers', model, sched_dict) regularizers = __factory('regularizers', model, sched_dict)
quantizers = __factory('quantizers', model, sched_dict) quantizers = __factory('quantizers', model, sched_dict, optimizer=optimizer)
if len(quantizers) > 1: if len(quantizers) > 1:
print("\nError: Multiple Quantizers not supported") print("\nError: Multiple Quantizers not supported")
exit(1) exit(1)
...@@ -92,12 +92,6 @@ def dict_config(model, optimizer, sched_dict): ...@@ -92,12 +92,6 @@ def dict_config(model, optimizer, sched_dict):
quantizer = quantizers[instance_name] quantizer = quantizers[instance_name]
policy = distiller.QuantizationPolicy(quantizer) policy = distiller.QuantizationPolicy(quantizer)
# Quantizers for training modify the models parameters, need to update the optimizer
if quantizer.train_with_fp_copy:
optimizer_type = type(optimizer)
new_optimizer = optimizer_type(model.parameters(), **optimizer.defaults)
optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
elif 'lr_scheduler' in policy_def: elif 'lr_scheduler' in policy_def:
# LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain
# a quantization policy was initialized (if exists) # a quantization policy was initialized (if exists)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from .quantizer import Quantizer from .quantizer import Quantizer
from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, SymmetricLinearQuantizer
from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
del quantizer del quantizer
del range_linear del range_linear
......
...@@ -43,6 +43,34 @@ class LinearQuantizeSTE(torch.autograd.Function): ...@@ -43,6 +43,34 @@ class LinearQuantizeSTE(torch.autograd.Function):
return grad_output, None, None, None return grad_output, None, None, None
class LearnedClippedLinearQuantizeSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, input, clip_val, num_bits, dequantize, inplace):
ctx.save_for_backward(input, clip_val)
if inplace:
ctx.mark_dirty(input)
scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, clip_val.data[0])
output = clamp(input, 0, clip_val.data[0], inplace)
output = linear_quantize(output, scale_factor, inplace)
if dequantize:
output = linear_dequantize(output, scale_factor, inplace)
return output
@staticmethod
def backward(ctx, grad_output):
input, clip_val = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input.le(0)] = 0
grad_input[input.ge(clip_val.data[0])] = 0
grad_alpha = grad_output.clone()
grad_alpha[input.lt(clip_val.data[0])] = 0
grad_alpha = grad_alpha.sum().expand_as(clip_val)
# Straight-through estimator for the scale factor calculation
return grad_input, grad_alpha, None, None, None
class ClippedLinearQuantization(nn.Module): class ClippedLinearQuantization(nn.Module):
def __init__(self, num_bits, clip_val, dequantize=True, inplace=False): def __init__(self, num_bits, clip_val, dequantize=True, inplace=False):
super(ClippedLinearQuantization, self).__init__() super(ClippedLinearQuantization, self).__init__()
...@@ -63,6 +91,24 @@ class ClippedLinearQuantization(nn.Module): ...@@ -63,6 +91,24 @@ class ClippedLinearQuantization(nn.Module):
inplace_str) inplace_str)
class LearnedClippedLinearQuantization(nn.Module):
def __init__(self, num_bits, init_act_clip_val, dequantize=True, inplace=False):
super(LearnedClippedLinearQuantization, self).__init__()
self.num_bits = num_bits
self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]))
self.dequantize = dequantize
self.inplace = inplace
def forward(self, input):
input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, self.dequantize, self.inplace)
return input
def __repr__(self):
inplace_str = ', inplace' if self.inplace else ''
return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val,
inplace_str)
class WRPNQuantizer(Quantizer): class WRPNQuantizer(Quantizer):
""" """
Quantizer using the WRPN quantization scheme, as defined in: Quantizer using the WRPN quantization scheme, as defined in:
...@@ -72,10 +118,11 @@ class WRPNQuantizer(Quantizer): ...@@ -72,10 +118,11 @@ class WRPNQuantizer(Quantizer):
1. This class does not take care of layer widening as described in the paper 1. This class does not take care of layer widening as described in the paper
2. The paper defines special handling for 1-bit weights which isn't supported here yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet
""" """
def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=False): def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
super(WRPNQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, quantize_bias=False):
bits_overrides=bits_overrides, train_with_fp_copy=True, super(WRPNQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
quantize_bias=quantize_bias) bits_weights=bits_weights, bits_overrides=bits_overrides,
train_with_fp_copy=True, quantize_bias=quantize_bias)
def wrpn_quantize_param(param_fp, num_bits): def wrpn_quantize_param(param_fp, num_bits):
scale_factor = symmetric_linear_quantization_scale_factor(num_bits, 1) scale_factor = symmetric_linear_quantization_scale_factor(num_bits, 1)
...@@ -94,6 +141,15 @@ class WRPNQuantizer(Quantizer): ...@@ -94,6 +141,15 @@ class WRPNQuantizer(Quantizer):
self.replacement_factory[nn.ReLU] = relu_replace_fn self.replacement_factory[nn.ReLU] = relu_replace_fn
def dorefa_quantize_param(param_fp, num_bits):
scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
out = param_fp.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = LinearQuantizeSTE.apply(out, scale_factor, True, False)
out = 2 * out - 1
return out
class DorefaQuantizer(Quantizer): class DorefaQuantizer(Quantizer):
""" """
Quantizer using the DoReFa scheme, as defined in: Quantizer using the DoReFa scheme, as defined in:
...@@ -104,18 +160,11 @@ class DorefaQuantizer(Quantizer): ...@@ -104,18 +160,11 @@ class DorefaQuantizer(Quantizer):
1. Gradients quantization not supported yet 1. Gradients quantization not supported yet
2. The paper defines special handling for 1-bit weights which isn't supported here yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet
""" """
def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=False): def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, quantize_bias=False):
bits_overrides=bits_overrides, train_with_fp_copy=True, super(DorefaQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
quantize_bias=quantize_bias) bits_weights=bits_weights, bits_overrides=bits_overrides,
train_with_fp_copy=True, quantize_bias=quantize_bias)
def dorefa_quantize_param(param_fp, num_bits):
scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
out = param_fp.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = LinearQuantizeSTE.apply(out, scale_factor, True, False)
out = 2 * out - 1
return out
def relu_replace_fn(module, name, qbits_map): def relu_replace_fn(module, name, qbits_map):
bits_acts = qbits_map[name].acts bits_acts = qbits_map[name].acts
...@@ -126,3 +175,44 @@ class DorefaQuantizer(Quantizer): ...@@ -126,3 +175,44 @@ class DorefaQuantizer(Quantizer):
self.param_quantization_fn = dorefa_quantize_param self.param_quantization_fn = dorefa_quantize_param
self.replacement_factory[nn.ReLU] = relu_replace_fn self.replacement_factory[nn.ReLU] = relu_replace_fn
class PACTQuantizer(Quantizer):
"""
Quantizer using the PACT quantization scheme, as defined in:
Choi et al., PACT: Parameterized Clipping Activation for Quantized Neural Networks
(https://arxiv.org/abs/1805.06085)
Args:
act_clip_init_val (float): Initial clipping value for activations, referred to as "alpha" in the paper
(default: 8.0)
act_clip_decay (float): L2 penalty applied to the clipping values, referred to as "lambda_alpha" in the paper.
If None then the optimizer's default weight decay value is used (default: None)
"""
def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
quantize_bias=False, act_clip_init_val=8.0, act_clip_decay=None):
super(PACTQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
bits_weights=bits_weights, bits_overrides=bits_overrides,
train_with_fp_copy=True, quantize_bias=quantize_bias)
def relu_replace_fn(module, name, qbits_map):
bits_acts = qbits_map[name].acts
if bits_acts is None:
return module
return LearnedClippedLinearQuantization(bits_acts, act_clip_init_val, dequantize=True,
inplace=module.inplace)
self.param_quantization_fn = dorefa_quantize_param
self.replacement_factory[nn.ReLU] = relu_replace_fn
self.act_clip_decay = act_clip_decay
# In PACT, LearnedClippedLinearQuantization is used for activation, which contains a learnt 'clip_val' parameter
# We optimize this value separately from the main model parameters
def _get_updated_optimizer_params_groups(self):
base_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' not in name]}
clip_val_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' in name]}
if self.act_clip_decay is not None:
clip_val_group['weight_decay'] = self.act_clip_decay
return [base_group, clip_val_group]
...@@ -58,6 +58,9 @@ class Quantizer(object): ...@@ -58,6 +58,9 @@ class Quantizer(object):
Args: Args:
model (torch.nn.Module): The model to be quantized model (torch.nn.Module): The model to be quantized
optimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is going
to perform changes to existing model parameters and/or add new ones.
Specifically, when train_with_fp_copy is True, this cannot be None.
bits_activations/weights (int): Default number of bits to use when quantizing each tensor type. bits_activations/weights (int): Default number of bits to use when quantizing each tensor type.
Value of None means do not quantize. Value of None means do not quantize.
bits_overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with bits_overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with
...@@ -77,15 +80,18 @@ class Quantizer(object): ...@@ -77,15 +80,18 @@ class Quantizer(object):
3.2 We also back-prop through the 'quantize' operation from step 1 3.2 We also back-prop through the 'quantize' operation from step 1
4. Update fp_weights with gradients calculated in step 3.2 4. Update fp_weights with gradients calculated in step 3.2
""" """
def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(), def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(),
quantize_bias=False, train_with_fp_copy=False): quantize_bias=False, train_with_fp_copy=False):
if not isinstance(bits_overrides, OrderedDict): if not isinstance(bits_overrides, OrderedDict):
raise TypeError('bits_overrides must be an instance of collections.OrderedDict') raise TypeError('bits_overrides must be an instance of collections.OrderedDict')
if train_with_fp_copy and optimizer is None:
raise ValueError('optimizer cannot be None when train_with_fp_copy is True')
self.default_qbits = QBits(acts=bits_activations, wts=bits_weights) self.default_qbits = QBits(acts=bits_activations, wts=bits_weights)
self.quantize_bias = quantize_bias self.quantize_bias = quantize_bias
self.model = model self.model = model
self.optimizer = optimizer
# Stash some quantizer data in the model so we can re-apply the quantizer on a resuming model # Stash some quantizer data in the model so we can re-apply the quantizer on a resuming model
self.model.quantizer_metadata = {'type': type(self), self.model.quantizer_metadata = {'type': type(self),
...@@ -166,6 +172,12 @@ class Quantizer(object): ...@@ -166,6 +172,12 @@ class Quantizer(object):
msglogger.info( msglogger.info(
"Parameter '{0}' will be quantized to {1} bits".format(param_full_name, qbits.wts)) "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, qbits.wts))
# If an optimizer was passed, assume we need to update it
if self.optimizer:
optimizer_type = type(self.optimizer)
new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
def _pre_process_container(self, container, prefix=''): def _pre_process_container(self, container, prefix=''):
...@@ -192,6 +204,20 @@ class Quantizer(object): ...@@ -192,6 +204,20 @@ class Quantizer(object):
# For container we call recursively # For container we call recursively
self._pre_process_container(module, full_name + '.') self._pre_process_container(module, full_name + '.')
def _get_updated_optimizer_params_groups(self):
"""
Returns a list of model parameter groups and optimizer hyper-parameter overrides,
as expected by the __init__ function of torch.optim.Optimizer.
This is called after all model changes were made in prepare_model, in case an Optimizer instance was
passed to __init__.
Subclasses which add parameters to the model should override as needed.
:return: List of parameter groups
"""
# Default implementation - just return all model parameters as one group
return [{'params': self.model.parameters()}]
def quantize_params(self): def quantize_params(self):
""" """
Quantize all parameters using the parameters using self.param_quantization_fn (using the defined number Quantize all parameters using the parameters using self.param_quantization_fn (using the defined number
......
...@@ -148,9 +148,6 @@ class CompressionScheduler(object): ...@@ -148,9 +148,6 @@ class CompressionScheduler(object):
name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, '', 1) name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, '', 1)
name = '.'.join(name_parts) name = '.'.join(name_parts)
self.zeros_mask_dict[name].apply_mask(param) self.zeros_mask_dict[name].apply_mask(param)
else:
raise
def state_dict(self): def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`. """Returns the state of the scheduler as a :class:`dict`.
......
...@@ -29,6 +29,14 @@ This method requires training the model with quantization, as discussed [here](q ...@@ -29,6 +29,14 @@ This method requires training the model with quantization, as discussed [here](q
- Gradients quantization as proposed in the paper is not supported yet. - Gradients quantization as proposed in the paper is not supported yet.
- The paper defines special handling for binary weights which isn't supported in Distiller yet. - The paper defines special handling for binary weights which isn't supported in Distiller yet.
## PACT
(As proposed in [PACT: Parameterized Clipping Activation for Quantized Neural Networks](https://arxiv.org/abs/1805.06085))
This method is similar to DoReFa, but the upper clipping values, \(\alpha\), of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, \(\alpha\) is shared per layer.
This method requires training the model with quantization, as discussed [here](quantization/#training-with-quantization). Use the `PACTQuantizer` class to transform an existing model to a model suitable for training with quantization using PACT.
## WRPN ## WRPN
(As proposed in [WRPN: Wide Reduced-Precision Networks](https://arxiv.org/abs/1709.01134)) (As proposed in [WRPN: Wide Reduced-Precision Networks](https://arxiv.org/abs/1709.01134))
......
...@@ -76,13 +76,23 @@ The `Quantizer` class also provides an API to quantize the weights of all layers ...@@ -76,13 +76,23 @@ The `Quantizer` class also provides an API to quantize the weights of all layers
### Training with Quantization ### Training with Quantization
The `Quantizer` class supports training with quantization in the loop, as described [here](quantization.md#training-with-quantization). This is enabled by setting `train_with_fp_copy=True` in the `Quantizer` constructor. At model transformation, in each module that has parameters that should be quantized, a new `torch.nn.Parameter` is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module **is not** created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": The `Quantizer` class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:
1. The existing `torch.nn.Parameter`, e.g. `weights`, is replaced by a `torch.nn.Parameter` named `float_weight`. 1. Maintaining a full precision copy of the weights, as described [here](quantization.md#training-with-quantization). This is enabled by setting `train_with_fp_copy=True` in the `Quantizer` constructor. At model transformation, in each module that has parameters that should be quantized, a new `torch.nn.Parameter` is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module **is not** created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack":
2. To maintain the existing functionality of the module, we then register a `buffer` in the module with the original name - `weights`.
3. During training, `float_weight` will be passed to `param_quantization_fn` and the result will be stored in `weight`.
**Important Note**: Since this process modifies the model's parameters, it must be done **before** an PyTorch `Optimizer` is created (this refers to any of the sub-classes defined [here](https://pytorch.org/docs/stable/optim.html#algorithms)). 1. The existing `torch.nn.Parameter`, e.g. `weights`, is replaced by a `torch.nn.Parameter` named `float_weight`.
2. To maintain the existing functionality of the module, we then register a `buffer` in the module with the original name - `weights`.
3. During training, `float_weight` will be passed to `param_quantization_fn` and the result will be stored in `weight`.
2. In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the [PACT](algo_quantization.md#PACT) method, acitvations are clipped to a value \(\alpha\), which is a learned parameter per-layer
To support these two cases, the `Quantizer` class also accepts an instance of a `torch.optim.Optimizer` (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters.
!!! Note "Optimizing New Parameters"
In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override `Quantizer._get_updated_optimizer_params_groups()`, and return the proper groups plus any desired hyper-parameter overrides.
### Examples
The base `Quantizer` class is implemented in `distiller/quantization/quantizer.py`. The base `Quantizer` class is implemented in `distiller/quantization/quantizer.py`.
For a simple sub-class implementing symmetric linear quantization, see `SymmetricLinearQuantizer` in `distiller/quantization/range_linear.py`. For examples of lower-precision methods using training with quantization see `DorefaQuantizer` and `WRPNQuantizer` in `distiller/quantization/clipped_linear.py` For a simple sub-class implementing symmetric linear quantization, see `SymmetricLinearQuantizer` in `distiller/quantization/range_linear.py`.
In `distiller/quantization/clipped_linear.py` there are examples of lower-precision methods which use training with quantization. Specifically, see `PACTQuantizer` for an example of overriding `Quantizer._get_updated_optimizer_params_groups()`.
...@@ -106,6 +106,8 @@ ...@@ -106,6 +106,8 @@
<li><a class="toctree-l4" href="#dorefa">DoReFa</a></li> <li><a class="toctree-l4" href="#dorefa">DoReFa</a></li>
<li><a class="toctree-l4" href="#pact">PACT</a></li>
<li><a class="toctree-l4" href="#wrpn">WRPN</a></li> <li><a class="toctree-l4" href="#wrpn">WRPN</a></li>
<li><a class="toctree-l4" href="#symmetric-linear-quantization">Symmetric Linear Quantization</a></li> <li><a class="toctree-l4" href="#symmetric-linear-quantization">Symmetric Linear Quantization</a></li>
...@@ -195,6 +197,10 @@ ...@@ -195,6 +197,10 @@
<li>Gradients quantization as proposed in the paper is not supported yet.</li> <li>Gradients quantization as proposed in the paper is not supported yet.</li>
<li>The paper defines special handling for binary weights which isn't supported in Distiller yet.</li> <li>The paper defines special handling for binary weights which isn't supported in Distiller yet.</li>
</ul> </ul>
<h2 id="pact">PACT</h2>
<p>(As proposed in <a href="https://arxiv.org/abs/1805.06085">PACT: Parameterized Clipping Activation for Quantized Neural Networks</a>)</p>
<p>This method is similar to DoReFa, but the upper clipping values, <script type="math/tex">\alpha</script>, of the activation functions are learned parameters instead of hard coded to 1. Note that per the paper's recommendation, <script type="math/tex">\alpha</script> is shared per layer.</p>
<p>This method requires training the model with quantization, as discussed <a href="../quantization/#training-with-quantization">here</a>. Use the <code>PACTQuantizer</code> class to transform an existing model to a model suitable for training with quantization using PACT.</p>
<h2 id="wrpn">WRPN</h2> <h2 id="wrpn">WRPN</h2>
<p>(As proposed in <a href="https://arxiv.org/abs/1709.01134">WRPN: Wide Reduced-Precision Networks</a>) </p> <p>(As proposed in <a href="https://arxiv.org/abs/1709.01134">WRPN: Wide Reduced-Precision Networks</a>) </p>
<p>In this method, activations are clipped to <script type="math/tex">[0, 1]</script> and quantized as follows (<script type="math/tex">k</script> is the number of bits used for quantization):</p> <p>In this method, activations are clipped to <script type="math/tex">[0, 1]</script> and quantized as follows (<script type="math/tex">k</script> is the number of bits used for quantization):</p>
......
...@@ -231,15 +231,29 @@ To execute the model transformation, call the <code>prepare_model</code> functio ...@@ -231,15 +231,29 @@ To execute the model transformation, call the <code>prepare_model</code> functio
<h3 id="weights-quantization">Weights Quantization</h3> <h3 id="weights-quantization">Weights Quantization</h3>
<p>The <code>Quantizer</code> class also provides an API to quantize the weights of all layers at once. To use it, the <code>param_quantization_fn</code> attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the <code>Quantizer</code> class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the <code>quantize_params</code> function can be called, which will iterate over all parameters and quantize them using <code>params_quantization_fn</code>.</p> <p>The <code>Quantizer</code> class also provides an API to quantize the weights of all layers at once. To use it, the <code>param_quantization_fn</code> attribute needs to point to a function that accepts a tensor and the number of bits. During model transformation, the <code>Quantizer</code> class will build a list of all model parameters that need to be quantized along with their bit-width. Then, the <code>quantize_params</code> function can be called, which will iterate over all parameters and quantize them using <code>params_quantization_fn</code>.</p>
<h3 id="training-with-quantization">Training with Quantization</h3> <h3 id="training-with-quantization">Training with Quantization</h3>
<p>The <code>Quantizer</code> class supports training with quantization in the loop, as described <a href="../quantization/index.html#training-with-quantization">here</a>. This is enabled by setting <code>train_with_fp_copy=True</code> in the <code>Quantizer</code> constructor. At model transformation, in each module that has parameters that should be quantized, a new <code>torch.nn.Parameter</code> is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module <strong>is not</strong> created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack":</p> <p>The <code>Quantizer</code> class supports training with quantization in the loop. This requires handling of a couple of flows / scenarios:</p>
<ol>
<li>
<p>Maintaining a full precision copy of the weights, as described <a href="../quantization/index.html#training-with-quantization">here</a>. This is enabled by setting <code>train_with_fp_copy=True</code> in the <code>Quantizer</code> constructor. At model transformation, in each module that has parameters that should be quantized, a new <code>torch.nn.Parameter</code> is added, which will maintain the required full precision copy of the parameters. Note that this is done in-place - a new module <strong>is not</strong> created. We preferred not to sub-class the existing PyTorch modules for this purpose. In order to this in-place, and also guarantee proper back-propagation through the weights quantization function, we employ the following "hack": </p>
<ol> <ol>
<li>The existing <code>torch.nn.Parameter</code>, e.g. <code>weights</code>, is replaced by a <code>torch.nn.Parameter</code> named <code>float_weight</code>.</li> <li>The existing <code>torch.nn.Parameter</code>, e.g. <code>weights</code>, is replaced by a <code>torch.nn.Parameter</code> named <code>float_weight</code>.</li>
<li>To maintain the existing functionality of the module, we then register a <code>buffer</code> in the module with the original name - <code>weights</code>.</li> <li>To maintain the existing functionality of the module, we then register a <code>buffer</code> in the module with the original name - <code>weights</code>.</li>
<li>During training, <code>float_weight</code> will be passed to <code>param_quantization_fn</code> and the result will be stored in <code>weight</code>.</li> <li>During training, <code>float_weight</code> will be passed to <code>param_quantization_fn</code> and the result will be stored in <code>weight</code>.</li>
</ol> </ol>
<p><strong>Important Note</strong>: Since this process modifies the model's parameters, it must be done <strong>before</strong> an PyTorch <code>Optimizer</code> is created (this refers to any of the sub-classes defined <a href="https://pytorch.org/docs/stable/optim.html#algorithms">here</a>).</p> </li>
<li>
<p>In addition, some quantization methods may introduce additional learned parameters to the model. For example, in the <a href="../algo_quantization/index.html#PACT">PACT</a> method, acitvations are clipped to a value <script type="math/tex">\alpha</script>, which is a learned parameter per-layer</p>
</li>
</ol>
<p>To support these two cases, the <code>Quantizer</code> class also accepts an instance of a <code>torch.optim.Optimizer</code> (normally this would be one an instance of its sub-classes). The quantizer will take care of modifying the optimizer according to the changes made to the parameters. </p>
<div class="admonition note">
<p class="admonition-title">Optimizing New Parameters</p>
<p>In cases where new parameters are required by the scheme, it is likely that they'll need to be optimized separately from the main model parameters. In that case, the sub-class for the speicifc method should override <code>Quantizer._get_updated_optimizer_params_groups()</code>, and return the proper groups plus any desired hyper-parameter overrides.</p>
</div>
<h3 id="examples">Examples</h3>
<p>The base <code>Quantizer</code> class is implemented in <code>distiller/quantization/quantizer.py</code>.<br /> <p>The base <code>Quantizer</code> class is implemented in <code>distiller/quantization/quantizer.py</code>.<br />
For a simple sub-class implementing symmetric linear quantization, see <code>SymmetricLinearQuantizer</code> in <code>distiller/quantization/range_linear.py</code>. For examples of lower-precision methods using training with quantization see <code>DorefaQuantizer</code> and <code>WRPNQuantizer</code> in <code>distiller/quantization/clipped_linear.py</code></p> For a simple sub-class implementing symmetric linear quantization, see <code>SymmetricLinearQuantizer</code> in <code>distiller/quantization/range_linear.py</code>.<br />
In <code>distiller/quantization/clipped_linear.py</code> there are examples of lower-precision methods which use training with quantization. Specifically, see <code>PACTQuantizer</code> for an example of overriding <code>Quantizer._get_updated_optimizer_params_groups()</code>.</p>
</div> </div>
</div> </div>
......
...@@ -246,5 +246,5 @@ And of course, if we used a sparse or compressed representation, then we are red ...@@ -246,5 +246,5 @@ And of course, if we used a sparse or compressed representation, then we are red
<!-- <!--
MkDocs version : 0.17.2 MkDocs version : 0.17.2
Build Date UTC : 2018-07-17 10:59:29 Build Date UTC : 2018-07-22 11:48:56
--> -->
This diff is collapsed.
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
<url> <url>
<loc>/index.html</loc> <loc>/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
<url> <url>
<loc>/install/index.html</loc> <loc>/install/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
<url> <url>
<loc>/usage/index.html</loc> <loc>/usage/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
<url> <url>
<loc>/schedule/index.html</loc> <loc>/schedule/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -37,19 +37,19 @@ ...@@ -37,19 +37,19 @@
<url> <url>
<loc>/pruning/index.html</loc> <loc>/pruning/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>/regularization/index.html</loc> <loc>/regularization/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>/quantization/index.html</loc> <loc>/quantization/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -59,13 +59,13 @@ ...@@ -59,13 +59,13 @@
<url> <url>
<loc>/algo_pruning/index.html</loc> <loc>/algo_pruning/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>/algo_quantization/index.html</loc> <loc>/algo_quantization/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
<url> <url>
<loc>/model_zoo/index.html</loc> <loc>/model_zoo/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -82,7 +82,7 @@ ...@@ -82,7 +82,7 @@
<url> <url>
<loc>/jupyter/index.html</loc> <loc>/jupyter/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
...@@ -90,7 +90,7 @@ ...@@ -90,7 +90,7 @@
<url> <url>
<loc>/design/index.html</loc> <loc>/design/index.html</loc>
<lastmod>2018-07-17</lastmod> <lastmod>2018-07-22</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
......
...@@ -312,8 +312,11 @@ def main(): ...@@ -312,8 +312,11 @@ def main():
# The main use-case for this sample application is CNN compression. Compression # The main use-case for this sample application is CNN compression. Compression
# requires a compression schedule configuration file in YAML. # requires a compression schedule configuration file in YAML.
compression_scheduler = distiller.file_config(model, optimizer, args.compress) compression_scheduler = distiller.file_config(model, optimizer, args.compress)
# Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
model.cuda()
best_epoch = start_epoch best_epoch = start_epoch
for epoch in range(start_epoch, start_epoch + args.epochs): for epoch in range(start_epoch, start_epoch + args.epochs):
# This is the main training loop. # This is the main training loop.
msglogger.info('\n') msglogger.info('\n')
......
# time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1
# --epochs 200 --compress=../quantization/preact_resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0
#2018-07-18 12:25:56,477 - --- validate (epoch=199)-----------
#2018-07-18 12:25:56,477 - 10000 samples (128 per mini-batch)
#2018-07-18 12:25:57,810 - Epoch: [199][ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625
#2018-07-18 12:25:58,402 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307
#
#2018-07-18 12:25:58,404 - ==> Best validation Top1: 92.560 Epoch: 127
#2018-07-18 12:25:58,404 - Saving checkpoint to: logs/checkpoint.pth.tar
#2018-07-18 12:25:58,418 - --- test ---------------------
#2018-07-18 12:25:58,418 - 10000 samples (128 per mini-batch)
#2018-07-18 12:25:59,664 - Test: [ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625
#2018-07-18 12:26:00,248 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307
lr_schedulers: lr_schedulers:
training_lr: training_lr:
class: MultiStepMultiGammaLR class: MultiStepMultiGammaLR
......
# time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1
# --epochs 200 --compress=../quantization/preact_resnet20_cifar_pact.yaml --out-dir="logs/" --wd=0.0002 --vs=0
#2018-07-18 17:28:56,710 - --- validate (epoch=199)-----------
#2018-07-18 17:28:56,710 - 10000 samples (128 per mini-batch)
#2018-07-18 17:28:58,070 - Epoch: [199][ 50/ 78] Loss 0.349229 Top1 91.140625 Top5 99.671875
#2018-07-18 17:28:58,670 - ==> Top1: 91.440 Top5: 99.680 Loss: 0.348
#
#2018-07-18 17:28:58,671 - ==> Best validation Top1: 91.860 Epoch: 147
#2018-07-18 17:28:58,672 - Saving checkpoint to: logs/checkpoint.pth.tar
#2018-07-18 17:28:58,687 - --- test ---------------------
#2018-07-18 17:28:58,687 - 10000 samples (128 per mini-batch)
#2018-07-18 17:29:00,006 - Test: [ 50/ 78] Loss 0.349229 Top1 91.140625 Top5 99.671875
#2018-07-18 17:29:00,560 - ==> Top1: 91.440 Top5: 99.680 Loss: 0.348
quantizers:
pact_quantizer:
class: PACTQuantizer
act_clip_init_val: 8.0
bits_activations: 4
bits_weights: 3
bits_overrides:
# Don't quantize first and last layers
conv1:
wts: null
acts: null
layer1.0.pre_relu:
wts: null
acts: null
final_relu:
wts: null
acts: null
fc:
wts: null
acts: null
lr_schedulers:
training_lr:
class: MultiStepLR
milestones: [60, 120]
gammas: 0.1
policies:
- quantizer:
instance_name: pact_quantizer
starting_epoch: 0
ending_epoch: 200
frequency: 1
- lr_scheduler:
instance_name: training_lr
starting_epoch: 0
ending_epoch: 121
frequency: 1
...@@ -106,8 +106,8 @@ test_configs = [ ...@@ -106,8 +106,8 @@ test_configs = [
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
DS_CIFAR, accuracy_checker, [91.620, 99.630]), DS_CIFAR, accuracy_checker, [91.620, 99.630]),
TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_dorefa_test.yaml')), format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
DS_CIFAR, accuracy_checker, [41.190, 91.160]) DS_CIFAR, accuracy_checker, [48.290, 94.460])
] ]
......
quantizers: quantizers:
dorefa_quantizer: pact_quantizer:
class: DorefaQuantizer class: PACTQuantizer
bits_activations: 8 act_clip_init_val: 8.0
bits_activations: 4
bits_weights: 3 bits_weights: 3
bits_overrides: bits_overrides:
# Don't quantize first and last layer # Don't quantize first and last layers
conv1: conv1:
wts: null wts: null
acts: null acts: null
...@@ -20,13 +21,13 @@ quantizers: ...@@ -20,13 +21,13 @@ quantizers:
lr_schedulers: lr_schedulers:
training_lr: training_lr:
class: MultiStepMultiGammaLR class: MultiStepLR
milestones: [80, 120, 160] milestones: [60, 120]
gammas: [0.1, 0.1, 0.2] gammas: 0.1
policies: policies:
- quantizer: - quantizer:
instance_name: dorefa_quantizer instance_name: pact_quantizer
starting_epoch: 0 starting_epoch: 0
ending_epoch: 200 ending_epoch: 200
frequency: 1 frequency: 1
...@@ -34,5 +35,5 @@ policies: ...@@ -34,5 +35,5 @@ policies:
- lr_scheduler: - lr_scheduler:
instance_name: training_lr instance_name: training_lr
starting_epoch: 0 starting_epoch: 0
ending_epoch: 161 ending_epoch: 121
frequency: 1 frequency: 1
...@@ -95,9 +95,9 @@ def dummy_quantize_params(param, num_bits): ...@@ -95,9 +95,9 @@ def dummy_quantize_params(param, num_bits):
class DummyQuantizer(Quantizer): class DummyQuantizer(Quantizer):
def __init__(self, model, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(), def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_overrides=OrderedDict(),
quantize_bias=False, train_with_fp_copy=False): quantize_bias=False, train_with_fp_copy=False):
super(DummyQuantizer, self).__init__(model, bits_activations, bits_weights, bits_overrides, quantize_bias, super(DummyQuantizer, self).__init__(model, optimizer, bits_activations, bits_weights, bits_overrides, quantize_bias,
train_with_fp_copy) train_with_fp_copy)
self.replacement_factory[nn.Conv2d] = lambda module, name, qbits_map: DummyWrapperLayer(module, qbits_map[name]) self.replacement_factory[nn.Conv2d] = lambda module, name, qbits_map: DummyWrapperLayer(module, qbits_map[name])
...@@ -143,6 +143,12 @@ def fixture_model(): ...@@ -143,6 +143,12 @@ def fixture_model():
return DummyModel() return DummyModel()
# TODO: Test optimizer modifications in 'test_model_prep'
@pytest.fixture(name='optimizer')
def fixture_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1)
@pytest.fixture(name='train_with_fp_copy', params=[False, True], ids=['fp_copy_off', 'fp_copy_on']) @pytest.fixture(name='train_with_fp_copy', params=[False, True], ids=['fp_copy_off', 'fp_copy_on'])
def fixture_train_with_fp_copy(request): def fixture_train_with_fp_copy(request):
return request.param return request.param
...@@ -213,7 +219,7 @@ def test_overrides_ordered_dict(model): ...@@ -213,7 +219,7 @@ def test_overrides_ordered_dict(model):
# never actually matched # never actually matched
] ]
) )
def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides, def test_model_prep(model, optimizer, qbits, bits_overrides, explicit_expected_overrides,
train_with_fp_copy, quantize_bias, parallel): train_with_fp_copy, quantize_bias, parallel):
if parallel: if parallel:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
...@@ -223,7 +229,7 @@ def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides, ...@@ -223,7 +229,7 @@ def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides,
expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides) expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
# Initialize Quantizer # Initialize Quantizer
q = DummyQuantizer(model, bits_activations=qbits.acts, bits_weights=qbits.wts, q = DummyQuantizer(model, optimizer=optimizer, bits_activations=qbits.acts, bits_weights=qbits.wts,
bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy, bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy,
quantize_bias=quantize_bias) quantize_bias=quantize_bias)
...@@ -290,12 +296,12 @@ def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides, ...@@ -290,12 +296,12 @@ def test_model_prep(model, qbits, bits_overrides, explicit_expected_overrides,
'sub1.conv1': QBits(8, 4), 'sub1.conv2': QBits(4, 4), 'sub2.conv1': QBits(8, 4), 'sub2.conv2': QBits(4, 4)}), 'sub1.conv1': QBits(8, 4), 'sub1.conv2': QBits(4, 4), 'sub2.conv1': QBits(8, 4), 'sub2.conv2': QBits(4, 4)}),
] ]
) )
def test_param_quantization(model, qbits, bits_overrides, explicit_expected_overrides, def test_param_quantization(model, optimizer, qbits, bits_overrides, explicit_expected_overrides,
train_with_fp_copy, quantize_bias): train_with_fp_copy, quantize_bias):
# Build expected QBits # Build expected QBits
expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides) expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
q = DummyQuantizer(model, bits_activations=qbits.acts, bits_weights=qbits.wts, q = DummyQuantizer(model, optimizer=optimizer, bits_activations=qbits.acts, bits_weights=qbits.wts,
bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy, bits_overrides=deepcopy(bits_overrides), train_with_fp_copy=train_with_fp_copy,
quantize_bias=quantize_bias) quantize_bias=quantize_bias)
q.prepare_model() q.prepare_model()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment