Skip to content
Snippets Groups Projects
Unverified Commit ce3528e4 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

[Quantizer] Fix handling when default bits_activations == None (#345)

parent e65ec8fc
No related branches found
No related tags found
No related merge requests found
......@@ -281,28 +281,28 @@ class Quantizer(object):
# We indicate this module wasn't replaced by a wrapper
replace_msg(full_name)
self.modules_processed[module] = full_name, None
continue
# We use a type hint comment to let IDEs know replace_fn is a function
replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable]
# If the replacement function wasn't specified - continue without replacing this module.
if replace_fn is not None:
valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], replace_fn)
if invalid_kwargs:
raise TypeError("""Quantizer of type %s doesn't accept \"%s\"
as override arguments for %s. Allowed kwargs: %s"""
% (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
replace_msg(full_name, (module, new_module))
# Add to history of prepared submodules
self.modules_processed[module] = full_name, new_module
setattr(container, name, new_module)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
if not distiller.has_children(module) and distiller.has_children(new_module):
for sub_module_name, sub_module in new_module.named_modules():
self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
else:
# We use a type hint comment to let IDEs know replace_fn is a function
replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable]
# If the replacement function wasn't specified - continue without replacing this module.
if replace_fn is not None:
valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name],
replace_fn)
if invalid_kwargs:
raise TypeError("""Quantizer of type %s doesn't accept \"%s\"
as override arguments for %s. Allowed kwargs: %s"""
% (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
replace_msg(full_name, (module, new_module))
# Add to history of prepared submodules
self.modules_processed[module] = full_name, new_module
setattr(container, name, new_module)
# If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
if not distiller.has_children(module) and distiller.has_children(new_module):
for sub_module_name, sub_module in new_module.named_modules():
self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
if distiller.has_children(module):
# For container we call recursively
......
......@@ -147,30 +147,34 @@ class DummyQuantizer(Quantizer):
#############################
# Other utils
#############################
expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}
def params_quantizable(module):
return isinstance(module, (nn.Conv2d, nn.Linear))
def get_expected_qbits(model, qbits, expected_overrides):
expected_qbits = {}
post_prepare_changes = {}
expected_type_replacements = {nn.Conv2d: DummyWrapperLayer, nn.ReLU: DummyQuantLayer, nn.Linear: DummyWrapperLayer}
expected_qbits = OrderedDict()
post_prepare_qbbits_changes = OrderedDict()
post_prepare_expected_types = OrderedDict()
prefix = 'module.' if isinstance(model, torch.nn.DataParallel) else ''
for orig_name, orig_module in model.named_modules():
orig_module_type = type(orig_module)
bits_a, bits_w, bits_b = expected_overrides.get(orig_name.replace(prefix, '', 1), qbits)
if not params_quantizable(orig_module):
bits_w = bits_b = None
expected_qbits[orig_name] = QBits(bits_a, bits_w, bits_b)
if expected_qbits[orig_name] == QBits(None, None, None):
post_prepare_expected_types[orig_name] = orig_module_type
else:
post_prepare_expected_types[orig_name] = expected_type_replacements.get(orig_module_type, orig_module_type)
# We're testing replacement of module with container
if post_prepare_expected_types[orig_name] == DummyWrapperLayer:
post_prepare_qbbits_changes[orig_name] = QBits(bits_a, None, None)
post_prepare_qbbits_changes[orig_name + '.inner'] = expected_qbits[orig_name]
post_prepare_expected_types[orig_name + '.inner'] = orig_module_type
# We're testing replacement of module with container
if isinstance(orig_module, (nn.Conv2d, nn.Linear)):
post_prepare_changes[orig_name] = QBits(bits_a, None, None)
post_prepare_changes[orig_name + '.inner'] = expected_qbits[orig_name]
return expected_qbits, post_prepare_changes
return expected_qbits, post_prepare_qbbits_changes, post_prepare_expected_types
#############################
......@@ -251,7 +255,10 @@ bias_key = 'bits_bias'
'sub1.relu2': QBits(8, None, None), 'sub1.pool2': QBits(8, None, None)}),
(QBits(8, 4, 32),
OrderedDict([('conv1', {acts_key: 8, wts_key: 4, bias_key: None})]),
{'conv1': QBits(8, 4, None)})
{'conv1': QBits(8, 4, None)}),
(QBits(None, 8, 32),
OrderedDict([('conv1', {acts_key: 8, wts_key: 8, bias_key: 32})]),
{'conv1': QBits(8, 8, 32)})
],
ids=[
'no_override',
......@@ -260,7 +267,8 @@ bias_key = 'bits_bias'
'overlap_pattern_override_proper', # "proper" ==> Specific pattern before broader pattern
'overlap_pattern_override_wrong', # "wrong" ==> Broad pattern before specific pattern, so specific pattern
# never actually matched
'wts_quant_bias_not'
'wts_quant_bias_not',
'dont_quant_acts'
]
)
def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overrides,
......@@ -270,7 +278,9 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
m_orig = deepcopy(model)
# Build expected QBits
expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
expected_qbits, post_prepare_changes, post_prepare_expected_types = get_expected_qbits(model,
qbits,
explicit_expected_overrides)
# Initialize Quantizer
q = DummyQuantizer(model, optimizer=optimizer,
......@@ -317,15 +327,12 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
# Check module replacement is as expected
q_module = q_named_modules[orig_name]
expected_type = expected_type_replacements.get(type(orig_module))
if expected_type is None or expected_qbits[orig_name] == QBits(None, None, None):
assert type(orig_module) == type(q_module)
else:
assert type(q_module) == expected_type
if expected_type == DummyWrapperLayer:
assert expected_qbits[orig_name + '.inner'] == q_module.qbits
else:
assert expected_qbits[orig_name] == q_module.qbits
expected_type = post_prepare_expected_types[orig_name]
assert type(q_module) == expected_type
if expected_type == DummyWrapperLayer:
assert expected_qbits[orig_name + '.inner'] == q_module.qbits
elif expected_type == DummyQuantLayer:
assert expected_qbits[orig_name] == q_module.qbits
@pytest.mark.parametrize(
......@@ -344,7 +351,7 @@ def test_model_prep(model, optimizer, qbits, overrides, explicit_expected_overri
def test_param_quantization(model, optimizer, qbits, overrides, explicit_expected_overrides,
train_with_fp_copy):
# Build expected QBits
expected_qbits, post_prepare_changes = get_expected_qbits(model, qbits, explicit_expected_overrides)
expected_qbits, post_prepare_changes, _ = get_expected_qbits(model, qbits, explicit_expected_overrides)
q = DummyQuantizer(model, optimizer=optimizer,
bits_activations=qbits.acts, bits_weights=qbits.wts, bits_bias=qbits.bias,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment