From 34e5d20f181df5e9d928bc6ba51d3d17e4ce4490 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Tue, 14 May 2019 11:17:38 +0300
Subject: [PATCH] Post train quant - warn instead of exit when dynamic quant
 not supported

---
 distiller/quantization/range_linear.py | 27 +++++++++++++++++---------
 1 file changed, 18 insertions(+), 9 deletions(-)

diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 6a24304..67cd85c 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -511,6 +511,10 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         return tmpstr
 
 
+class NoStatsError(NotImplementedError):
+    pass
+
+
 class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
                  activation_stats=None, clip_n_stds=None):
@@ -518,8 +522,8 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
             raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.Concat modules')
 
         if not activation_stats:
-            raise ValueError(self.__class__.__name__ +
-                             ' must get activation stats, dynamic quantization not supported')
+            raise NoStatsError(self.__class__.__name__ +
+                               ' must get activation stats, dynamic quantization not supported')
 
         super(RangeLinearQuantConcatWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                             clip_acts=clip_acts, activation_stats=activation_stats,
@@ -564,8 +568,8 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
             raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseAdd modules')
 
         if not activation_stats:
-            raise ValueError(self.__class__.__name__ +
-                             ' must get activation stats, dynamic quantization not supported')
+            raise NoStatsError(self.__class__.__name__ +
+                               ' must get activation stats, dynamic quantization not supported')
 
         super(RangeLinearQuantEltwiseAddWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                                 clip_acts=clip_acts, activation_stats=activation_stats,
@@ -612,8 +616,8 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
             raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseMult modules')
 
         if not activation_stats:
-            raise ValueError(self.__class__.__name__ +
-                             ' must get activation stats, dynamic quantization not supported')
+            raise NoStatsError(self.__class__.__name__ +
+                               ' must get activation stats, dynamic quantization not supported')
 
         super(RangeLinearQuantEltwiseMultWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                                  clip_acts=clip_acts, activation_stats=activation_stats,
@@ -773,9 +777,14 @@ class PostTrainLinearQuantizer(Quantizer):
                 return FP16Wrapper(module)
             norm_name = distiller.utils.normalize_module_name(name)
             clip = self.clip_acts if norm_name not in self.no_clip_layers else ClipMode.NONE
-            return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip,
-                                activation_stats=self.model_activation_stats.get(norm_name, None),
-                                clip_n_stds=clip_n_stds)
+            try:
+                return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip,
+                                    activation_stats=self.model_activation_stats.get(norm_name, None),
+                                    clip_n_stds=clip_n_stds)
+            except NoStatsError:
+                msglogger.warning('WARNING: {0} - quantization of {1} without stats not supported. '
+                                  'Keeping the original FP32 module'.format(name, module.__class__.__name__))
+                return module
 
         def replace_embedding(module, name, qbits_map, fp16=fp16):
             if fp16:
-- 
GitLab