diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 6a2430491e2e24a4eedc1c5434baff0c0ef4fade..67cd85c6149bfc19e808b130968eca3531358ba4 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -511,6 +511,10 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): return tmpstr +class NoStatsError(NotImplementedError): + pass + + class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None): @@ -518,8 +522,8 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.Concat modules') if not activation_stats: - raise ValueError(self.__class__.__name__ + - ' must get activation stats, dynamic quantization not supported') + raise NoStatsError(self.__class__.__name__ + + ' must get activation stats, dynamic quantization not supported') super(RangeLinearQuantConcatWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, @@ -564,8 +568,8 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseAdd modules') if not activation_stats: - raise ValueError(self.__class__.__name__ + - ' must get activation stats, dynamic quantization not supported') + raise NoStatsError(self.__class__.__name__ + + ' must get activation stats, dynamic quantization not supported') super(RangeLinearQuantEltwiseAddWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, @@ -612,8 +616,8 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseMult modules') if not activation_stats: - raise ValueError(self.__class__.__name__ + - ' must get activation stats, dynamic quantization not supported') + raise NoStatsError(self.__class__.__name__ + + ' must get activation stats, dynamic quantization not supported') super(RangeLinearQuantEltwiseMultWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, @@ -773,9 +777,14 @@ class PostTrainLinearQuantizer(Quantizer): return FP16Wrapper(module) norm_name = distiller.utils.normalize_module_name(name) clip = self.clip_acts if norm_name not in self.no_clip_layers else ClipMode.NONE - return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip, - activation_stats=self.model_activation_stats.get(norm_name, None), - clip_n_stds=clip_n_stds) + try: + return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip, + activation_stats=self.model_activation_stats.get(norm_name, None), + clip_n_stds=clip_n_stds) + except NoStatsError: + msglogger.warning('WARNING: {0} - quantization of {1} without stats not supported. ' + 'Keeping the original FP32 module'.format(name, module.__class__.__name__)) + return module def replace_embedding(module, name, qbits_map, fp16=fp16): if fp16: