Skip to content
Snippets Groups Projects
Unverified Commit 47175961 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Separate LinearQuantMode for weights/activations (#451)

In PostTrainLinearQuantizer and QuantAwareTrainRangeLinearQuantizer
parent 012417a5
No related branches found
No related tags found
No related merge requests found
......@@ -58,6 +58,16 @@ class LinearQuantMode(Enum):
ASYMMETRIC_SIGNED = 3
class ModuleQuantMode(namedtuple('ModuleQuantMode', ['activations', 'weights'])):
"""
Named tuple for configuring the LinearQuantMode of both weights and activations of a module
"""
def __new__(cls, activations, weights):
if not isinstance(activations, LinearQuantMode) or not isinstance(weights, LinearQuantMode):
raise ValueError('ModuleQuantMode must receive LinearQuantMode values')
return super(ModuleQuantMode, cls).__new__(cls, activations, weights)
class ClipMode(Enum):
# No clipping - absolute min/max values will be used
NONE = 0
......@@ -85,7 +95,15 @@ def _verify_enum_value(val, enum_cls):
def verify_quant_mode(mode):
return _verify_enum_value(mode, LinearQuantMode)
if isinstance(mode, ModuleQuantMode):
return mode
if isinstance(mode, dict):
acts = _verify_enum_value(mode['activations'], LinearQuantMode)
wts = _verify_enum_value(mode['weights'], LinearQuantMode)
else:
acts = wts = _verify_enum_value(mode, LinearQuantMode)
return ModuleQuantMode(acts, wts)
def verify_clip_mode(mode):
......@@ -218,21 +236,28 @@ def linear_dequantize_with_metadata(t, inplace=False):
def add_post_train_quant_args(argparser):
str_to_quant_mode_map = {'sym': LinearQuantMode.SYMMETRIC,
'asym_s': LinearQuantMode.ASYMMETRIC_SIGNED,
'asym_u': LinearQuantMode.ASYMMETRIC_UNSIGNED}
str_to_clip_mode_map = {'none': ClipMode.NONE, 'avg': ClipMode.AVG, 'n_std': ClipMode.N_STD,
'gauss': ClipMode.GAUSS, 'laplace': ClipMode.LAPLACE}
def from_dict(d, val_str):
str_to_quant_mode_map = OrderedDict([
('sym', LinearQuantMode.SYMMETRIC),
('asym_s', LinearQuantMode.ASYMMETRIC_SIGNED),
('asym_u', LinearQuantMode.ASYMMETRIC_UNSIGNED)
])
str_to_clip_mode_map = OrderedDict([
('none', ClipMode.NONE), ('avg', ClipMode.AVG), ('n_std', ClipMode.N_STD),
('gauss', ClipMode.GAUSS), ('laplace', ClipMode.LAPLACE)
])
def from_dict(val_str, d, optional):
if not val_str and optional:
return None
try:
return d[val_str]
except KeyError:
raise argparse.ArgumentTypeError('Must be one of {0} (received {1})'.format(list(d.keys()), val_str))
linear_quant_mode_str = partial(from_dict, str_to_quant_mode_map)
clip_mode_str = partial(from_dict, str_to_clip_mode_map)
linear_quant_mode_str = partial(from_dict, d=str_to_quant_mode_map, optional=False)
linear_quant_mode_str_optional = partial(from_dict, d=str_to_quant_mode_map, optional=True)
clip_mode_str = partial(from_dict, d=str_to_clip_mode_map, optional=False)
group = argparser.add_argument_group('Arguments controlling quantization at evaluation time '
'("post-training quantization")')
......@@ -240,7 +265,14 @@ def add_post_train_quant_args(argparser):
help='Apply linear quantization to model before evaluation. Applicable only if '
'--evaluate is also set')
group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
help='Default linear quantization mode (for weights and activations). '
'Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
group.add_argument('--qe-mode-acts', '--qema', type=linear_quant_mode_str_optional, default=None,
help='Linear quantization mode for activations. Overrides --qe-mode`. '
'Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
group.add_argument('--qe-mode-wts', '--qemw', type=linear_quant_mode_str_optional, default=None,
help='Linear quantization mode for Weights. Overrides --qe-mode`. '
'Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of activations. Use 0 to not quantize activations. '
'Default value is 8')
......@@ -269,7 +301,7 @@ def add_post_train_quant_args(argparser):
stats_group.add_argument('--qe-calibration', type=distiller.utils.float_range_argparse_checker(exc_min=True),
metavar='PORTION_OF_TEST_SET', default=None,
help='Run the model in evaluation mode on the specified portion of the test dataset and '
'collect statistics. Ignores all other \'qe--*\' arguments')
'collect statistics')
stats_group.add_argument('--qe-config-file', type=str, metavar='PATH',
help='Path to YAML file containing configuration for PostTrainLinearQuantizer (if present, '
'all other --qe* arguments are ignored)')
......@@ -283,7 +315,7 @@ class RangeLinearQuantWrapper(nn.Module):
wrapped_module (torch.nn.Module): Module to be wrapped
num_bits_acts (int): Number of bits used for inputs and output quantization
num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results
mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed / unsigned)
mode (ModuleQuantMode / LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed / unsigned)
clip_acts (ClipMode): Activations clipping mode to use
activation_stats (dict): Dict containing activation stats, used for static calculation of quantization
parameters. Dict should be in the format exported by distiller.data_loggers.QuantCalibrationStatsCollector.
......@@ -305,13 +337,16 @@ class RangeLinearQuantWrapper(nn.Module):
input_overrides = input_overrides or OrderedDict()
mode = verify_quant_mode(mode)
self.mode = mode
self.wrapped_module = wrapped_module
self.clip_half_range = clip_half_range
self.scale_approx_mult_bits = scale_approx_mult_bits
self.requires_quantized_inputs = requires_quantized_inputs
self.inputs_quant_auto_fallback = inputs_quant_auto_fallback
self.output_quant_settings = QuantSettings(num_bits_acts, mode, clip_acts, clip_n_stds, clip_half_range, False)
self.output_quant_settings = QuantSettings(num_bits_acts, mode.activations, clip_acts, clip_n_stds,
clip_half_range, False)
self.accum_quant_settings = QuantSettings(num_bits_accum, LinearQuantMode.SYMMETRIC,
ClipMode.NONE, None, False, False)
......@@ -349,7 +384,7 @@ class RangeLinearQuantWrapper(nn.Module):
# so other than inspecting the contents there's not much to do with it)
self._dequant_out = True
signed = mode != LinearQuantMode.ASYMMETRIC_UNSIGNED
signed = mode.activations != LinearQuantMode.ASYMMETRIC_UNSIGNED
self.acts_min_q_val, self.acts_max_q_val = get_quantized_range(num_bits_acts, signed=signed)
# The accumulator is always signed
self.accum_min_q_val, self.accum_max_q_val = get_quantized_range(num_bits_accum, signed=True)
......@@ -372,8 +407,9 @@ class RangeLinearQuantWrapper(nn.Module):
else:
self.inputs_quant_metadata_fallback = None
scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode, clip_acts,
clip_n_stds, clip_half_range, scale_approx_mult_bits)
scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode.activations,
clip_acts, clip_n_stds, clip_half_range,
scale_approx_mult_bits)
self.register_buffer('output_scale', scale)
self.register_buffer('output_zero_point', zp)
else:
......@@ -569,7 +605,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
num_bits_acts (int): Number of bits used for inputs and output quantization
num_bits_params (int): Number of bits used for parameters (weights and bias) quantization
num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results
mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned)
mode (ModuleQuantMode / LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned)
clip_acts (ClipNode): See RangeLinearQuantWrapper
per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per
output channel
......@@ -595,7 +631,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
# If activations are not quantized, we do fake quantization of the parameters, that is - quant and de-quant
self.fake_quant_params = self.output_quant_settings.num_bits is None
self.wts_quant_settings = QuantSettings(num_bits_params, mode, ClipMode.NONE, None, False, per_channel_wts)
self.wts_quant_settings = QuantSettings(num_bits_params, self.mode.weights, ClipMode.NONE, None, False,
per_channel_wts)
self.params_min_q_val, self.params_max_q_val = get_quantized_range(
self.wts_quant_settings.num_bits,
......@@ -696,7 +733,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
# to the input and weights and pass those to the wrapped model. Functionally, since at this point we're
# dealing solely with integer values, the results are the same either way.
if self.output_quant_settings.quant_mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted:
if self.wts_quant_settings.quant_mode != LinearQuantMode.SYMMETRIC and \
not self.is_simulated_quant_weight_shifted:
# We "store" the w_zero_point inside our wrapped module's weights to
# improve performance on inference.
self.wrapped_module.weight.data += self.w_zero_point
......@@ -755,7 +793,7 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
wrapped_module (distiller.modules.Matmul or distiller.modules.BatchMatmul): Module to be wrapped
num_bits_acts (int): Number of bits used for inputs and output quantization
num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results
mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned)
mode (ModuleQuantMode / LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned)
clip_acts (ClipNode): See RangeLinearQuantWrapper
activation_stats (dict): See RangeLinearQuantWrapper
clip_n_stds (int): See RangeLinearQuantWrapper
......@@ -967,13 +1005,16 @@ class RangeLinearEmbeddingWrapper(nn.Module):
super(RangeLinearEmbeddingWrapper, self).__init__()
mode = verify_quant_mode(mode)
self.mode = mode
self.min_q_val, self.max_q_val = get_quantized_range(num_bits,
signed=mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
signed=mode.weights != LinearQuantMode.ASYMMETRIC_UNSIGNED)
if stats is None:
w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, self.mode)
w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights)
else:
w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode)
w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights)
device = wrapped_module.weight.device
......@@ -1057,7 +1098,7 @@ class PostTrainLinearQuantizer(Quantizer):
* distiller.modules.EltwiseMult
* distiller.modules.Matmul
* distiller.modules.BatchMatmul
An existing module will need likely need to be modified to use the 'distiller.modules.*' modules. This needs to
An existing module will likely need to be modified to use the 'distiller.modules.*' modules. This needs to
be done BEFORE creating the quantizer. See the docs for more details:
https://nervanasystems.github.io/distiller/prepare_model_quant.html
......@@ -1068,7 +1109,7 @@ class PostTrainLinearQuantizer(Quantizer):
model (torch.nn.Module): Model to be quantized
bits_activations/parameters/accum (int): Number of bits to be used when quantizing each tensor type
overrides (:obj:`OrderedDict`, optional): Overrides the layers quantization settings.
mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed / unsigned)
mode (ModuleQuantMode / LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed / unsigned)
clip_acts (ClipMode): Activations clipping mode to use
per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per
output channel
......@@ -1129,11 +1170,12 @@ class PostTrainLinearQuantizer(Quantizer):
" * Optimizations for quantization of layers followed by Relu/Tanh/Sigmoid are only "
"supported when statistics are used.\nEND WARNING\n")
mode_dict = {'activations': _enum_to_str(mode.activations), 'weights': _enum_to_str(mode.weights)}
self.model.quantizer_metadata = {'type': type(self),
'params': {'bits_activations': bits_activations,
'bits_parameters': bits_parameters,
'bits_accum': bits_accum,
'mode': str(mode).split('.')[1],
'mode': mode_dict,
'clip_acts': _enum_to_str(clip_acts),
'clip_n_stds': clip_n_stds,
'clip_half_range': clip_half_range,
......@@ -1329,11 +1371,14 @@ class PostTrainLinearQuantizer(Quantizer):
[(layer, OrderedDict([('clip_acts', 'NONE')]))
for layer in args.qe_no_clip_layers]
)
mode_acts = args.qe_mode_acts or args.qe_mode
mode_wts = args.qe_mode_wts or args.qe_mode
mode = ModuleQuantMode(mode_acts, mode_wts)
return cls(model,
bits_activations=args.qe_bits_acts,
bits_parameters=args.qe_bits_wts,
bits_accum=args.qe_bits_accum,
mode=args.qe_mode,
mode=mode,
clip_acts=args.qe_clip_acts,
per_channel_wts=args.qe_per_channel,
model_activation_stats=(None if args.qe_dynamic else args.qe_stats_file),
......@@ -1659,7 +1704,8 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer):
mode = verify_quant_mode(mode)
self.model.quantizer_metadata['params']['mode'] = str(mode).split('.')[1]
mode_dict = {'activations': _enum_to_str(mode.activations), 'weights': _enum_to_str(mode.weights)}
self.model.quantizer_metadata['params']['mode'] = mode_dict
self.model.quantizer_metadata['params']['ema_decay'] = ema_decay
self.model.quantizer_metadata['params']['per_channel_wts'] = per_channel_wts
self.model.quantizer_metadata['params']['quantize_inputs'] = quantize_inputs
......
......@@ -15,37 +15,56 @@ Post-training quantization can either be configured straight from the command-li
| Long Form | Short | Description | Default |
|--------------------------|-----------|---------------------------------------------------------------------------------------|---------|
| `--quantize-eval` | `--qe` | Apply linear quantization to model before evaluation | Off |
| `--qe-mode` | `--qem` | Linear quantization mode. Choices: "sym", "asym_u", "asym_s" | "sym" |
| `--qe-mode` | `--qem` | Default linear quantization mode (for weights and activations). Choices: "sym", "asym_u", "asym_s" | "sym" |
| `--qe-mode-acts` | `--qema` | Linear quantization mode for activations. **Overrides `--qe-mode`**. Choices: "sym", "asym_u", "asym_s" | None |
| `--qe-mode-wts` | `--qemw` | Linear quantization mode for weights. **Overrides `--qe-mode`**. Choices: "sym", "asym_u", "asym_s" | None |
| `--qe-bits-acts` | `--qeba` | # of bits for quantization of activations. Use 0 to not quantize activations | 8 |
| `--qe-bits-wts` | `--qebw` | # of bits for quantization of weights. Use 0 to not quantize weights | 8 |
| `--qe-bits-accum` | N/A | # of bits for quantization of the accumulator | 32 |
| `--qe-clip-acts` | `--qeca` | Set activations clipping mode. Choices: "none", "avg", "n_std" | "none" |
| `--qe-clip-acts` | `--qeca` | Set activations clipping mode. Choices: "none", "avg", "n_std", "gauss", "laplace" | "none" |
| `--qe-clip-n-stds` | N/A | When qe-clip-acts is set to 'n_std', this is the number of standard deviations to use | None |
| `--qe-no-clip-layers` | `--qencl` | List of layer names (space-separated) for which not to clip activations | '' |
| `--qe-per-channel` | `--qepc` | Enable per-channel quantization of weights (per output channel) | Off |
| `--qe-scale-approx-bits` | `--qesab` | Enables scale factor approximation using integer multiply + bit shift, using this number of bits the integer multiplier | None |
| `--qe-stats-file` | N/A | Use stats file for static quantization of activations. See details below | None |
| `--qe-dynamic` | N/A | Perform dynamic quantization. See details below | None |
| `--qe-config-file` | N/A | Path to YAML config file. See section above. (ignores all other --qe* arguments) | None |
(Note that these arguments can be added to any `argparse.ArgumentParser` by calling `distiller.quantization.add_post_train_quant_args()` and passing an existing parser)
### "Net-Aware" Quantization
## "Net-Aware" Quantization
The term "net-aware" quantization, coined in [this](https://arxiv.org/abs/1811.09886) paper from Facebook (section 3.2.2), means we can achieve better quantization by considering sequences of operations instead of just quantizing each operation independently. This isn't exactly layer fusion - in Distiller we modify activation stats prior to setting quantization parameters, in to make sure that when a module is followed by certain activation functions, only the relevant ranges are quantized. We do this for:
* **ReLU** - Clip all negative values
* **Tanh / Sigmoid** - Clip according to the (approximated) saturation values for these functions. We use [-4, 4] for tanh and [-6, 6] for sigmoid.
### Static vs. Dynamic Quantization of Activations
## Static vs. Dynamic Quantization of Activations
Distiller supports both "static" and "dynamic" post-training quantization of **activations**.
* **Static Quantization:** Pre-calculated tensor statistics are used to calculate the quantization parameters.
* **Dynamic Quantization:** Quantization parameters are re-calculated for each batch.
**Support for this mode is limited**. It isn't as fully featured as static quantization, and the accuracy results obtained when using it are likely not as representative of real-world results.
Specifically:
* Only convolution, FC (aka Linear) and embedding layers are supported at this time. Non-supported layers are kept in FP32, and a warning is displayed.
* "Net-aware" quantization, described above, isn't supported in dynamic mode.
### Static Quantization
Pre-calculated tensor statistics are used to calculate the quantization parameters. A preliminary step of collecting these statistics is required. This step is commonly refered to as the **calibration step**.
#### Generating stats
To generate stats, use the `--qe-calibration <VAL>` command line argument. `VAL` should be a numeric value in the range \[0 .. 1\], indicating how much of the test dataset should be used to collect statistics. For example, passing 0.05 will use 5% of the test set. Stats are saved in a YAML file name `acts_quantization_stats.yaml` in the run directory.
* In the image classification sample, if both `--qe-calibration` and `--quantize-eval` are passed, calibration will be followed by model quantization in the same run. If only the calibration argument is passed, then the script will exit after the calibration step.
* **NOTE:** The image classification sample performs static quantization by default. That means that if a stats file isn't passed (see next section), then a calibration step will be executed prior to quantization, using 5% of the test set (equivalent to using `--qe-calibration 0.05`).
#### Using previously generated stats
In most cases, there is no need to re-run calibration each time we quantize a model. A previously generated stats file can be passed via `--qe-stats-file <path_to_yaml_stats_file>`. This will skip calibration step.
### Dynamic Quantization
Quantization parameters are re-calculated for each batch.
**Support for this mode is limited**. It isn't as fully featured as static quantization, and the accuracy results obtained when using it are likely not as representative of real-world results. Specifically:
* Only convolution, FC (aka Linear) and embedding layers are supported at this time. Non-supported layers are kept in FP32, and a warning is displayed.
* "Net-aware" quantization, described above, isn't supported in dynamic mode.
## Sample Invocations
......@@ -57,7 +76,7 @@ cd <distiller_root>/examples/classifier_compression
All the paths used are relative to this directory.
All the examples below are using **static quantization** of activations, which means the first step is to collect activation statistics for the FP32 model. We will use the sample stats file located at:
All the examples below are using **static quantization** of activations. As discussed above, to avoid running a calibration step each time, we'll use a pre-generated stats file located at:
```
<dilstiller_root>/examples/quantization/post_train_quant/stats/resnet50_quant_stats.yaml
......@@ -86,26 +105,28 @@ This table summarizes the settings and results for each run. The command lines f
| 9 | Asymmetric | 6 | 6 | No | none | 62.230% |
| 10 | Asymmetric | 6 | 6 | Yes | avg (exc. last layer) | 74.494% |
(Note that it's possible to define symmetric/asymmetric mode separately for weights and activations using `--qe-mode-wts` and `--qe-mode-acts`, respectively. For brevity and simplicity here we use a monolithic setting via the `--qe-mode` flag)
Command lines:
| | Command Line |
|----|--------------|
| 1 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate`
| 2 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 3 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-per-channel --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 4 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-per-channel --qe-clip-acts avg --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 5 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-per-channel --qe-clip-acts avg --qe-no-clip-layers fc --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 2 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-mode sym --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 3 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-mode sym --qe-per-channel --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 4 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-mode sym --qe-per-channel --qe-clip-acts avg --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 5 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-mode sym --qe-per-channel --qe-clip-acts avg --qe-no-clip-layers fc --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 6 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-mode asym_u --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 7 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-mode asym_u --qe-per-channel --qe-clip-acts avg --qe-no-clip-layers fc --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 8 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6 --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 8 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6 --qe-mode sym --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 9 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6 --qe-mode asym_u --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
| 10 | `python compress_classifier.py -a resnet50 --pretrained <path_to_imagenet_dataset> --evaluate --quantize-eval --qe-bits-acts 6 --qe-bits-wts 6 --qe-mode asym_u --qe-per-channel --qe-clip-acts avg --qe-no-clip-layers fc --qe-stats-file ../quantization/post_train_quant/stats/resnet50_quant_stats.yaml`
## Note 1: Accuracy Loss When Clipping Activations
### Note 1: Accuracy Loss When Clipping Activations
Notice the degradation in accuracy in run (4) - ~2.6% compared to per-channel without clipping. Let's recall that the output of the final layer of the model holds the "score" of each class (which, since we're using softmax, can be interpreted as the un-normalized log probability of each class). So if we clip the outputs of this layer, we're in fact "cutting-off" the highest (and lowest) scores. If the highest scores for some sample are close enough, this can result in a wrong classification of that sample.
We can provide Distiller with a list of layers for which not to clip activations. In this case we just want to skip the last layer, which in the case of the ResNet-50 model is called `fc`. This is what we do in run (5), and we regain most of the accuracy back.
## Note 2: Under 8-bits
### Note 2: Under 8-bits
Runs (8) - (10) are examples of trying post-training quantization below 8-bits. Notice how with the most basic settings we get a massive accuracy loss of ~53%. Even with asymmetric quantization and all other optimizations enabled, we still get a non-trivial degradation of just under 2% vs. FP32. In many cases, quantizing with less than 8-bits requires quantization-aware training. However, if we allow some layers to remain in 8-bit, we can regain some of the accuracy. We can do this by using a YAML configuration file and specifying overrides. As mentioned at the top of this document, check out the `resnet18_imagenet_post_train.yaml` file located in this directory for an example of how to do this.
......@@ -53,7 +53,16 @@ quantizers:
bits_activations: 6
bits_parameters: 6
bits_accum: 32
# Quantization mode can be defined either with a single value for both weights and activations, or with
# a nested dictionary specifying weights and activations separately.
# All the results in the table above are using ASYMMETRIC for both weights and activations.
mode: ASYMMETRIC_UNSIGNED
# Example of mixed definition:
# mode:
# activations: ASYMMETRIC_UNSIGNED
# weights: SYMMETRIC
# Path to stats file assuming this is being invoked from the 'classifier_compression' example directory
model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml
per_channel_wts: True
......
......@@ -37,7 +37,15 @@ quantizers:
bits_activations: 8
bits_parameters: 8
bits_accum: 32
# Quantization mode can be defined either with a single value for both weights and activations, or with
# a nested dictionary specifying weights and activations separately
mode: ASYMMETRIC_UNSIGNED
# Example of mixed definition
# mode:
# activations: ASYMMETRIC_UNSIGNED
# weights: SYMMETRIC
# Path to stats file assuming this is being invoked from the 'classifier_compression' example directory
model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml
per_channel_wts: True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment