From 98b5469593a52b925623e5204120255605cf488f Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Sun, 31 Mar 2019 16:35:27 +0300
Subject: [PATCH] Post train quant - Add wrapper for embedding layer

---
 distiller/quantization/range_linear.py | 35 ++++++++++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 234bfe8..2445a57 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -570,6 +570,34 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
         return output_scale / self.accum_scale, output_zero_point
 
 
+class RangeLinearEmbeddingWrapper(nn.Module):
+    def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None):
+        if not isinstance(wrapped_module, nn.Embedding):
+            raise ValueError(self.__class__.__name__ + ' can only wrap torch.nn.Embedding modules')
+
+        super(RangeLinearEmbeddingWrapper, self).__init__()
+
+        self.min_q_val, self.max_q_val = get_quantized_range(num_bits,
+                                                             signed=mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
+
+        if stats is None:
+            w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, self.mode)
+        else:
+            w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode)
+
+        self.register_buffer('w_scale', w_scale)
+        self.register_buffer('w_zero_point', w_zero_point)
+        linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val,
+                              self.max_q_val, inplace=True)
+
+        self.wrapped_module = wrapped_module
+
+    def forward(self, input):
+        out_q = self.wrapped_module(input)
+        out_f = linear_dequantize(out_q, self.w_scale, self.w_zero_point, inplace=True)
+        return out_f
+
+
 class PostTrainLinearQuantizer(Quantizer):
     """
     Applies range-based linear quantization to a model.
@@ -628,6 +656,12 @@ class PostTrainLinearQuantizer(Quantizer):
             return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip,
                                 activation_stats=self.model_activation_stats.get(norm_name, None))
 
+        def replace_embedding(module, name, qbits_map):
+            norm_name = distiller.utils.normalize_module_name(name)
+            return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode,
+                                               stats=self.model_activation_stats.get(norm_name, None))
+
+
         self.clip_acts = clip_acts
         self.no_clip_layers = [] or no_clip_layers
         self.model_activation_stats = model_activation_stats or {}
@@ -641,6 +675,7 @@ class PostTrainLinearQuantizer(Quantizer):
             replace_non_param_layer, RangeLinearQuantEltwiseAddWrapper)
         self.replacement_factory[distiller.modules.EltwiseMult] = partial(
             replace_non_param_layer, RangeLinearQuantEltwiseMultWrapper)
+        self.replacement_factory[nn.Embedding] = replace_embedding
 
     @classmethod
     def from_args(cls, model, args):
-- 
GitLab