From 245f483ccbd2fb170641c7b9605bfb2b053d40ef Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Tue, 28 Apr 2020 14:35:30 +0300
Subject: [PATCH] PTQ - Enable weights clipping in Embedding modules

---
 distiller/quantization/range_linear.py | 27 +++++++++++++++++++-------
 1 file changed, 20 insertions(+), 7 deletions(-)

diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 3bb9e40..e02f5ce 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -1318,7 +1318,8 @@ class FP16Wrapper(FPWrapper):
 
 
 class RangeLinearEmbeddingWrapper(nn.Module):
-    def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None, save_fp_weights=False):
+    def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None, save_fp_weights=False,
+                 clip_acts=ClipMode.NONE, clip_n_stds=None, clip_half_range=False):
         if not isinstance(wrapped_module, nn.Embedding):
             raise ValueError(self.__class__.__name__ + ' can only wrap torch.nn.Embedding modules')
 
@@ -1327,7 +1328,8 @@ class RangeLinearEmbeddingWrapper(nn.Module):
         mode = verify_quant_mode(mode)
         self.mode = mode
 
-        self.wts_quant_settings = QuantSettings(num_bits, self.mode.weights, ClipMode.NONE, None, False, False)
+        self.wts_quant_settings = QuantSettings(num_bits, self.mode.weights, clip_acts, clip_n_stds, clip_half_range,
+                                                False)
 
         self.params_min_q_val, self.params_max_q_val = get_quantized_range(
             self.wts_quant_settings.num_bits,
@@ -1339,9 +1341,13 @@ class RangeLinearEmbeddingWrapper(nn.Module):
             wrapped_module.register_buffer('float_weight', wrapped_module.weight.clone().detach())
 
         if stats is None:
-            w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights)
+            w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights,
+                                                                  clip=clip_acts, num_stds=clip_n_stds,
+                                                                  half_range=clip_half_range)
         else:
-            w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights)
+            w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights,
+                                                                      clip=clip_acts, num_stds=clip_n_stds,
+                                                                      half_range=clip_half_range)
 
         device = wrapped_module.weight.device
         self.register_buffer('w_scale', w_scale.to(device))
@@ -1358,7 +1364,9 @@ class RangeLinearEmbeddingWrapper(nn.Module):
             yield 'w_zero_point', self.w_zero_point
 
     def set_linear_quant_param(self, name, val):
-        if name in ['w_scale', 'w_zero_point']:
+        if name in dict(self.named_clipping()):
+            setattr(self, name, val)
+        elif name in ['w_scale', 'w_zero_point']:
             if self.save_fp_weights:
                 getattr(self, name).fill_(val)
                 self.wrapped_module.weight.data.copy_(self.wrapped_module.float_weight.data)
@@ -1714,13 +1722,18 @@ class PostTrainLinearQuantizer(Quantizer):
                               'Keeping the original FP32 module'.format(name, module.__class__.__name__), UserWarning)
                 return module
 
-        def replace_embedding(module, name, qbits_map, fp16=fp16, fpq_module=fpq_module):
+        def replace_embedding(module, name, qbits_map, fp16=fp16, fpq_module=fpq_module, clip_acts=clip_acts,
+                              clip_n_stds=clip_n_stds, clip_half_range=clip_half_range):
             fpq_module = _check_fp16_arg(fp16, fpq_module)
             if fpq_module:
                 return FPWrapper(module, fpq_module, convert_input=False)
             norm_name = distiller.utils.normalize_module_name(name)
+            if not self.also_clip_weights:
+                clip_acts, clip_n_stds, clip_half_range = ClipMode.NONE, None, False
             return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode,
-                                               stats=self.model_activation_stats.get(norm_name, None))
+                                               stats=self.model_activation_stats.get(norm_name, None),
+                                               save_fp_weights=self.save_fp_weights, clip_acts=clip_acts,
+                                               clip_n_stds=clip_n_stds, clip_half_range=clip_half_range)
 
         def replace_fake_quant(module, name, qbits_map, fp16=fp16,
                                clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-- 
GitLab