diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 6b1c24e1e1e75d9cd2976a3ae5dddfb33e2bae0d..0c21ab36f51d18ffd62faeeb886703eee10ec018 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -176,7 +176,7 @@ class Quantizer(object): self.params_to_quantize = [] # A dictionary of replaced modules and their respective names. - self.modules_replaced = OrderedDict() + self.modules_processed = OrderedDict() def _add_qbits_entry(self, module_name, module_type, qbits): if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]: @@ -248,18 +248,25 @@ class Quantizer(object): # Iterate through model, insert quantization functions as appropriate for name, module in container.named_children(): full_name = prefix + name - if module in self.modules_replaced: - previous_name, previous_wrapper = self.modules_replaced[module] + if module in self.modules_processed: + previous_name, previous_wrapper = self.modules_processed[module] warnings.warn("Module '{0}' references to same module as '{1}'." ' Replacing with reference the same wrapper.'.format(full_name, previous_name), UserWarning) - msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, previous_wrapper)) - setattr(container, name, previous_wrapper) + if previous_wrapper: + msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'. + format(full_name, module, previous_wrapper)) + setattr(container, name, previous_wrapper) + else: + msglogger.debug('Module {0}: Skipping \n{1}.'.format(full_name, module)) continue current_qbits = self.module_qbits_map[full_name] if current_qbits.acts is None and current_qbits.wts is None: if self.module_overrides_map[full_name]: raise ValueError("Adding overrides while not quantizing is not allowed.") + # We indicate this module wasn't replaced by a wrapper + msglogger.debug('Module {0}: Skipping \n{1}.'.format(full_name, module)) + self.modules_processed[module] = full_name, None continue # We use a type hint comment to let IDEs know replace_fn is a function @@ -274,7 +281,7 @@ class Quantizer(object): new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs) msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module)) # Add to history of prepared submodules - self.modules_replaced[module] = full_name, new_module + self.modules_processed[module] = full_name, new_module setattr(container, name, new_module) # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index d1f151eb68522711744fe2c64453488dfb75f33d..d1f3763bdbb4662ed644cc082938e101963bb4d9 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -20,6 +20,7 @@ from copy import deepcopy from collections import OrderedDict import pytest +import distiller from distiller.quantization import Quantizer from distiller.quantization.quantizer import QBits, _ParamToQuant from distiller.quantization.quantizer import FP_BKP_PREFIX @@ -423,14 +424,34 @@ def test_overridable_args(model, optimizer, train_with_fp_copy): assert model_copy.relu1.overridable_prop == 123 -def test_shared_submodule(optimizer, train_with_fp_copy): +@pytest.mark.parametrize( + "overrides, expected_relu_type, is_skipped", + [ + (None, DummyQuantLayer, False), + (distiller.utils.yaml_ordered_load(""" + dense1.relu: + bits_activations: null + bits_weights: null + """), nn.ReLU, True) + ] +) +def test_shared_submodule(optimizer, train_with_fp_copy, overrides, expected_relu_type, is_skipped): with pytest.warns(UserWarning, match="Module '{0}' references to same module as '{1}'.".format('dense2.relu', 'dense1.relu')): densenet = DummyModelWithSharedSubmodule(1024, 1024, 1000) + relu = densenet.dense1.relu quantizer = DummyQuantizer(densenet, bits_weights=8, bits_activations=8, bits_bias=32, optimizer=optimizer, - train_with_fp_copy=train_with_fp_copy) + train_with_fp_copy=train_with_fp_copy, + overrides=deepcopy(overrides)) quantizer.prepare_model() - assert isinstance(quantizer.model.dense1.relu, DummyQuantLayer) + assert isinstance(quantizer.model.dense1.relu, expected_relu_type) assert quantizer.model.dense1.relu == quantizer.model.dense2.relu + assert quantizer.modules_processed[relu] is not None + if is_skipped: + assert quantizer.modules_processed[relu][1] is None + else: + assert quantizer.modules_processed[relu][1] == quantizer.model.dense1.relu + assert quantizer.modules_processed[relu][1] == quantizer.model.dense2.relu +