Skip to content
Snippets Groups Projects
Commit 5271625a authored by Lev Zlotnik's avatar Lev Zlotnik Committed by Guy Jacob
Browse files

Quantizer: Specify # bias bits + custom overrides (BREAKING) (#178)

* Bias handling:
  * Add 'bits_bias' parameter to explicitly specify # of bits for bias,
    similar to weights and activations.
  * BREAKING: Remove the now redundant 'quantize_bias' boolean parameter
* Custom overrides:
  * Expand the semantics of the overrides dict to allow overriding of
    other parameters in addition to bit-widths
  * Functions registered in the quantizer's 'replacement_factory' can
    define keyword arguments. Non bit-width entries in the overrides
    dict will be checked against the function signature and passed
  * BREAKING:
    * Changed the name of 'bits_overrides' to simply 'overrides'
    * Bit-width overrides must now be defined using the full parameter
      names - 'bits_activations/weights/bias' instead of the short-hands
      'acts' and 'wts' which were used so far.
  * Added/updated relevant tests
  * Modified all quantization YAMLs under 'examples' to reflect 
    these changes
  * Updated docs
parent acaf477b
No related branches found
No related tags found
No related merge requests found
Showing
with 246 additions and 199 deletions
...@@ -36,7 +36,6 @@ import logging ...@@ -36,7 +36,6 @@ import logging
from collections import OrderedDict from collections import OrderedDict
import yaml import yaml
import json import json
import inspect
from torch.optim.lr_scheduler import * from torch.optim.lr_scheduler import *
import distiller import distiller
from distiller.thinning import * from distiller.thinning import *
...@@ -44,6 +43,7 @@ from distiller.pruning import * ...@@ -44,6 +43,7 @@ from distiller.pruning import *
from distiller.regularization import * from distiller.regularization import *
from distiller.learning_rate import * from distiller.learning_rate import *
from distiller.quantization import * from distiller.quantization import *
from distiller.utils import filter_kwargs
msglogger = logging.getLogger() msglogger = logging.getLogger()
app_cfg_logger = logging.getLogger("app_cfg") app_cfg_logger = logging.getLogger("app_cfg")
...@@ -196,7 +196,7 @@ def build_component(model, name, user_args, **extra_args): ...@@ -196,7 +196,7 @@ def build_component(model, name, user_args, **extra_args):
raise ValueError("Class named '{0}' does not exist".format(class_name)) from ex raise ValueError("Class named '{0}' does not exist".format(class_name)) from ex
# First we check that the user defined dict itself does not contain invalid args # First we check that the user defined dict itself does not contain invalid args
valid_args, invalid_args = __filter_kwargs(user_args, class_.__init__) valid_args, invalid_args = filter_kwargs(user_args, class_.__init__)
if invalid_args: if invalid_args:
raise ValueError( raise ValueError(
'{0} does not accept the following arguments: {1}'.format(class_name, list(invalid_args.keys()))) '{0} does not accept the following arguments: {1}'.format(class_name, list(invalid_args.keys())))
...@@ -206,31 +206,11 @@ def build_component(model, name, user_args, **extra_args): ...@@ -206,31 +206,11 @@ def build_component(model, name, user_args, **extra_args):
valid_args.update(extra_args) valid_args.update(extra_args)
valid_args['model'] = model valid_args['model'] = model
valid_args['name'] = name valid_args['name'] = name
final_valid_args, _ = __filter_kwargs(valid_args, class_.__init__) final_valid_args, _ = filter_kwargs(valid_args, class_.__init__)
instance = class_(**final_valid_args) instance = class_(**final_valid_args)
return instance return instance
def __filter_kwargs(dict_to_filter, function_to_call):
"""Utility to check which arguments in the passed dictionary exist in a function's signature
The function returns two dicts, one with just the valid args from the input and one with the invalid args.
The caller can then decide to ignore the existence of invalid args, depending on context.
"""
sig = inspect.signature(function_to_call)
filter_keys = [param.name for param in sig.parameters.values() if (param.kind == param.POSITIONAL_OR_KEYWORD)]
valid_args = {}
invalid_args = {}
for key in dict_to_filter:
if key in filter_keys:
valid_args[key] = dict_to_filter[key]
else:
invalid_args[key] = dict_to_filter[key]
return valid_args, invalid_args
def __policy_params(policy_def, type): def __policy_params(policy_def, type):
name = policy_def[type]['instance_name'] name = policy_def[type]['instance_name']
args = policy_def[type].get('args', None) args = policy_def[type].get('args', None)
......
...@@ -103,11 +103,12 @@ class WRPNQuantizer(Quantizer): ...@@ -103,11 +103,12 @@ class WRPNQuantizer(Quantizer):
1. This class does not take care of layer widening as described in the paper 1. This class does not take care of layer widening as described in the paper
2. The paper defines special handling for 1-bit weights which isn't supported here yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet
""" """
def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None, def __init__(self, model, optimizer,
quantize_bias=False): bits_activations=32, bits_weights=32, bits_bias=None,
overrides=None):
super(WRPNQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations, super(WRPNQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
bits_weights=bits_weights, bits_overrides=bits_overrides, bits_weights=bits_weights, bits_bias=bits_bias,
train_with_fp_copy=True, quantize_bias=quantize_bias) train_with_fp_copy=True, overrides=overrides)
def wrpn_quantize_param(param_fp, param_meta): def wrpn_quantize_param(param_fp, param_meta):
scale, zero_point = symmetric_linear_quantization_params(param_meta.num_bits, 1) scale, zero_point = symmetric_linear_quantization_params(param_meta.num_bits, 1)
...@@ -159,11 +160,12 @@ class DorefaQuantizer(Quantizer): ...@@ -159,11 +160,12 @@ class DorefaQuantizer(Quantizer):
1. Gradients quantization not supported yet 1. Gradients quantization not supported yet
2. The paper defines special handling for 1-bit weights which isn't supported here yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet
""" """
def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None, def __init__(self, model, optimizer,
quantize_bias=False): bits_activations=32, bits_weights=32, bits_bias=None,
overrides=None):
super(DorefaQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations, super(DorefaQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
bits_weights=bits_weights, bits_overrides=bits_overrides, bits_weights=bits_weights, bits_bias=bits_bias,
train_with_fp_copy=True, quantize_bias=quantize_bias) train_with_fp_copy=True, overrides=overrides)
def relu_replace_fn(module, name, qbits_map): def relu_replace_fn(module, name, qbits_map):
bits_acts = qbits_map[name].acts bits_acts = qbits_map[name].acts
...@@ -188,11 +190,12 @@ class PACTQuantizer(Quantizer): ...@@ -188,11 +190,12 @@ class PACTQuantizer(Quantizer):
act_clip_decay (float): L2 penalty applied to the clipping values, referred to as "lambda_alpha" in the paper. act_clip_decay (float): L2 penalty applied to the clipping values, referred to as "lambda_alpha" in the paper.
If None then the optimizer's default weight decay value is used (default: None) If None then the optimizer's default weight decay value is used (default: None)
""" """
def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None, def __init__(self, model, optimizer,
quantize_bias=False, act_clip_init_val=8.0, act_clip_decay=None): bits_activations=32, bits_weights=32, bits_bias=None,
overrides=None, act_clip_init_val=8.0, act_clip_decay=None):
super(PACTQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations, super(PACTQuantizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
bits_weights=bits_weights, bits_overrides=bits_overrides, bits_weights=bits_weights, bits_bias=bits_bias,
train_with_fp_copy=True, quantize_bias=quantize_bias) overrides=overrides, train_with_fp_copy=True)
def relu_replace_fn(module, name, qbits_map): def relu_replace_fn(module, name, qbits_map):
bits_acts = qbits_map[name].acts bits_acts = qbits_map[name].acts
......
...@@ -21,10 +21,11 @@ import logging ...@@ -21,10 +21,11 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import distiller import distiller
import warnings
msglogger = logging.getLogger() msglogger = logging.getLogger()
QBits = namedtuple('QBits', ['acts', 'wts']) QBits = namedtuple('QBits', ['acts', 'wts', 'bias'])
FP_BKP_PREFIX = 'float_' FP_BKP_PREFIX = 'float_'
...@@ -62,6 +63,9 @@ class _ParamToQuant(object): ...@@ -62,6 +63,9 @@ class _ParamToQuant(object):
self.q_attr_name = q_attr_name self.q_attr_name = q_attr_name
self.num_bits = num_bits self.num_bits = num_bits
def __repr__(self):
return "ParamToQuant(module_name=%s,num_bits=%s)" % (self.module_name, self.num_bits)
class Quantizer(object): class Quantizer(object):
r""" r"""
...@@ -72,16 +76,21 @@ class Quantizer(object): ...@@ -72,16 +76,21 @@ class Quantizer(object):
optimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is going optimizer (torch.optim.Optimizer): An optimizer instance, required in cases where the quantizer is going
to perform changes to existing model parameters and/or add new ones. to perform changes to existing model parameters and/or add new ones.
Specifically, when train_with_fp_copy is True, this cannot be None. Specifically, when train_with_fp_copy is True, this cannot be None.
bits_activations/weights (int): Default number of bits to use when quantizing each tensor type. bits_activations/weights/bias (int): Default number of bits to use when quantizing each tensor type.
Value of None means do not quantize. Value of None means do not quantize.
bits_overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with overrides (OrderedDict): Dictionary mapping regular expressions of layer name patterns to dictionary with
values for 'acts' and/or 'wts' to override the default values. overrides of default values.
The keys in the overrides dictionary should be parameter names that the Quantizer accepts default values
for in its init function.
The parameters 'bits_activations', 'bits_weights', and 'bits_bias' which are accepted by the base Quantizer
are supported by default.
Other than those, each sub-class of Quantizer defines the set of parameter for which it supports
over-riding.
OrderedDict is used to enable handling of overlapping name patterns. So, for example, one could define OrderedDict is used to enable handling of overlapping name patterns. So, for example, one could define
certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for
specific layers in that group, e.g. 'conv1'. specific layers in that group, e.g. 'conv1'.
The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must
come before the broad patterns. come before the broad patterns.
quantize_bias (bool): Flag indicating whether to quantize bias (w. same number of bits as weights) or not.
train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized and train_with_fp_copy (bool): If true, will modify layers with weights to keep both a quantized and
floating-point copy, such that the following flow occurs in each training iteration: floating-point copy, such that the following flow occurs in each training iteration:
1. q_weights = quantize(fp_weights) 1. q_weights = quantize(fp_weights)
...@@ -91,18 +100,19 @@ class Quantizer(object): ...@@ -91,18 +100,19 @@ class Quantizer(object):
3.2 We also back-prop through the 'quantize' operation from step 1 3.2 We also back-prop through the 'quantize' operation from step 1
4. Update fp_weights with gradients calculated in step 3.2 4. Update fp_weights with gradients calculated in step 3.2
""" """
def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_overrides=None, def __init__(self, model, optimizer=None,
quantize_bias=False, train_with_fp_copy=False): bits_activations=None, bits_weights=None, bits_bias=None,
if bits_overrides is None: overrides=None, train_with_fp_copy=False):
bits_overrides = OrderedDict() if overrides is None:
if not isinstance(bits_overrides, OrderedDict): overrides = OrderedDict()
raise TypeError('bits_overrides must be an instance of collections.OrderedDict or None') if not isinstance(overrides, OrderedDict):
raise TypeError('overrides must be an instance of collections.OrderedDict or None')
if train_with_fp_copy and optimizer is None: if train_with_fp_copy and optimizer is None:
raise ValueError('optimizer cannot be None when train_with_fp_copy is True') raise ValueError('optimizer cannot be None when train_with_fp_copy is True')
self.default_qbits = QBits(acts=bits_activations, wts=bits_weights) self.default_qbits = QBits(acts=bits_activations, wts=bits_weights, bias=bits_bias)
self.quantize_bias = quantize_bias self.overrides = overrides
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
...@@ -111,36 +121,47 @@ class Quantizer(object): ...@@ -111,36 +121,47 @@ class Quantizer(object):
self.model.quantizer_metadata = {'type': type(self), self.model.quantizer_metadata = {'type': type(self),
'params': {'bits_activations': bits_activations, 'params': {'bits_activations': bits_activations,
'bits_weights': bits_weights, 'bits_weights': bits_weights,
'bits_overrides': copy.deepcopy(bits_overrides), 'bits_bias': bits_bias,
'quantize_bias': quantize_bias}} 'overrides': copy.deepcopy(overrides)}}
for k, v in bits_overrides.items(): for k, v in self.overrides.items():
qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts)) if any(old_bits_key in v.keys() for old_bits_key in ['acts', 'wts', 'bias']):
bits_overrides[k] = qbits raise ValueError("Using 'acts' / 'wts' / 'bias' to specify bit-width overrides is deprecated.\n"
"Please use the full parameter names: "
"'bits_activations' / 'bits_weights' / 'bits_bias'")
qbits = QBits(acts=v.pop('bits_activations', self.default_qbits.acts),
wts=v.pop('bits_weights', self.default_qbits.wts),
bias=v.pop('bits_bias', self.default_qbits.bias))
v['bits'] = qbits
# Prepare explicit mapping from each layer to QBits based on default + overrides # Prepare explicit mapping from each layer to QBits based on default + overrides
patterns = [] patterns = []
regex = None regex_overrides = None
if bits_overrides: if overrides:
patterns = list(bits_overrides.keys()) patterns = list(overrides.keys())
regex_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns]) regex_overrides_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns])
regex = re.compile(regex_str) regex_overrides = re.compile(regex_overrides_str)
self.module_qbits_map = {} self.module_qbits_map = {}
self.module_overrides_map = {}
for module_full_name, module in model.named_modules(): for module_full_name, module in model.named_modules():
# Need to account for scenario where model is parallelized with DataParallel, which wraps the original
# module with a wrapper module called 'module' :)
name_to_match = module_full_name.replace('module.', '', 1)
qbits = self.default_qbits qbits = self.default_qbits
if regex: override_entry = self.overrides.get(name_to_match, OrderedDict())
# Need to account for scenario where model is parallelized with DataParallel, which wraps the original if regex_overrides:
# module with a wrapper module called 'module' :) m_overrides = regex_overrides.match(name_to_match)
name_to_match = module_full_name.replace('module.', '', 1) if m_overrides:
m = regex.match(name_to_match)
if m:
group_idx = 0 group_idx = 0
groups = m.groups() groups = m_overrides.groups()
while groups[group_idx] is None: while groups[group_idx] is None:
group_idx += 1 group_idx += 1
qbits = bits_overrides[patterns[group_idx]] override_entry = copy.deepcopy(override_entry or self.overrides[patterns[group_idx]])
qbits = override_entry.pop('bits', self.default_qbits)
self._add_qbits_entry(module_full_name, type(module), qbits) self._add_qbits_entry(module_full_name, type(module), qbits)
self._add_override_entry(module_full_name, override_entry)
# Mapping from module type to function generating a replacement module suited for quantization # Mapping from module type to function generating a replacement module suited for quantization
# To be populated by child classes # To be populated by child classes
...@@ -156,9 +177,12 @@ class Quantizer(object): ...@@ -156,9 +177,12 @@ class Quantizer(object):
if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]: if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]:
# For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't
# support quantization of batch norm scale parameters) # support quantization of batch norm scale parameters)
qbits = QBits(acts=qbits.acts, wts=None) qbits = QBits(acts=qbits.acts, wts=None, bias=None)
self.module_qbits_map[module_name] = qbits self.module_qbits_map[module_name] = qbits
def _add_override_entry(self, module_name, entry):
self.module_overrides_map[module_name] = entry
def prepare_model(self): def prepare_model(self):
self._prepare_model_impl() self._prepare_model_impl()
...@@ -180,13 +204,8 @@ class Quantizer(object): ...@@ -180,13 +204,8 @@ class Quantizer(object):
curr_parameters = dict(module.named_parameters()) curr_parameters = dict(module.named_parameters())
for param_name, param in curr_parameters.items(): for param_name, param in curr_parameters.items():
# Bias is usually quantized according to the accumulator's number of bits # Bias is usually quantized according to the accumulator's number of bits
# Temporary hack: Assume that number is 32 bits and hard-code it here # Handle # of bits for bias quantization as "first-class" citizen, similarly to weights
# TODO: Handle # of bits for bias quantization as "first-class" citizen, similarly to weights n_bits = qbits.bias if param_name.endswith('bias') else qbits.wts
n_bits = qbits.wts
if param_name.endswith('bias'):
if not self.quantize_bias:
continue
n_bits = 32
fp_attr_name = param_name fp_attr_name = param_name
if self.train_with_fp_copy: if self.train_with_fp_copy:
hack_float_backup_parameter(module, param_name, n_bits) hack_float_backup_parameter(module, param_name, n_bits)
...@@ -209,9 +228,18 @@ class Quantizer(object): ...@@ -209,9 +228,18 @@ class Quantizer(object):
full_name = prefix + name full_name = prefix + name
current_qbits = self.module_qbits_map[full_name] current_qbits = self.module_qbits_map[full_name]
if current_qbits.acts is None and current_qbits.wts is None: if current_qbits.acts is None and current_qbits.wts is None:
if self.module_overrides_map[full_name]:
raise ValueError("Adding overrides while not quantizing is not allowed.")
continue continue
try: try:
new_module = self.replacement_factory[type(module)](module, full_name, self.module_qbits_map) replace_fn = self.replacement_factory[type(module)]
valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], replace_fn)
if invalid_kwargs:
raise TypeError("""Quantizer of type %s doesn't accept \"%s\"
as override arguments for %s. Allowed kwargs: %s"""
% (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
new_module = self.replacement_factory[type(module)](module, full_name,
self.module_qbits_map, **valid_kwargs)
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module)) msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module))
setattr(container, name, new_module) setattr(container, name, new_module)
...@@ -219,7 +247,7 @@ class Quantizer(object): ...@@ -219,7 +247,7 @@ class Quantizer(object):
if not distiller.has_children(module) and distiller.has_children(new_module): if not distiller.has_children(module) and distiller.has_children(new_module):
for sub_module_name, sub_module in new_module.named_modules(): for sub_module_name, sub_module in new_module.named_modules():
self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits) self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None) self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
except KeyError: except KeyError:
pass pass
......
...@@ -615,12 +615,12 @@ class PostTrainLinearQuantizer(Quantizer): ...@@ -615,12 +615,12 @@ class PostTrainLinearQuantizer(Quantizer):
model_activation_stats (str / dict / OrderedDict): Either a path to activation stats YAML file, or a dictionary model_activation_stats (str / dict / OrderedDict): Either a path to activation stats YAML file, or a dictionary
containing the stats. If None then stats will be calculated dynamically. containing the stats. If None then stats will be calculated dynamically.
""" """
def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, bits_overrides=None, def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32,
mode=LinearQuantMode.SYMMETRIC, clip_acts=False, no_clip_layers=None, per_channel_wts=False, overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=False, no_clip_layers=None,
model_activation_stats=None): per_channel_wts=False, model_activation_stats=None):
super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations, super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations,
bits_weights=bits_parameters, bits_overrides=bits_overrides, bits_weights=bits_parameters, bits_bias=bits_accum,
train_with_fp_copy=False) overrides=overrides, train_with_fp_copy=False)
mode = verify_mode(mode) mode = verify_mode(mode)
...@@ -644,7 +644,7 @@ class PostTrainLinearQuantizer(Quantizer): ...@@ -644,7 +644,7 @@ class PostTrainLinearQuantizer(Quantizer):
def replace_param_layer(module, name, qbits_map): def replace_param_layer(module, name, qbits_map):
norm_name = distiller.utils.normalize_module_name(name) norm_name = distiller.utils.normalize_module_name(name)
clip = self.clip_acts and norm_name not in self.no_clip_layers clip = clip_acts and norm_name not in self.no_clip_layers
return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts, return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts,
num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip, num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip,
per_channel_wts=per_channel_wts, per_channel_wts=per_channel_wts,
...@@ -667,8 +667,10 @@ class PostTrainLinearQuantizer(Quantizer): ...@@ -667,8 +667,10 @@ class PostTrainLinearQuantizer(Quantizer):
self.model_activation_stats = model_activation_stats or {} self.model_activation_stats = model_activation_stats or {}
self.bits_accum = bits_accum self.bits_accum = bits_accum
self.mode = mode self.mode = mode
self.replacement_factory[nn.Conv2d] = replace_param_layer self.replacement_factory[nn.Conv2d] = replace_param_layer
self.replacement_factory[nn.Linear] = replace_param_layer self.replacement_factory[nn.Linear] = replace_param_layer
self.replacement_factory[distiller.modules.Concat] = partial( self.replacement_factory[distiller.modules.Concat] = partial(
replace_non_param_layer, RangeLinearQuantConcatWrapper) replace_non_param_layer, RangeLinearQuantConcatWrapper)
self.replacement_factory[distiller.modules.EltwiseAdd] = partial( self.replacement_factory[distiller.modules.EltwiseAdd] = partial(
...@@ -687,8 +689,15 @@ class PostTrainLinearQuantizer(Quantizer): ...@@ -687,8 +689,15 @@ class PostTrainLinearQuantizer(Quantizer):
return distiller.config_component_from_file_by_class(model, args.qe_config_file, return distiller.config_component_from_file_by_class(model, args.qe_config_file,
'PostTrainLinearQuantizer') 'PostTrainLinearQuantizer')
else: else:
return cls(model, args.qe_bits_acts, args.qe_bits_wts, args.qe_bits_accum, None, args.qe_mode, return cls(model,
args.qe_clip_acts, args.qe_no_clip_layers, args.qe_per_channel, args.qe_stats_file) bits_activations=args.qe_bits_acts,
bits_parameters=args.qe_bits_wts,
bits_accum=args.qe_bits_accum,
mode=args.qe_mode,
clip_acts=args.qe_clip_acts,
no_clip_layers=args.qe_no_clip_layers,
per_channel_wts=args.qe_per_channel,
model_activation_stats=args.qe_stats_file)
############################################################################### ###############################################################################
...@@ -788,14 +797,14 @@ class FakeQuantizationWrapper(nn.Module): ...@@ -788,14 +797,14 @@ class FakeQuantizationWrapper(nn.Module):
class QuantAwareTrainRangeLinearQuantizer(Quantizer): class QuantAwareTrainRangeLinearQuantizer(Quantizer):
def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_overrides=None, def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_bias=32,
quantize_bias=True, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False, overrides=None, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False,
quantize_inputs=True, num_bits_inputs=None): quantize_inputs=True, num_bits_inputs=None):
super(QuantAwareTrainRangeLinearQuantizer, self).__init__(model, optimizer=optimizer, super(QuantAwareTrainRangeLinearQuantizer, self).__init__(model, optimizer=optimizer,
bits_activations=bits_activations, bits_activations=bits_activations,
bits_weights=bits_weights, bits_weights=bits_weights,
bits_overrides=bits_overrides, bits_bias=bits_bias,
quantize_bias=quantize_bias, overrides=overrides,
train_with_fp_copy=True) train_with_fp_copy=True)
if isinstance(model, nn.DataParallel) and len(model.device_ids) > 1: if isinstance(model, nn.DataParallel) and len(model.device_ids) > 1:
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
This module contains various tensor sparsity/density measurement functions, together This module contains various tensor sparsity/density measurement functions, together
with some random helper functions. with some random helper functions.
""" """
import inspect
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -595,6 +597,26 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max= ...@@ -595,6 +597,26 @@ def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max=
return checker return checker
def filter_kwargs(dict_to_filter, function_to_call):
"""Utility to check which arguments in the passed dictionary exist in a function's signature
The function returns two dicts, one with just the valid args from the input and one with the invalid args.
The caller can then decide to ignore the existence of invalid args, depending on context.
"""
sig = inspect.signature(function_to_call)
filter_keys = [param.name for param in sig.parameters.values() if (param.kind == param.POSITIONAL_OR_KEYWORD)]
valid_args = {}
invalid_args = {}
for key in dict_to_filter:
if key in filter_keys:
valid_args[key] = dict_to_filter[key]
else:
invalid_args[key] = dict_to_filter[key]
return valid_args, invalid_args
def convert_tensors_recursively_to(val, *args, **kwargs): def convert_tensors_recursively_to(val, *args, **kwargs):
""" Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same.""" """ Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same."""
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
...@@ -604,3 +626,4 @@ def convert_tensors_recursively_to(val, *args, **kwargs): ...@@ -604,3 +626,4 @@ def convert_tensors_recursively_to(val, *args, **kwargs):
return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val) return type(val)(convert_tensors_recursively_to(item, *args, **kwargs) for item in val)
return val return val
...@@ -92,6 +92,7 @@ For post-training quantization, this method is implemented by wrapping existing ...@@ -92,6 +92,7 @@ For post-training quantization, this method is implemented by wrapping existing
- Element-wise addition - Element-wise addition
- Element-wise multiplication - Element-wise multiplication
- Concatenation - Concatenation
- Embedding
- All other layers are unaffected and are executed using their original FP32 implementation. - All other layers are unaffected and are executed using their original FP32 implementation.
- To automatically transform an existing model to a quantized model using this method, use the `PostTrainLinearQuantizer` class. For details on ways to invoke the quantizer see [here](schedule.md#post-training-quantization). - To automatically transform an existing model to a quantized model using this method, use the `PostTrainLinearQuantizer` class. For details on ways to invoke the quantizer see [here](schedule.md#post-training-quantization).
- The transform performed by the Quantizer only works on sub-classes of `torch.nn.Module`. But operations such as element-wise addition / multiplication and concatenation do not have associated Modules in PyTorch. They are either overloaded operators, or simple functions in the `torch` namespace. To be able to quantize these operations, we've implemented very simple modules that wrap these operations [here](https://github.com/NervanaSystems/distiller/blob/master/distiller/distiller/modules). It is necessary to manually modify your model and replace any existing operator with a corresponding module. For an example, see our slightly modified [ResNet implementation](https://github.com/NervanaSystems/distiller/blob/quantization_updates/models/imagenet/resnet.py). - The transform performed by the Quantizer only works on sub-classes of `torch.nn.Module`. But operations such as element-wise addition / multiplication and concatenation do not have associated Modules in PyTorch. They are either overloaded operators, or simple functions in the `torch` namespace. To be able to quantize these operations, we've implemented very simple modules that wrap these operations [here](https://github.com/NervanaSystems/distiller/blob/master/distiller/distiller/modules). It is necessary to manually modify your model and replace any existing operator with a corresponding module. For an example, see our slightly modified [ResNet implementation](https://github.com/NervanaSystems/distiller/blob/quantization_updates/models/imagenet/resnet.py).
......
...@@ -63,10 +63,10 @@ To execute the model transformation, call the `prepare_model` function of the `Q ...@@ -63,10 +63,10 @@ To execute the model transformation, call the `prepare_model` function of the `Q
### Flexible Bit-Widths ### Flexible Bit-Widths
- Each instance of `Quantizer` is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the `bits_activations` and `bits_weights` parameters in `Quantizer`'s constructor. Sub-classes may define bit-widths for other tensor types as needed. - Each instance of `Quantizer` is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the `bits_activations`, `bits_weights` and `bits_bias` parameters in `Quantizer`'s constructor. Sub-classes may define bit-widths for other tensor types as needed.
- We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks ("container" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern. - We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks ("container" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.
- So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the `bits_overrides` parameter in the constructor. - So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the `overrides` parameter in the constructor.
- The `bits_overrides` mapping is required to be an instance of [`collections.OrderedDict`](https://docs.python.org/3.5/library/collections.html#collections.OrderedDict) (as opposed to just a simple Python [`dict`](https://docs.python.org/3.5/library/stdtypes.html#dict)). This is done in order to enable handling of overlapping name patterns. - The `overrides` mapping is required to be an instance of [`collections.OrderedDict`](https://docs.python.org/3.5/library/collections.html#collections.OrderedDict) (as opposed to just a simple Python [`dict`](https://docs.python.org/3.5/library/stdtypes.html#dict)). This is done in order to enable handling of overlapping name patterns.
So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'. So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.
The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns. The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.
......
...@@ -250,24 +250,24 @@ quantizers: ...@@ -250,24 +250,24 @@ quantizers:
class: DorefaQuantizer class: DorefaQuantizer
bits_activations: 8 bits_activations: 8
bits_weights: 4 bits_weights: 4
bits_overrides: overrides:
conv1: conv1:
wts: null bits_weights: null
acts: null bits_activations: null
relu1: relu1:
wts: null bits_weights: null
acts: null bits_activations: null
final_relu: final_relu:
wts: null bits_weights: null
acts: null bits_activations: null
fc: fc:
wts: null bits_weights: null
acts: null bits_activations: null
``` ```
- The specific quantization method we're instantiating here is `DorefaQuantizer`. - The specific quantization method we're instantiating here is `DorefaQuantizer`.
- Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. - Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively.
- Then, we define the `bits_overrides` mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of `DorefaQuantizer`, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters `conv1`, the first activation layer `relu1`, the last activation layer `final_relu` and the last layer with parameters `fc`. - Then, we define the `overrides` mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of `DorefaQuantizer`, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters `conv1`, the first activation layer `relu1`, the last activation layer `final_relu` and the last layer with parameters `fc`.
- Specifying `null` means "do not quantize". - Specifying `null` means "do not quantize".
- Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers. - Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.
...@@ -276,10 +276,10 @@ quantizers: ...@@ -276,10 +276,10 @@ quantizers:
Suppose we have a sub-module in our model named `block1`, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named `conv1`, `conv2` and so on. In that case we would define the following: Suppose we have a sub-module in our model named `block1`, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named `conv1`, `conv2` and so on. In that case we would define the following:
``` ```
bits_overrides: overrides:
'block1\.conv*': 'block1\.conv*':
wts: 2 bits_weights: 2
acts: null bits_activations: null
``` ```
- **RegEx Note**: Remember that the dot (`.`) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: `\.` - **RegEx Note**: Remember that the dot (`.`) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: `\.`
...@@ -287,13 +287,13 @@ bits_overrides: ...@@ -287,13 +287,13 @@ bits_overrides:
**Overlapping patterns** are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for `block1.conv1`: **Overlapping patterns** are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for `block1.conv1`:
``` ```
bits_overrides: overrides:
'block1\.conv1': 'block1\.conv1':
wts: 4 bits_weights: 4
acts: null bits_activations: null
'block1\.conv*': 'block1\.conv*':
wts: 2 bits_weights: 2
acts: null bits_activations: null
``` ```
- **Important Note**: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed **before** the broad one. - **Important Note**: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed **before** the broad one.
...@@ -390,9 +390,10 @@ if args.quantize_eval: ...@@ -390,9 +390,10 @@ if args.quantize_eval:
# Execute evaluation on model as usual # Execute evaluation on model as usual
``` ```
Note that the command-line arguments don't expose the `bits_overrides` parameter of the quantizer, which allows fine-grained control over how each layer is quantized. To utilize this functionality, configure with a YAML file. Note that the command-line arguments don't expose the `overrides` parameter of the quantizer, which allows fine-grained control over how each layer is quantized. To utilize this functionality, configure with a YAML file.
To see integration of these command line arguments in use, see the [image classification example](https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py). For examples invocations of post-training quantization see [here](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant). To see integration of these command line arguments in use, see the [image classification example](https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py).
For examples invocations of post-training quantization see [here](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant).
### Collecting Statistics for Quantization ### Collecting Statistics for Quantization
......
...@@ -279,6 +279,7 @@ For any of the methods below that require quantization-aware training, please se ...@@ -279,6 +279,7 @@ For any of the methods below that require quantization-aware training, please se
<li>Element-wise addition</li> <li>Element-wise addition</li>
<li>Element-wise multiplication</li> <li>Element-wise multiplication</li>
<li>Concatenation</li> <li>Concatenation</li>
<li>Embedding</li>
</ul> </ul>
</li> </li>
<li>All other layers are unaffected and are executed using their original FP32 implementation.</li> <li>All other layers are unaffected and are executed using their original FP32 implementation.</li>
......
...@@ -249,10 +249,10 @@ Each sub-class of <code>Quantizer</code> should populate the <code>replacement_f ...@@ -249,10 +249,10 @@ Each sub-class of <code>Quantizer</code> should populate the <code>replacement_f
To execute the model transformation, call the <code>prepare_model</code> function of the <code>Quantizer</code> instance.</p> To execute the model transformation, call the <code>prepare_model</code> function of the <code>Quantizer</code> instance.</p>
<h3 id="flexible-bit-widths">Flexible Bit-Widths</h3> <h3 id="flexible-bit-widths">Flexible Bit-Widths</h3>
<ul> <ul>
<li>Each instance of <code>Quantizer</code> is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the <code>bits_activations</code> and <code>bits_weights</code> parameters in <code>Quantizer</code>'s constructor. Sub-classes may define bit-widths for other tensor types as needed.</li> <li>Each instance of <code>Quantizer</code> is parameterized by the number of bits to be used for quantization of different tensor types. The default ones are activations and weights. These are the <code>bits_activations</code>, <code>bits_weights</code> and <code>bits_bias</code> parameters in <code>Quantizer</code>'s constructor. Sub-classes may define bit-widths for other tensor types as needed.</li>
<li>We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks ("container" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.</li> <li>We also want to be able to override the default number of bits mentioned in the bullet above for certain layers. These could be very specific layers. However, many models are comprised of building blocks ("container" modules, such as Sequential) which contain several modules, and it is likely we'll want to override settings for entire blocks, or for a certain module across different blocks. When such building blocks are used, the names of the internal modules usually follow some pattern.</li>
<li>So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the <code>bits_overrides</code> parameter in the constructor.</li> <li>So, for this purpose, Quantizer also accepts a mapping of regular expressions to number of bits. This allows the user to override specific layers using they're exact name, or a group of layers via a regular expression. This mapping is passed via the <code>overrides</code> parameter in the constructor.</li>
<li>The <code>bits_overrides</code> mapping is required to be an instance of <a href="https://docs.python.org/3.5/library/collections.html#collections.OrderedDict"><code>collections.OrderedDict</code></a> (as opposed to just a simple Python <a href="https://docs.python.org/3.5/library/stdtypes.html#dict"><code>dict</code></a>). This is done in order to enable handling of overlapping name patterns.<br /> <li>The <code>overrides</code> mapping is required to be an instance of <a href="https://docs.python.org/3.5/library/collections.html#collections.OrderedDict"><code>collections.OrderedDict</code></a> (as opposed to just a simple Python <a href="https://docs.python.org/3.5/library/stdtypes.html#dict"><code>dict</code></a>). This is done in order to enable handling of overlapping name patterns.<br />
So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.<br /> So, for example, one could define certain override parameters for a group of layers, e.g. 'conv*', but also define different parameters for specific layers in that group, e.g. 'conv1'.<br />
The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.</li> The patterns are evaluated eagerly - the first match wins. Therefore, the more specific patterns must come before the broad patterns.</li>
</ul> </ul>
......
...@@ -273,5 +273,5 @@ And of course, if we used a sparse or compressed representation, then we are red ...@@ -273,5 +273,5 @@ And of course, if we used a sparse or compressed representation, then we are red
<!-- <!--
MkDocs version : 1.0.4 MkDocs version : 1.0.4
Build Date UTC : 2019-03-28 17:45:12 Build Date UTC : 2019-04-01 14:59:11
--> -->
...@@ -440,47 +440,47 @@ policies: ...@@ -440,47 +440,47 @@ policies:
class: DorefaQuantizer class: DorefaQuantizer
bits_activations: 8 bits_activations: 8
bits_weights: 4 bits_weights: 4
bits_overrides: overrides:
conv1: conv1:
wts: null bits_weights: null
acts: null bits_activations: null
relu1: relu1:
wts: null bits_weights: null
acts: null bits_activations: null
final_relu: final_relu:
wts: null bits_weights: null
acts: null bits_activations: null
fc: fc:
wts: null bits_weights: null
acts: null bits_activations: null
</code></pre> </code></pre>
<ul> <ul>
<li>The specific quantization method we're instantiating here is <code>DorefaQuantizer</code>.</li> <li>The specific quantization method we're instantiating here is <code>DorefaQuantizer</code>.</li>
<li>Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. </li> <li>Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. </li>
<li>Then, we define the <code>bits_overrides</code> mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of <code>DorefaQuantizer</code>, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters <code>conv1</code>, the first activation layer <code>relu1</code>, the last activation layer <code>final_relu</code> and the last layer with parameters <code>fc</code>.</li> <li>Then, we define the <code>overrides</code> mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of <code>DorefaQuantizer</code>, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters <code>conv1</code>, the first activation layer <code>relu1</code>, the last activation layer <code>final_relu</code> and the last layer with parameters <code>fc</code>.</li>
<li>Specifying <code>null</code> means "do not quantize".</li> <li>Specifying <code>null</code> means "do not quantize".</li>
<li>Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.</li> <li>Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.</li>
</ul> </ul>
<h3 id="defining-overrides-for-groups-of-layers-using-regular-expressions">Defining overrides for <strong>groups of layers</strong> using regular expressions</h3> <h3 id="defining-overrides-for-groups-of-layers-using-regular-expressions">Defining overrides for <strong>groups of layers</strong> using regular expressions</h3>
<p>Suppose we have a sub-module in our model named <code>block1</code>, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named <code>conv1</code>, <code>conv2</code> and so on. In that case we would define the following:</p> <p>Suppose we have a sub-module in our model named <code>block1</code>, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named <code>conv1</code>, <code>conv2</code> and so on. In that case we would define the following:</p>
<pre><code>bits_overrides: <pre><code>overrides:
'block1\.conv*': 'block1\.conv*':
wts: 2 bits_weights: 2
acts: null bits_activations: null
</code></pre> </code></pre>
<ul> <ul>
<li><strong>RegEx Note</strong>: Remember that the dot (<code>.</code>) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: <code>\.</code></li> <li><strong>RegEx Note</strong>: Remember that the dot (<code>.</code>) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: <code>\.</code></li>
</ul> </ul>
<p><strong>Overlapping patterns</strong> are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for <code>block1.conv1</code>:</p> <p><strong>Overlapping patterns</strong> are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for <code>block1.conv1</code>:</p>
<pre><code>bits_overrides: <pre><code>overrides:
'block1\.conv1': 'block1\.conv1':
wts: 4 bits_weights: 4
acts: null bits_activations: null
'block1\.conv*': 'block1\.conv*':
wts: 2 bits_weights: 2
acts: null bits_activations: null
</code></pre> </code></pre>
<ul> <ul>
...@@ -563,8 +563,9 @@ args = parser.parse_args() ...@@ -563,8 +563,9 @@ args = parser.parse_args()
# Execute evaluation on model as usual # Execute evaluation on model as usual
</code></pre> </code></pre>
<p>Note that the command-line arguments don't expose the <code>bits_overrides</code> parameter of the quantizer, which allows fine-grained control over how each layer is quantized. To utilize this functionality, configure with a YAML file.</p> <p>Note that the command-line arguments don't expose the <code>overrides</code> parameter of the quantizer, which allows fine-grained control over how each layer is quantized. To utilize this functionality, configure with a YAML file.</p>
<p>To see integration of these command line arguments in use, see the <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py">image classification example</a>. For examples invocations of post-training quantization see <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant">here</a>.</p> <p>To see integration of these command line arguments in use, see the <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py">image classification example</a>.
For examples invocations of post-training quantization see <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_training_quant">here</a>.</p>
<h3 id="collecting-statistics-for-quantization">Collecting Statistics for Quantization</h3> <h3 id="collecting-statistics-for-quantization">Collecting Statistics for Quantization</h3>
<p>To collect generate statistics that can be used for static quantization of activations, do the following (shown here assuming the command line argument <code>--qe-calibration</code> shown above is used, which specifies the number of batches to use for statistics generation):</p> <p>To collect generate statistics that can be used for static quantization of activations, do the following (shown here assuming the command line argument <code>--qe-calibration</code> shown above is used, which specifies the number of batches to use for statistics generation):</p>
<pre><code class="python">if args.qe_calibration: <pre><code class="python">if args.qe_calibration:
......
This diff is collapsed.
...@@ -2,87 +2,87 @@ ...@@ -2,87 +2,87 @@
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"> <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
<url> <url>
<loc>None</loc> <loc>None</loc>
<lastmod>2019-03-28</lastmod> <lastmod>2019-04-01</lastmod>
<changefreq>daily</changefreq> <changefreq>daily</changefreq>
</url> </url>
</urlset> </urlset>
\ No newline at end of file
No preview for this file type
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# python compress_classifier.py -a=resnet18 -p=10 -j=22 <path_to_imagenet_dataset> --pretrained --evaluate --quantize-eval --qe-config-file ../quantization/post_train_quant/resnet18_imagenet_post_train.yaml # python compress_classifier.py -a=resnet18 -p=10 -j=22 <path_to_imagenet_dataset> --pretrained --evaluate --quantize-eval --qe-config-file ../quantization/post_train_quant/resnet18_imagenet_post_train.yaml
# (Note that when '--qe-config-file' is passed, all other '--qe*' arguments are ignored. Only the settings in the YAML file are used) # (Note that when '--qe-config-file' is passed, all other '--qe*' arguments are ignored. Only the settings in the YAML file are used)
# #
# Specifically, configuring with a YAML file allows us to define the 'bits_overrides' section, which is cumbersome # Specifically, configuring with a YAML file allows us to define the 'overrides' section, which is cumbersome
# to define programatically and not exposed as a command-line argument. # to define programatically and not exposed as a command-line argument.
# #
# To illustrate how this may come in handy, we'll try post-training quantization of ResNet-18 using 6-bits for weights # To illustrate how this may come in handy, we'll try post-training quantization of ResNet-18 using 6-bits for weights
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
# FP32: (Run the above command line without the part starting with '--quantize-eval') # FP32: (Run the above command line without the part starting with '--quantize-eval')
# ==> Top1: 69.758 Top5: 89.076 Loss: 1.251 # ==> Top1: 69.758 Top5: 89.076 Loss: 1.251
# #
# All layers 6-bits: (Just comment out the entire 'bits_overrides' section) # All layers 6-bits: (Just comment out the entire 'overrides' section)
# ==> Top1: 64.702 Top5: 85.990 Loss: 1.467 # ==> Top1: 64.702 Top5: 85.990 Loss: 1.467
# #
# First + last layer with 8-bits, everything else in 6-bits (comment out just the section with the header '.*add') # First + last layer with 8-bits, everything else in 6-bits (comment out just the section with the header '.*add')
...@@ -46,14 +46,14 @@ quantizers: ...@@ -46,14 +46,14 @@ quantizers:
per_channel_wts: True per_channel_wts: True
clip_acts: True clip_acts: True
no_clip_layers: fc no_clip_layers: fc
bits_overrides: overrides:
# First and last layers + element-wise add layers in 8-bits # First and last layers + element-wise add layers in 8-bits
conv1: conv1:
wts: 8 bits_weights: 8
acts: 8 bits_activations: 8
.*add: .*add:
wts: 8 bits_weights: 8
acts: 8 bits_activations: 8
fc: fc:
wts: 8 bits_weights: 8
acts: 8 bits_activations: 8
...@@ -3,20 +3,20 @@ quantizers: ...@@ -3,20 +3,20 @@ quantizers:
class: DorefaQuantizer class: DorefaQuantizer
bits_activations: 8 bits_activations: 8
bits_weights: 3 bits_weights: 3
bits_overrides: overrides:
# Don't quantize first and last layer # Don't quantize first and last layer
features.0: features.0:
wts: null bits_weights: null
acts: null bits_activations: null
features.1: features.1:
wts: null bits_weights: null
acts: null bits_activations: null
classifier.5: classifier.5:
wts: null bits_weights: null
acts: null bits_activations: null
classifier.6: classifier.6:
wts: null bits_weights: null
acts: null bits_activations: null
lr_schedulers: lr_schedulers:
training_lr: training_lr:
......
...@@ -3,20 +3,20 @@ quantizers: ...@@ -3,20 +3,20 @@ quantizers:
class: DorefaQuantizer class: DorefaQuantizer
bits_activations: 8 bits_activations: 8
bits_weights: 3 bits_weights: 3
bits_overrides: overrides:
# Don't quantize first and last layer # Don't quantize first and last layer
conv1: conv1:
wts: null bits_weights: null
acts: null bits_activations: null
relu1: relu1:
wts: null bits_weights: null
acts: null bits_activations: null
final_relu: final_relu:
wts: null bits_weights: null
acts: null bits_activations: null
fc: fc:
wts: null bits_weights: null
acts: null bits_activations: null
lr_schedulers: lr_schedulers:
training_lr: training_lr:
......
...@@ -22,20 +22,20 @@ quantizers: ...@@ -22,20 +22,20 @@ quantizers:
act_clip_init_val: 8.0 act_clip_init_val: 8.0
bits_activations: 4 bits_activations: 4
bits_weights: 3 bits_weights: 3
bits_overrides: overrides:
# Don't quantize first and last layers # Don't quantize first and last layers
conv1: conv1:
wts: null bits_weights: null
acts: null bits_activations: null
layer1.0.pre_relu: layer1.0.pre_relu:
wts: null bits_weights: null
acts: null bits_activations: null
final_relu: final_relu:
wts: null bits_weights: null
acts: null bits_activations: null
fc: fc:
wts: null bits_weights: null
acts: null bits_activations: null
lr_schedulers: lr_schedulers:
training_lr: training_lr:
......
...@@ -34,20 +34,20 @@ quantizers: ...@@ -34,20 +34,20 @@ quantizers:
class: DorefaQuantizer class: DorefaQuantizer
bits_activations: 8 bits_activations: 8
bits_weights: 3 bits_weights: 3
bits_overrides: overrides:
# Don't quantize first and last layer # Don't quantize first and last layer
conv1: conv1:
wts: null bits_weights: null
acts: null bits_activations: null
layer1.0.pre_relu: layer1.0.pre_relu:
wts: null bits_weights: null
acts: null bits_activations: null
final_relu: final_relu:
wts: null bits_weights: null
acts: null bits_activations: null
fc: fc:
wts: null bits_weights: null
acts: null bits_activations: null
lr_schedulers: lr_schedulers:
training_lr: training_lr:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment