diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 87f85466a54ef1f67b1f435edaf8c62e57d3384b..6e3fdb18518414d00025f72186b0a58070dd0ed3 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -242,9 +242,11 @@ def add_post_train_quant_args(argparser): group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym', help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys())) group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS', - help='Number of bits for quantization of activations') + help='Number of bits for quantization of activations. Use 0 to not quantize activations. ' + 'Default value is 8') group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS', - help='Number of bits for quantization of weights') + help='Number of bits for quantization of weights. Use 0 to not quantize weights. ' + 'Default value is 8') group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS', help='Number of bits for quantization of the accumulator') group.add_argument('--qe-clip-acts', '--qeca', type=clip_mode_str, default='none', @@ -312,6 +314,15 @@ class RangeLinearQuantWrapper(nn.Module): self.output_quant_settings = QuantSettings(num_bits_acts, mode, clip_acts, clip_n_stds, clip_half_range, False) self.accum_quant_settings = QuantSettings(num_bits_accum, LinearQuantMode.SYMMETRIC, ClipMode.NONE, None, False, False) + + self.preset_act_stats = False + self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long)) + + # Activations not quantized - stop here + if num_bits_acts is None: + return + + # Activations are quantized - setup quantization parameters if self.requires_quantized_inputs: self.inputs_quant_settings_overrides = OrderedDict() for k, v in input_overrides.items(): @@ -320,13 +331,12 @@ class RangeLinearQuantWrapper(nn.Module): quant_settings = deepcopy(self.output_quant_settings) quant_settings.clip_half_range = False else: - quant_settings = QuantSettings(v.pop('bits_activations', self.output_quant_settings.num_bits), - verify_quant_mode( - v.pop('mode', self.output_quant_settings.quant_mode)), - verify_clip_mode( - v.pop('clip_acts', self.output_quant_settings.clip_mode)), - v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds), - False, False) + quant_settings = QuantSettings( + v.pop('bits_activations', self.output_quant_settings.num_bits), + verify_quant_mode(v.pop('mode', self.output_quant_settings.quant_mode)), + verify_clip_mode(v.pop('clip_acts', self.output_quant_settings.clip_mode)), + v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds), + False, False) if v: # Poor man's input checking on input overrides dict raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys())) @@ -369,10 +379,8 @@ class RangeLinearQuantWrapper(nn.Module): else: self.preset_act_stats = False - self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long)) - def named_acts_quant_params(self): - if self.preset_act_stats: + if self.output_quant_settings.num_bits is not None and self.preset_act_stats: # Output scale buffers are saved in the model only when stats are used yield 'output_scale', self.output_scale yield 'output_zero_point', self.output_zero_point @@ -381,6 +389,13 @@ class RangeLinearQuantWrapper(nn.Module): if self.training: raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode") + if self.output_quant_settings.num_bits is None: + # Pass through + out = self.wrapped_module(*inputs) + if self.clip_half_range: + out = f.relu(out) + return out + device = inputs[0].device for buffer_name, buffer in self._buffers.items(): setattr(self, buffer_name, buffer.to(device)) @@ -498,6 +513,9 @@ class RangeLinearQuantWrapper(nn.Module): raise NotImplementedError def extra_repr(self): + if self.output_quant_settings.num_bits is None: + return 'output_quant_settings=Not_Quantized' + tmpstr = 'output_quant_settings={0}'.format(self.output_quant_settings) tmpstr += '\naccum_quant_settings={0}'.format(self.accum_quant_settings) tmpstr += '\nrequires_quantized_inputs={0}'.format(self.requires_quantized_inputs) @@ -574,6 +592,9 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): if not isinstance(wrapped_module, (nn.Conv2d, nn.Conv3d, nn.Linear)): raise ValueError(self.__class__.__name__ + ' can wrap only Conv2D, Conv3D and Linear modules') + # If activations are not quantized, we do fake quantization of the parameters, that is - quant and de-quant + self.fake_quant_params = self.output_quant_settings.num_bits is None + self.wts_quant_settings = QuantSettings(num_bits_params, mode, ClipMode.NONE, None, False, per_channel_wts) self.params_min_q_val, self.params_max_q_val = get_quantized_range( @@ -591,20 +612,9 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.params_min_q_val, self.params_max_q_val, inplace=True) - self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None - device = self.w_scale.device - - if self.preset_act_stats: - t = torch.zeros_like(self.w_scale) - if self.wts_quant_settings.per_channel: - t = t.squeeze(dim=-1) - self.register_buffer('accum_scale', t) - else: - self.accum_scale = torch.ones(1).to(device) - # Quantize bias self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None - if self.has_bias and not self.preset_act_stats: + if self.has_bias and (self.fake_quant_params or not self.preset_act_stats): b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias, self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode) @@ -612,8 +622,27 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.register_buffer('b_zero_point', b_zero_point) base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point, self.params_min_q_val, self.params_max_q_val) - # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor - self.register_buffer('base_b_q', base_b_q) + if not self.preset_act_stats: + # Dynamic ranges - save in auxiliary buffer, + # requantize each time based on dynamic input scale factor + self.register_buffer('base_b_q', base_b_q) + + # Activations not quantized - de-quant parameters and return + if self.fake_quant_params: + linear_dequantize(wrapped_module.weight.data, self.w_scale, self.w_zero_point, inplace=True) + if self.has_bias: + wrapped_module.bias = torch.nn.Parameter(linear_dequantize(base_b_q, self.b_scale, self.b_zero_point)) + return + + # Activations are quantized - setup accumulator quantization parameters + device = self.w_scale.device + if self.preset_act_stats: + t = torch.zeros_like(self.w_scale) + if self.wts_quant_settings.per_channel: + t = t.squeeze(dim=-1) + self.register_buffer('accum_scale', t) + else: + self.accum_scale = torch.ones(1).to(device) # A flag indicating that the simulated quantized weights are pre-shifted. for faster performance. # In the first forward pass - `w_zero_point` is added into the weights, to allow faster inference, @@ -623,7 +652,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.register_buffer('is_simulated_quant_weight_shifted', torch.tensor(0, dtype=torch.uint8, device=device)) def state_dict(self, destination=None, prefix='', keep_vars=False): - if self.is_simulated_quant_weight_shifted: + if not self.fake_quant_params and self.is_simulated_quant_weight_shifted: # We want to return the weights to their integer representation: self.wrapped_module.weight.data -= self.w_zero_point self.is_simulated_quant_weight_shifted.fill_(False) # i.e. is_simulated_quant_weight_shifted = False @@ -920,6 +949,11 @@ class FPWrapper(nn.Module): return result + def extra_repr(self): + tmpstr = 'float_dtype={}, convert_input={}, return_fp32={}'.format(self.dtype, self.convert_input, + self.return_fp32) + return tmpstr + class FP16Wrapper(FPWrapper): def __init__(self, module, convert_input=True, return_fp32=True): @@ -993,6 +1027,12 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): return output_scale, output_zero_point + def extra_repr(self): + tmpstr = super(RangeLinearFakeQuantWrapper, self).extra_repr() + if self.dtype: + tmpstr += '\nwrapped_module_float_dtype={}.'.format(self.dtype) + return tmpstr + _ptq_wrappers_int_only = (RangeLinearQuantWrapper, RangeLinearEmbeddingWrapper) _ptq_wrappers_all = _ptq_wrappers_int_only + (FPWrapper,) @@ -1104,31 +1144,40 @@ class PostTrainLinearQuantizer(Quantizer): 'model_activation_stats': model_activation_stats, 'overrides': overrides_bkp}} - def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts, - mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, - clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - input_overrides=None, fpq_module=fpq_module, fake=False): + def _check_fp16_arg(fp16, fpq_module): if fp16: warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.", DeprecationWarning) fpq_module = fpq_module or 16 + return fpq_module + + def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts, + mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, + clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + input_overrides=None, fpq_module=fpq_module, fake=False): + fpq_module = _check_fp16_arg(fp16, fpq_module) + if fpq_module and not fake: + return FPWrapper(module, fpq_module) + norm_name = distiller.utils.normalize_module_name(name) + activation_stats = self.model_activation_stats.get(norm_name, None) clip_acts = verify_clip_mode(clip_acts) - if fpq_module: - if not fake: - return FPWrapper(module, fpq_module) - else: - return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, - activation_stats=self.model_activation_stats.get(norm_name, - None), - clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits, - fpq_module=fpq_module) - - return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts, + qbits = qbits_map[name] + if qbits.acts is not None and qbits.wts is None: + # Quantizing only activations equals fake-quantization + fake = True + + if fake: + return RangeLinearFakeQuantWrapper(module, qbits.acts, mode=mode, clip_acts=clip_acts, + activation_stats=activation_stats, + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + fpq_module=fpq_module) + + return RangeLinearQuantParamLayerWrapper(module, qbits.acts, qbits.wts, num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts, per_channel_wts=per_channel_wts, - activation_stats=self.model_activation_stats.get(norm_name, None), + activation_stats=activation_stats, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=input_overrides, @@ -1139,36 +1188,37 @@ class PostTrainLinearQuantizer(Quantizer): clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback, fpq_module=fpq_module, fake=False): + fpq_module = _check_fp16_arg(fp16, fpq_module) + if fpq_module and not fake: + return FPWrapper(module, fpq_module) + norm_name = distiller.utils.normalize_module_name(name) + activation_stats = self.model_activation_stats.get(norm_name, None) clip_acts = verify_clip_mode(clip_acts) - if fp16: - warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.", - DeprecationWarning) - fpq_module = fpq_module or 16 - if fpq_module: - if fake: - return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, - activation_stats=self.model_activation_stats.get(norm_name, None), - clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits, - fpq_module=fpq_module) - else: - return FPWrapper(module, fpq_module) + qbits = qbits_map[name] + + if fake: + return RangeLinearFakeQuantWrapper(module, qbits.acts, mode=mode, clip_acts=clip_acts, + activation_stats=activation_stats, + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + fpq_module=fpq_module) try: - return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, - activation_stats=self.model_activation_stats.get(norm_name, None), - clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + return wrapper_type(module, qbits.acts, mode=mode, clip_acts=clip_acts, + activation_stats=activation_stats, + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=input_overrides, inputs_quant_auto_fallback=inputs_quant_auto_fallback) except NoStatsError: - msglogger.warning('WARNING: {0} - quantization of {1} without stats not supported. ' - 'Keeping the original FP32 module'.format(name, module.__class__.__name__)) + warnings.warn('WARNING: {0} - quantization of {1} without stats not supported. ' + 'Keeping the original FP32 module'.format(name, module.__class__.__name__), UserWarning) return module - def replace_embedding(module, name, qbits_map, fp16=fp16): - if fp16: - return FP16Wrapper(module, convert_input=False) + def replace_embedding(module, name, qbits_map, fp16=fp16, fpq_module=fpq_module): + fpq_module = _check_fp16_arg(fp16, fpq_module) + if fpq_module: + return FPWrapper(module, fpq_module, convert_input=False) norm_name = distiller.utils.normalize_module_name(name) return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode, stats=self.model_activation_stats.get(norm_name, None)) @@ -1182,16 +1232,16 @@ class PostTrainLinearQuantizer(Quantizer): pred = self.adjacency_map[name].predecessors[0].name if isinstance(named_modules[pred], RangeLinearQuantWrapper): return nn.Identity() - norm_name = distiller.utils.normalize_module_name(name) - clip_acts = verify_clip_mode(clip_acts) + if distiller.has_children(module): return module - if fp16: - warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.", - DeprecationWarning) - fpq_module = 16 + + fpq_module = _check_fp16_arg(fp16, fpq_module) if not fake: return FPWrapper(module, fpq_module) + + norm_name = distiller.utils.normalize_module_name(name) + clip_acts = verify_clip_mode(clip_acts) return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, activation_stats=self.model_activation_stats.get(norm_name, None), clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, @@ -1271,6 +1321,10 @@ class PostTrainLinearQuantizer(Quantizer): return distiller.config_component_from_file_by_class(model, args.qe_config_file, 'PostTrainLinearQuantizer') else: + if args.qe_bits_acts == 0: + args.qe_bits_acts = None + if args.qe_bits_wts == 0: + args.qe_bits_wts = None overrides = OrderedDict( [(layer, OrderedDict([('clip_acts', 'NONE')])) for layer in args.qe_no_clip_layers] @@ -1496,8 +1550,6 @@ class PostTrainLinearQuantizer(Quantizer): buffer.data = buffer.data.to(device) - - ############################################################################### # Quantization-aware training ############################################################################### diff --git a/examples/quantization/post_train_quant/command_line.md b/examples/quantization/post_train_quant/command_line.md index 061d78f0aa1cddd8b99bb374370dcc124f4f4946..f2eaa4fb37519291c5fb764feb4874abe3d2e017 100644 --- a/examples/quantization/post_train_quant/command_line.md +++ b/examples/quantization/post_train_quant/command_line.md @@ -16,8 +16,8 @@ Post-training quantization can either be configured straight from the command-li |--------------------------|-----------|---------------------------------------------------------------------------------------|---------| | `--quantize-eval` | `--qe` | Apply linear quantization to model before evaluation | Off | | `--qe-mode` | `--qem` | Linear quantization mode. Choices: "sym", "asym_u", "asym_s" | "sym" | -| `--qe-bits-acts` | `--qeba` | # of bits for quantization of activations | 8 | -| `--qe-bits-wts` | `--qebw` | # of bits for quantization of weights | 8 | +| `--qe-bits-acts` | `--qeba` | # of bits for quantization of activations. Use 0 to not quantize activations | 8 | +| `--qe-bits-wts` | `--qebw` | # of bits for quantization of weights. Use 0 to not quantize weights | 8 | | `--qe-bits-accum` | N/A | # of bits for quantization of the accumulator | 32 | | `--qe-clip-acts` | `--qeca` | Set activations clipping mode. Choices: "none", "avg", "n_std" | "none" | | `--qe-clip-n-stds` | N/A | When qe-clip-acts is set to 'n_std', this is the number of standard deviations to use | None | diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index a30524f53604889abc34fa7a9be793ce81b60fa8..d54ea330a14c8b898e83819d663074fcee0dcdd4 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -743,3 +743,64 @@ def test_acts_quant_params_rnn(rnn_model): quantizer.update_acts_quant_params(new_config) assert model.rnn.rnn.cells[0].act_o.output_scale == 4 assert model.embedding.w_scale == 59.0 + + +############################################################################### +# Test wrappers with weights-only quantization +############################################################################### +@pytest.fixture(params=[False, True], ids=['perch_off', 'perch_on']) +def per_channel(request): + return request.param + + +@pytest.fixture(params=[False, True], ids=['no_bias', 'with_bias']) +def bias(request): + return request.param + + +def _fake_quant_tensor(tensor, n_bits, mode, per_channel): + q_min, q_max = q_utils.get_quantized_range(n_bits, mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) + scale, zp = _get_quant_params_from_tensor(tensor, n_bits, mode, per_channel=per_channel) + q_utils.linear_quantize_clamp(tensor, scale, zp, q_min, q_max, inplace=True) + q_utils.linear_dequantize(tensor, scale, zp, inplace=True) + + +def _test_wts_only_quant(layer, x, per_channel, bias, num_bits): + layer.weight.data = torch.rand_like(layer.weight) + if bias: + layer.bias.data = torch.rand_like(layer.bias) + mode = LinearQuantMode.ASYMMETRIC_UNSIGNED + + layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits, mode=mode, per_channel_wts=per_channel) + layer_ptq.eval() + + layer_manual_q = deepcopy(layer) + _fake_quant_tensor(layer_manual_q.weight.data, num_bits, mode, per_channel) + assert torch.equal(layer_ptq.wrapped_module.weight, layer_manual_q.weight) + if bias: + _fake_quant_tensor(layer_manual_q.bias.data, num_bits, mode, False) + assert torch.equal(layer_ptq.wrapped_module.bias, layer_manual_q.bias) + + y_ptq = layer_ptq(x) + y_manual_q = layer_manual_q(x) + + assert torch.equal(y_ptq, y_manual_q) + + +def test_conv_layer_wrapper_params_only(per_channel, bias): + distiller.set_deterministic() + in_ch = 3 + layer = torch.nn.Conv2d(in_ch, 10, 3, bias=bias) + x = torch.rand(5, in_ch, 5, 5) + + _test_wts_only_quant(layer, x, per_channel, bias, 8) + + +def test_linear_layer_wrapper_params_only(per_channel, bias): + distiller.set_deterministic() + in_features = 50 + layer = torch.nn.Linear(in_features, 30, bias=bias) + + x = torch.rand(5, in_features) + + _test_wts_only_quant(layer, x, per_channel, bias)