From 245f483ccbd2fb170641c7b9605bfb2b053d40ef Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Tue, 28 Apr 2020 14:35:30 +0300 Subject: [PATCH] PTQ - Enable weights clipping in Embedding modules --- distiller/quantization/range_linear.py | 27 +++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 3bb9e40..e02f5ce 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -1318,7 +1318,8 @@ class FP16Wrapper(FPWrapper): class RangeLinearEmbeddingWrapper(nn.Module): - def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None, save_fp_weights=False): + def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None, save_fp_weights=False, + clip_acts=ClipMode.NONE, clip_n_stds=None, clip_half_range=False): if not isinstance(wrapped_module, nn.Embedding): raise ValueError(self.__class__.__name__ + ' can only wrap torch.nn.Embedding modules') @@ -1327,7 +1328,8 @@ class RangeLinearEmbeddingWrapper(nn.Module): mode = verify_quant_mode(mode) self.mode = mode - self.wts_quant_settings = QuantSettings(num_bits, self.mode.weights, ClipMode.NONE, None, False, False) + self.wts_quant_settings = QuantSettings(num_bits, self.mode.weights, clip_acts, clip_n_stds, clip_half_range, + False) self.params_min_q_val, self.params_max_q_val = get_quantized_range( self.wts_quant_settings.num_bits, @@ -1339,9 +1341,13 @@ class RangeLinearEmbeddingWrapper(nn.Module): wrapped_module.register_buffer('float_weight', wrapped_module.weight.clone().detach()) if stats is None: - w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights) + w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights, + clip=clip_acts, num_stds=clip_n_stds, + half_range=clip_half_range) else: - w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights) + w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights, + clip=clip_acts, num_stds=clip_n_stds, + half_range=clip_half_range) device = wrapped_module.weight.device self.register_buffer('w_scale', w_scale.to(device)) @@ -1358,7 +1364,9 @@ class RangeLinearEmbeddingWrapper(nn.Module): yield 'w_zero_point', self.w_zero_point def set_linear_quant_param(self, name, val): - if name in ['w_scale', 'w_zero_point']: + if name in dict(self.named_clipping()): + setattr(self, name, val) + elif name in ['w_scale', 'w_zero_point']: if self.save_fp_weights: getattr(self, name).fill_(val) self.wrapped_module.weight.data.copy_(self.wrapped_module.float_weight.data) @@ -1714,13 +1722,18 @@ class PostTrainLinearQuantizer(Quantizer): 'Keeping the original FP32 module'.format(name, module.__class__.__name__), UserWarning) return module - def replace_embedding(module, name, qbits_map, fp16=fp16, fpq_module=fpq_module): + def replace_embedding(module, name, qbits_map, fp16=fp16, fpq_module=fpq_module, clip_acts=clip_acts, + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range): fpq_module = _check_fp16_arg(fp16, fpq_module) if fpq_module: return FPWrapper(module, fpq_module, convert_input=False) norm_name = distiller.utils.normalize_module_name(name) + if not self.also_clip_weights: + clip_acts, clip_n_stds, clip_half_range = ClipMode.NONE, None, False return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode, - stats=self.model_activation_stats.get(norm_name, None)) + stats=self.model_activation_stats.get(norm_name, None), + save_fp_weights=self.save_fp_weights, clip_acts=clip_acts, + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range) def replace_fake_quant(module, name, qbits_map, fp16=fp16, clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, -- GitLab