From 98b5469593a52b925623e5204120255605cf488f Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Sun, 31 Mar 2019 16:35:27 +0300 Subject: [PATCH] Post train quant - Add wrapper for embedding layer --- distiller/quantization/range_linear.py | 35 ++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 234bfe8..2445a57 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -570,6 +570,34 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): return output_scale / self.accum_scale, output_zero_point +class RangeLinearEmbeddingWrapper(nn.Module): + def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None): + if not isinstance(wrapped_module, nn.Embedding): + raise ValueError(self.__class__.__name__ + ' can only wrap torch.nn.Embedding modules') + + super(RangeLinearEmbeddingWrapper, self).__init__() + + self.min_q_val, self.max_q_val = get_quantized_range(num_bits, + signed=mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) + + if stats is None: + w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, self.mode) + else: + w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode) + + self.register_buffer('w_scale', w_scale) + self.register_buffer('w_zero_point', w_zero_point) + linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val, + self.max_q_val, inplace=True) + + self.wrapped_module = wrapped_module + + def forward(self, input): + out_q = self.wrapped_module(input) + out_f = linear_dequantize(out_q, self.w_scale, self.w_zero_point, inplace=True) + return out_f + + class PostTrainLinearQuantizer(Quantizer): """ Applies range-based linear quantization to a model. @@ -628,6 +656,12 @@ class PostTrainLinearQuantizer(Quantizer): return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip, activation_stats=self.model_activation_stats.get(norm_name, None)) + def replace_embedding(module, name, qbits_map): + norm_name = distiller.utils.normalize_module_name(name) + return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode, + stats=self.model_activation_stats.get(norm_name, None)) + + self.clip_acts = clip_acts self.no_clip_layers = [] or no_clip_layers self.model_activation_stats = model_activation_stats or {} @@ -641,6 +675,7 @@ class PostTrainLinearQuantizer(Quantizer): replace_non_param_layer, RangeLinearQuantEltwiseAddWrapper) self.replacement_factory[distiller.modules.EltwiseMult] = partial( replace_non_param_layer, RangeLinearQuantEltwiseMultWrapper) + self.replacement_factory[nn.Embedding] = replace_embedding @classmethod def from_args(cls, model, args): -- GitLab