Skip to content
Snippets Groups Projects
Unverified Commit 25d7a7b7 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by GitHub
Browse files

Fixed KeyError treatment in Quantizer._pre_process_container (#263)

* Made quantizer.replacemet_factory as defaultdict and removed 'except KeyError' in pre_process_container

* Added explanations and type hint for replace_fn 
parent 6b832025
No related branches found
No related tags found
No related merge requests found
......@@ -14,7 +14,7 @@
# limitations under the License.
#
from collections import namedtuple, OrderedDict
from collections import namedtuple, OrderedDict, defaultdict
import re
import copy
import logging
......@@ -22,6 +22,7 @@ import torch
import torch.nn as nn
import distiller
import warnings
from typing import Callable, Optional
msglogger = logging.getLogger()
......@@ -165,7 +166,8 @@ class Quantizer(object):
# Mapping from module type to function generating a replacement module suited for quantization
# To be populated by child classes
self.replacement_factory = {}
# Unspecified layer types return None by default.
self.replacement_factory = defaultdict(lambda: None)
# Pointer to parameters quantization function, triggered during training process
# To be populated by child classes
self.param_quantization_fn = None
......@@ -259,15 +261,17 @@ class Quantizer(object):
if self.module_overrides_map[full_name]:
raise ValueError("Adding overrides while not quantizing is not allowed.")
continue
try:
replace_fn = self.replacement_factory[type(module)]
# This hints pycharm the replace_fn is a function
replace_fn: Optional[Callable] = self.replacement_factory[type(module)]
# If the replacement function wasn't specified - continue without replacing this module.
if replace_fn is not None:
valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], replace_fn)
if invalid_kwargs:
raise TypeError("""Quantizer of type %s doesn't accept \"%s\"
as override arguments for %s. Allowed kwargs: %s"""
% (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
new_module = self.replacement_factory[type(module)](module, full_name,
self.module_qbits_map, **valid_kwargs)
new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
msglogger.debug('Module {0}: Replacing \n{1} with \n{2}'.format(full_name, module, new_module))
# Add to history of prepared submodules
self.modules_replaced[module] = full_name, new_module
......@@ -278,8 +282,6 @@ class Quantizer(object):
for sub_module_name, sub_module in new_module.named_modules():
self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
except KeyError:
pass
if distiller.has_children(module):
# For container we call recursively
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment