Skip to content
Snippets Groups Projects
Commit a69dd5d6 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by Guy Jacob
Browse files

Quantizer: Proper handling of modules that point to same object (#239)

parent 343e9a82
No related branches found
No related tags found
No related merge requests found
...@@ -173,6 +173,9 @@ class Quantizer(object): ...@@ -173,6 +173,9 @@ class Quantizer(object):
self.train_with_fp_copy = train_with_fp_copy self.train_with_fp_copy = train_with_fp_copy
self.params_to_quantize = [] self.params_to_quantize = []
# A dictionary of replaced modules and their respective names.
self.modules_replaced = OrderedDict()
def _add_qbits_entry(self, module_name, module_type, qbits): def _add_qbits_entry(self, module_name, module_type, qbits):
if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]: if module_type not in [nn.Conv2d, nn.Linear, nn.Embedding]:
# For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't
...@@ -184,6 +187,26 @@ class Quantizer(object): ...@@ -184,6 +187,26 @@ class Quantizer(object):
self.module_overrides_map[module_name] = entry self.module_overrides_map[module_name] = entry
def prepare_model(self): def prepare_model(self):
"""
Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width
and overrides configuration provided to __init__(), and according to the replacement_factory as
defined by the Quantizer sub-class being used.
Note:
If multiple sub-modules within the model actually reference the same module, then that module
is replaced only once, according to the configuration (bit-width and/or overrides) of the
first encountered reference.
Toy Example - say a module is constructed using this bit of code:
shared_relu = nn.ReLU
self.relu1 = shared_relu
self.relu2 = shared_relu
When traversing the model, a replacement will be generated when 'self.relu1' is encountered.
Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced
with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
will be ignored. A warning message will be shown.
"""
self._prepare_model_impl() self._prepare_model_impl()
msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
...@@ -226,6 +249,14 @@ class Quantizer(object): ...@@ -226,6 +249,14 @@ class Quantizer(object):
# Iterate through model, insert quantization functions as appropriate # Iterate through model, insert quantization functions as appropriate
for name, module in container.named_children(): for name, module in container.named_children():
full_name = prefix + name full_name = prefix + name
if module in self.modules_replaced:
previous_name, previous_wrapper = self.modules_replaced[module]
warnings.warn("Module '{0}' references to same module as '{1}'."
' Replacing with reference the same wrapper.'.format(full_name, previous_name),
UserWarning)
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, previous_wrapper))
setattr(container, name, previous_wrapper)
continue
current_qbits = self.module_qbits_map[full_name] current_qbits = self.module_qbits_map[full_name]
if current_qbits.acts is None and current_qbits.wts is None: if current_qbits.acts is None and current_qbits.wts is None:
if self.module_overrides_map[full_name]: if self.module_overrides_map[full_name]:
...@@ -241,6 +272,8 @@ class Quantizer(object): ...@@ -241,6 +272,8 @@ class Quantizer(object):
new_module = self.replacement_factory[type(module)](module, full_name, new_module = self.replacement_factory[type(module)](module, full_name,
self.module_qbits_map, **valid_kwargs) self.module_qbits_map, **valid_kwargs)
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module)) msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module))
# Add to history of prepared submodules
self.modules_replaced[module] = full_name, new_module
setattr(container, name, new_module) setattr(container, name, new_module)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
......
...@@ -83,6 +83,33 @@ class DummyModel(nn.Sequential): ...@@ -83,6 +83,33 @@ class DummyModel(nn.Sequential):
p.data = torch.zeros_like(p) p.data = torch.zeros_like(p)
class DummyDenseWithRelu(nn.Module):
def __init__(self, input_size, output_size, relu=None):
super(DummyDenseWithRelu, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.relu = relu or nn.ReLU()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.relu(self.linear(x))
class DummyModelWithSharedSubmodule(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(DummyModelWithSharedSubmodule, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.dense1 = DummyDenseWithRelu(input_size, hidden_size)
self.dense2 = DummyDenseWithRelu(hidden_size, output_size, self.dense1.relu)
def forward(self, x):
x = self.dense1(x)
x = self.dense2(x)
return x
############################# #############################
# Dummy Quantizer # Dummy Quantizer
############################# #############################
...@@ -111,6 +138,7 @@ class DummyQuantizer(Quantizer): ...@@ -111,6 +138,7 @@ class DummyQuantizer(Quantizer):
self.replacement_factory[nn.Conv2d] = _dummy_wrapper_layer self.replacement_factory[nn.Conv2d] = _dummy_wrapper_layer
self.replacement_factory[nn.ReLU] = _dummy_quant_layer self.replacement_factory[nn.ReLU] = _dummy_quant_layer
self.replacement_factory[nn.Linear] = _dummy_wrapper_layer
self.param_quantization_fn = dummy_quantize_params self.param_quantization_fn = dummy_quantize_params
...@@ -118,7 +146,7 @@ class DummyQuantizer(Quantizer): ...@@ -118,7 +146,7 @@ class DummyQuantizer(Quantizer):
# Other utils # Other utils
############################# #############################
expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer} expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}
def params_quantizable(module): def params_quantizable(module):
...@@ -136,7 +164,7 @@ def get_expected_qbits(model, qbits, expected_overrides): ...@@ -136,7 +164,7 @@ def get_expected_qbits(model, qbits, expected_overrides):
expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b) expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b)
# We're testing replacement of module with container # We're testing replacement of module with container
if isinstance(orig_module, nn.Conv2d): if isinstance(orig_module, (nn.Conv2d, nn.Linear)):
post_prepare_changes[orig_name] = QBits(bits_a, None, None) post_prepare_changes[orig_name] = QBits(bits_a, None, None)
post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name] post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name]
...@@ -394,3 +422,16 @@ def test_overridable_args(model, optimizer, train_with_fp_copy): ...@@ -394,3 +422,16 @@ def test_overridable_args(model, optimizer, train_with_fp_copy):
q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy) q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
q.prepare_model() q.prepare_model()
assert model_copy.relu1.overridable_prop == 123 assert model_copy.relu1.overridable_prop == 123
def test_shared_submodule(optimizer, train_with_fp_copy):
with pytest.warns(UserWarning,
match="Module '{0}' references to same module as '{1}'.".format('dense2.relu', 'dense1.relu')):
densenet = DummyModelWithSharedSubmodule(1024, 1024, 1000)
quantizer = DummyQuantizer(densenet,
bits_weights=8, bits_activations=8, bits_bias=32,
optimizer=optimizer,
train_with_fp_copy=train_with_fp_copy)
quantizer.prepare_model()
assert isinstance(quantizer.model.dense1.relu, DummyQuantLayer)
assert quantizer.model.dense1.relu == quantizer.model.dense2.relu
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment