Skip to content
Snippets Groups Projects
Unverified Commit d6efbe40 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by GitHub
Browse files

Bug fix for shared module (#268)

* Fixed bug where a shared module which was supposed to be skipped wasn't skipped on the second reference

* Added tests for new bug fix
parent fe27ab90
No related branches found
No related tags found
No related merge requests found
......@@ -176,7 +176,7 @@ class Quantizer(object):
self.params_to_quantize = []
# A dictionary of replaced modules and their respective names.
self.modules_replaced = OrderedDict()
self.modules_processed = OrderedDict()
def _add_qbits_entry(self, module_name, module_type, qbits):
if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]:
......@@ -248,18 +248,25 @@ class Quantizer(object):
# Iterate through model, insert quantization functions as appropriate
for name, module in container.named_children():
full_name = prefix + name
if module in self.modules_replaced:
previous_name, previous_wrapper = self.modules_replaced[module]
if module in self.modules_processed:
previous_name, previous_wrapper = self.modules_processed[module]
warnings.warn("Module '{0}' references to same module as '{1}'."
' Replacing with reference the same wrapper.'.format(full_name, previous_name),
UserWarning)
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, previous_wrapper))
setattr(container, name, previous_wrapper)
if previous_wrapper:
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.
format(full_name, module, previous_wrapper))
setattr(container, name, previous_wrapper)
else:
msglogger.debug('Module {0}: Skipping \n{1}.'.format(full_name, module))
continue
current_qbits = self.module_qbits_map[full_name]
if current_qbits.acts is None and current_qbits.wts is None:
if self.module_overrides_map[full_name]:
raise ValueError("Adding overrides while not quantizing is not allowed.")
# We indicate this module wasn't replaced by a wrapper
msglogger.debug('Module {0}: Skipping \n{1}.'.format(full_name, module))
self.modules_processed[module] = full_name, None
continue
# We use a type hint comment to let IDEs know replace_fn is a function
......@@ -274,7 +281,7 @@ class Quantizer(object):
new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module))
# Add to history of prepared submodules
self.modules_replaced[module] = full_name, new_module
self.modules_processed[module] = full_name, new_module
setattr(container, name, new_module)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
......
......@@ -20,6 +20,7 @@ from copy import deepcopy
from collections import OrderedDict
import pytest
import distiller
from distiller.quantization import Quantizer
from distiller.quantization.quantizer import QBits, _ParamToQuant
from distiller.quantization.quantizer import FP_BKP_PREFIX
......@@ -423,14 +424,34 @@ def test_overridable_args(model, optimizer, train_with_fp_copy):
assert model_copy.relu1.overridable_prop == 123
def test_shared_submodule(optimizer, train_with_fp_copy):
@pytest.mark.parametrize(
"overrides, expected_relu_type, is_skipped",
[
(None, DummyQuantLayer, False),
(distiller.utils.yaml_ordered_load("""
dense1.relu:
bits_activations: null
bits_weights: null
"""), nn.ReLU, True)
]
)
def test_shared_submodule(optimizer, train_with_fp_copy, overrides, expected_relu_type, is_skipped):
with pytest.warns(UserWarning,
match="Module '{0}' references to same module as '{1}'.".format('dense2.relu', 'dense1.relu')):
densenet = DummyModelWithSharedSubmodule(1024, 1024, 1000)
relu = densenet.dense1.relu
quantizer = DummyQuantizer(densenet,
bits_weights=8, bits_activations=8, bits_bias=32,
optimizer=optimizer,
train_with_fp_copy=train_with_fp_copy)
train_with_fp_copy=train_with_fp_copy,
overrides=deepcopy(overrides))
quantizer.prepare_model()
assert isinstance(quantizer.model.dense1.relu, DummyQuantLayer)
assert isinstance(quantizer.model.dense1.relu, expected_relu_type)
assert quantizer.model.dense1.relu == quantizer.model.dense2.relu
assert quantizer.modules_processed[relu] is not None
if is_skipped:
assert quantizer.modules_processed[relu][1] is None
else:
assert quantizer.modules_processed[relu][1] == quantizer.model.dense1.relu
assert quantizer.modules_processed[relu][1] == quantizer.model.dense2.relu
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment