diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 8eee101ad4f535aa12f73732e557a1248d751db0..baa7f9c53cf4b58fe4e30ffe5debb6ed14c9087f 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -173,6 +173,9 @@ class Quantizer(object): self.train_with_fp_copy = train_with_fp_copy self.params_to_quantize = [] + # A dictionary of replaced modules and their respective names. + self.modules_replaced = OrderedDict() + def _add_qbits_entry(self, module_name, module_type, qbits): if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]: # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't @@ -184,6 +187,26 @@ class Quantizer(object): self.module_overrides_map[module_name] = entry def prepare_model(self): + """ + Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width + and overrides configuration provided to __init__(), and according to the replacement_factory as + defined by the Quantizer sub-class being used. + + Note: + If multiple sub-modules within the model actually reference the same module, then that module + is replaced only once, according to the configuration (bit-width and/or overrides) of the + first encountered reference. + Toy Example - say a module is constructed using this bit of code: + + shared_relu = nn.ReLU + self.relu1 = shared_relu + self.relu2 = shared_relu + + When traversing the model, a replacement will be generated when 'self.relu1' is encountered. + Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced + with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' + will be ignored. A warning message will be shown. + """ self._prepare_model_impl() msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) @@ -226,6 +249,14 @@ class Quantizer(object): # Iterate through model, insert quantization functions as appropriate for name, module in container.named_children(): full_name = prefix + name + if module in self.modules_replaced: + previous_name, previous_wrapper = self.modules_replaced[module] + warnings.warn("Module '{0}' references to same module as '{1}'." + ' Replacing with reference the same wrapper.'.format(full_name, previous_name), + UserWarning) + msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, previous_wrapper)) + setattr(container, name, previous_wrapper) + continue current_qbits = self.module_qbits_map[full_name] if current_qbits.acts is None and current_qbits.wts is None: if self.module_overrides_map[full_name]: @@ -241,6 +272,8 @@ class Quantizer(object): new_module = self.replacement_factory[type(module)](module, full_name, self.module_qbits_map, **valid_kwargs) msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module)) + # Add to history of prepared submodules + self.modules_replaced[module] = full_name, new_module setattr(container, name, new_module) # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index fdd47c77a87b04548de64cdb892f71b4e4139070..fea1fe09ce772b78b0d522d220d41ece5914f164 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -83,6 +83,33 @@ class DummyModel(nn.Sequential): p.data = torch.zeros_like(p) +class DummyDenseWithRelu(nn.Module): + def __init__(self, input_size, output_size, relu=None): + super(DummyDenseWithRelu, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.relu = relu or nn.ReLU() + self.linear = nn.Linear(input_size, output_size) + + def forward(self, x): + return self.relu(self.linear(x)) + + +class DummyModelWithSharedSubmodule(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(DummyModelWithSharedSubmodule, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.dense1 = DummyDenseWithRelu(input_size, hidden_size) + self.dense2 = DummyDenseWithRelu(hidden_size, output_size, self.dense1.relu) + + def forward(self, x): + x = self.dense1(x) + x = self.dense2(x) + return x + + ############################# # Dummy Quantizer ############################# @@ -111,6 +138,7 @@ class DummyQuantizer(Quantizer): self.replacement_factory[nn.Conv2d] = _dummy_wrapper_layer self.replacement_factory[nn.ReLU] = _dummy_quant_layer + self.replacement_factory[nn.Linear] = _dummy_wrapper_layer self.param_quantization_fn = dummy_quantize_params @@ -118,7 +146,7 @@ class DummyQuantizer(Quantizer): # Other utils ############################# -expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer} +expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer} def params_quantizable(module): @@ -136,7 +164,7 @@ def get_expected_qbits(model, qbits, expected_overrides): expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b) # We're testing replacement of module with container - if isinstance(orig_module, nn.Conv2d): + if isinstance(orig_module, (nn.Conv2d, nn.Linear)): post_prepare_changes[orig_name] = QBits(bits_a, None, None) post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name] @@ -394,3 +422,16 @@ def test_overridable_args(model, optimizer, train_with_fp_copy): q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy) q.prepare_model() assert model_copy.relu1.overridable_prop == 123 + + +def test_shared_submodule(optimizer, train_with_fp_copy): + with pytest.warns(UserWarning, + match="Module '{0}' references to same module as '{1}'.".format('dense2.relu', 'dense1.relu')): + densenet = DummyModelWithSharedSubmodule(1024, 1024, 1000) + quantizer = DummyQuantizer(densenet, + bits_weights=8, bits_activations=8, bits_bias=32, + optimizer=optimizer, + train_with_fp_copy=train_with_fp_copy) + quantizer.prepare_model() + assert isinstance(quantizer.model.dense1.relu, DummyQuantLayer) + assert quantizer.model.dense1.relu == quantizer.model.dense2.relu