diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index baa7f9c53cf4b58fe4e30ffe5debb6ed14c9087f..44efa5e38c8c4fa4f499f3a3b2dbdea877138fc1 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -221,14 +221,11 @@ class Quantizer(object): for module_name, module in self.model.named_modules(): qbits = self.module_qbits_map[module_name] - if qbits.wts is None: - continue - curr_parameters = dict(module.named_parameters()) for param_name, param in curr_parameters.items(): - # Bias is usually quantized according to the accumulator's number of bits - # Handle # of bits for bias quantization as "first-class" citizen, similarly to weights n_bits = qbits.bias if param_name.endswith('bias') else qbits.wts + if n_bits is None: + continue fp_attr_name = param_name if self.train_with_fp_copy: hack_float_backup_parameter(module, param_name, n_bits) diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index fea1fe09ce772b78b0d522d220d41ece5914f164..d1f151eb68522711744fe2c64453488dfb75f33d 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -245,14 +245,18 @@ bias_key = 'bits_bias' 'sub1.relu1': QBits(8, None, None), 'sub1.pool1': QBits(8, None, None), 'sub1.conv2': QBits(8, 8, 32), 'sub1.bn2': QBits(8, None, None), 'sub1.relu2': QBits(8, None, None), 'sub1.pool2': QBits(8, None, None)}), + (QBits(8, 4, 32), + OrderedDict([('conv1', {acts_key: 8, wts_key: 4, bias_key: None})]), + {'conv1': QBits(8, 4, None)}) ], ids=[ 'no_override', 'simple_override', 'pattern_override', 'overlap_pattern_override_proper', # "proper" ==> Specific pattern before broader pattern - 'overlap_pattern_override_wrong' # "wrong" ==> Broad pattern before specific pattern, so specific pattern + 'overlap_pattern_override_wrong', # "wrong" ==> Broad pattern before specific pattern, so specific pattern # never actually matched + 'wts_quant_bias_not' ] ) def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overrides, @@ -327,10 +331,10 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri OrderedDict([('conv1', {acts_key: None, wts_key: None, bias_key: None}), ('relu1', {acts_key: None, wts_key: None, bias_key: None}), ('sub.*conv1', {acts_key: 8, wts_key: 4, bias_key: 32}), - ('sub.*conv2', {acts_key: 4, wts_key: 4, bias_key: 32})]), + ('sub.*conv2', {acts_key: 4, wts_key: 4, bias_key: None})]), {'conv1': QBits(None, None, None), 'relu1': QBits(None, None, None), - 'sub1.conv1': QBits(8, 4, 32), 'sub1.conv2': QBits(4, 4, 32), 'sub2.conv1': QBits(8, 4, 32), - 'sub2.conv2': QBits(4, 4, 32)}), + 'sub1.conv1': QBits(8, 4, 32), 'sub1.conv2': QBits(4, 4, None), 'sub2.conv1': QBits(8, 4, 32), + 'sub2.conv2': QBits(4, 4, None)}), ] ) def test_param_quantization(model, optimizer, qbits, overrides, explicit_expected_overrides, @@ -352,14 +356,9 @@ def test_param_quantization(model, optimizer, qbits, overrides, explicit_expecte if has_children(pre_quant_module): continue - num_qbits = expected_qbits[name].wts - for param_name, pre_quant_param in pre_quant_module.named_parameters(): - quantizable = num_qbits is not None - if param_name.endswith('bias'): - num_bits = expected_qbits[name].bias - else: - num_bits = num_qbits + num_bits = expected_qbits[name].bias if param_name.endswith('bias') else expected_qbits[name].wts + quantizable = num_bits is not None if quantizable and train_with_fp_copy: # "param_name" and "pre_quant_param" refer to the float copy