diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index fcbd9342e833cba0829cda581a4a41c01e8ef9ca..8ede4b25b34207d68c542889c61c436be96d2b60 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -281,28 +281,28 @@ class Quantizer(object): # We indicate this module wasn't replaced by a wrapper replace_msg(full_name) self.modules_processed[module] = full_name, None - continue - - # We use a type hint comment to let IDEs know replace_fn is a function - replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable] - # If the replacement function wasn't specified - continue without replacing this module. - if replace_fn is not None: - valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], replace_fn) - if invalid_kwargs: - raise TypeError("""Quantizer of type %s doesn't accept \"%s\" - as override arguments for %s. Allowed kwargs: %s""" - % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs))) - new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs) - replace_msg(full_name, (module, new_module)) - # Add to history of prepared submodules - self.modules_processed[module] = full_name, new_module - setattr(container, name, new_module) - - # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping - if not distiller.has_children(module) and distiller.has_children(new_module): - for sub_module_name, sub_module in new_module.named_modules(): - self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits) - self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None) + else: + # We use a type hint comment to let IDEs know replace_fn is a function + replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable] + # If the replacement function wasn't specified - continue without replacing this module. + if replace_fn is not None: + valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], + replace_fn) + if invalid_kwargs: + raise TypeError("""Quantizer of type %s doesn't accept \"%s\" + as override arguments for %s. Allowed kwargs: %s""" + % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs))) + new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs) + replace_msg(full_name, (module, new_module)) + # Add to history of prepared submodules + self.modules_processed[module] = full_name, new_module + setattr(container, name, new_module) + + # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping + if not distiller.has_children(module) and distiller.has_children(new_module): + for sub_module_name, sub_module in new_module.named_modules(): + self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits) + self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None) if distiller.has_children(module): # For container we call recursively diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index eecfa3ffe3b3065b607a7c0eac1675717701a066..049e563eda6d0c8034a9414af0bef43b58181716 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -147,30 +147,34 @@ class DummyQuantizer(Quantizer): ############################# # Other utils ############################# - -expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer} - - def params_quantizable(module): return isinstance(module, (nn.Conv2d, nn.Linear)) def get_expected_qbits(model, qbits, expected_overrides): - expected_qbits = {} - post_prepare_changes = {} + expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer} + + expected_qbits = OrderedDict() + post_prepare_qbbits_changes = OrderedDict() + post_prepare_expected_types = OrderedDict() prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else '' for orig_name, orig_module in model.named_modules(): + orig_module_type = type(orig_module) bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits) if not params_quantizable(orig_module): bits_w = bits_b = None expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b) + if expected_qbits[orig_name] == QBits(None, None, None): + post_prepare_expected_types[orig_name] = orig_module_type + else: + post_prepare_expected_types[orig_name] = expected_type_replacements.get(orig_module_type, orig_module_type) + # We're testing replacement of module with container + if post_prepare_expected_types[orig_name] == DummyWrapperLayer: + post_prepare_qbbits_changes[orig_name] = QBits(bits_a, None, None) + post_prepare_qbbits_changes[orig_name + '.inner'] = expected_qbits[orig_name] + post_prepare_expected_types[orig_name + '.inner'] = orig_module_type - # We're testing replacement of module with container - if isinstance(orig_module, (nn.Conv2d, nn.Linear)): - post_prepare_changes[orig_name] = QBits(bits_a, None, None) - post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name] - - return expected_qbits, post_prepare_changes + return expected_qbits, post_prepare_qbbits_changes, post_prepare_expected_types ############################# @@ -251,7 +255,10 @@ bias_key = 'bits_bias' 'sub1.relu2': QBits(8, None, None), 'sub1.pool2': QBits(8, None, None)}), (QBits(8, 4, 32), OrderedDict([('conv1', {acts_key: 8, wts_key: 4, bias_key: None})]), - {'conv1': QBits(8, 4, None)}) + {'conv1': QBits(8, 4, None)}), + (QBits(None, 8, 32), + OrderedDict([('conv1', {acts_key: 8, wts_key: 8, bias_key: 32})]), + {'conv1': QBits(8, 8, 32)}) ], ids=[ 'no_override', @@ -260,7 +267,8 @@ bias_key = 'bits_bias' 'overlap_pattern_override_proper', # "proper" ==> Specific pattern before broader pattern 'overlap_pattern_override_wrong', # "wrong" ==> Broad pattern before specific pattern, so specific pattern # never actually matched - 'wts_quant_bias_not' + 'wts_quant_bias_not', + 'dont_quant_acts' ] ) def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overrides, @@ -270,7 +278,9 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri m_orig = deepcopy(model) # Build expected QBits - expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides) + expected_qbits, post_prepare_changes, post_prepare_expected_types = get_expected_qbits(model, + qbits, + explicit_expected_overrides) # Initialize Quantizer q = DummyQuantizer(model, optimizer=optimizer, @@ -317,15 +327,12 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri # Check module replacement is as expected q_module = q_named_modules[orig_name] - expected_type = expected_type_replacements.get(type(orig_module)) - if expected_type is None or expected_qbits[orig_name] == QBits(None, None, None): - assert type(orig_module) == type(q_module) - else: - assert type(q_module) == expected_type - if expected_type == DummyWrapperLayer: - assert expected_qbits[orig_name + '.inner'] == q_module.qbits - else: - assert expected_qbits[orig_name] == q_module.qbits + expected_type = post_prepare_expected_types[orig_name] + assert type(q_module) == expected_type + if expected_type == DummyWrapperLayer: + assert expected_qbits[orig_name + '.inner'] == q_module.qbits + elif expected_type == DummyQuantLayer: + assert expected_qbits[orig_name] == q_module.qbits @pytest.mark.parametrize( @@ -344,7 +351,7 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri def test_param_quantization(model, optimizer, qbits, overrides, explicit_expected_overrides, train_with_fp_copy): # Build expected QBits - expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides) + expected_qbits, post_prepare_changes, _ = get_expected_qbits(model, qbits, explicit_expected_overrides) q = DummyQuantizer(model, optimizer=optimizer, bits_activations=qbits.acts, bits_weights=qbits.wts, bits_bias=qbits.bias,