diff --git a/distiller/quantization/ptq_greedy_search.py b/distiller/quantization/ptq_greedy_search.py index 2b5d96f381955872eec7db1dceba57a43a1d9bea..2d9d8a5b252f380375acda7abf0b4ddd9ef2894d 100644 --- a/distiller/quantization/ptq_greedy_search.py +++ b/distiller/quantization/ptq_greedy_search.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2018 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 37101a8c20a623b686a9c3e528df4f3588f621b1..5a7c89bbd1b5c0a4dcc8b041c1858fa90b3c811d 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -370,6 +370,10 @@ class RangeLinearQuantWrapper(nn.Module): self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long)) + def named_acts_quant_params(self): + yield 'output_scale', self.output_scale + yield 'output_zero_point', self.output_zero_point + def forward(self, *inputs): if self.training: raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode") @@ -619,7 +623,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): if self.is_simulated_quant_weight_shifted: # We want to return the weights to their integer representation: self.wrapped_module.weight.data -= self.w_zero_point - self.is_simulated_quant_weight_shifted.sub_(1) # i.e. is_simulated_quant_weight_shifted = False + self.is_simulated_quant_weight_shifted.fill_(False) # i.e. is_simulated_quant_weight_shifted = False return super(RangeLinearQuantParamLayerWrapper, self).state_dict(destination, prefix, keep_vars) def quantized_forward(self, input_q): @@ -664,7 +668,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # We "store" the w_zero_point inside our wrapped module's weights to # improve performance on inference. self.wrapped_module.weight.data += self.w_zero_point - self.is_simulated_quant_weight_shifted.add_(1) # i.e. is_simulated_quant_weight_shifted = True + self.is_simulated_quant_weight_shifted.fill_(True) # i.e. is_simulated_quant_weight_shifted = True input_q += input_q.quant_metadata.zero_point accum = self.wrapped_module.forward(input_q) @@ -943,6 +947,10 @@ class RangeLinearEmbeddingWrapper(nn.Module): self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val) self.wrapped_module = wrapped_module + def named_acts_quant_params(self): + yield 'w_scale', self.w_scale + yield 'w_zero_point', self.w_zero_point + def forward(self, input): out_q = self.wrapped_module(input) out_f = linear_dequantize(out_q, self.w_scale, self.w_zero_point, inplace=True) @@ -983,6 +991,10 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): return output_scale, output_zero_point +def _is_range_linear_wrapper(module): + return isinstance(module, (RangeLinearEmbeddingWrapper, RangeLinearQuantWrapper)) + + class PostTrainLinearQuantizer(Quantizer): """ Applies range-based linear quantization to a model. @@ -1215,6 +1227,32 @@ class PostTrainLinearQuantizer(Quantizer): save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.' self.save_per_layer_parameters(save_dir) + def named_acts_quant_params(self): + for module_name, module in self.model.named_modules(): + if _is_range_linear_wrapper(module): + for buff_name, buff in module.named_acts_quant_params(): + full_buff_name = "%s.%s" % (module_name, buff_name) + yield full_buff_name, buff + + def set_act_quant_param(self, name, val): + """ + Sets the the quant parameter by module_name.quant_param_name. + Args: + name (str): the name of the quant param [module_name].[quant_param_name] + val (int or float or torch.Tensor): the new value. + """ + self.acts_quant_params[name].fill_(val) + + def update_acts_quant_params(self, new_config): + """ + Updates all the quant params using a dictionary. + Args: + new_config (dict): the new configuration dict. + """ + for k, v in new_config.items(): + self.set_act_quant_param(k, v) + + @classmethod def from_args(cls, model, args): """ @@ -1279,6 +1317,8 @@ class PostTrainLinearQuantizer(Quantizer): raise ValueError('PostTrainLinearQuantizer requires dummy input in order to perform certain optimizations') super(PostTrainLinearQuantizer, self).prepare_model(dummy_input) + self.acts_quant_params = OrderedDict(self.named_acts_quant_params()) + def _pre_prepare_model(self, dummy_input): if not self.has_bidi_distiller_lstm: self._apply_bn_folding(dummy_input) diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 685c956cf660ad71898d076e222ad5df878d5602..4718b15d4d74c2734b6abc685fa9a09e0b93c87e 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -668,3 +668,75 @@ def test_stats_fusion_split_act(act1_type, act2_type, bn_out_stats, linear_out_e expected.pop('bn') # After BN folding BN stats are removed expected['linear']['output'] = linear_out_expected_stats assert quantizer.model_activation_stats == expected + + +############################################################################### +# Test Get/Set scale & zero_point of wrappers +############################################################################### +@pytest.mark.parametrize( + 'act1_type, act2_type, bn_out_stats', + [ + ('relu', 'relu', stats_entry(-5., 5., -3., 3., 0., 0.5)), + ('relu', 'sigmoid', stats_entry(-5., 5., -3., 3., 0., 0.5)), + ('relu', 'tanh', stats_entry(-5., 5., -3., 3., 0., 0.5)), + ], + ids=['relu-relu', 'relu-sigmoid', 'relu-tanh'] +) +def test_acts_quant_params_linear(act1_type, act2_type, bn_out_stats): + # prepare model: + model = LinearBNSplitAct(act1_type, act2_type) + stats = gen_stats_for_model(model) + stats['bn']['output'] = bn_out_stats + quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats)) + quantizer.prepare_model(torch.randn(10, 10)) + # get quant params: + expected_quant_params_keys = { + 'linear.output_zero_point', + 'linear.output_scale', + 'act1.output_zero_point', + 'act1.output_scale', + 'act2.output_zero_point', + 'act2.output_scale' + } + assert set(quantizer.acts_quant_params) == expected_quant_params_keys + quantizer.set_act_quant_param('linear.output_zero_point', 2.) + quantizer.set_act_quant_param('linear.output_scale', 30.) + assert model.linear.output_zero_point == 2. + assert model.linear.output_scale == 30. + expected_quant_param_linear_dict = { + 'output_zero_point': torch.tensor(2.), + 'output_scale': 30. + } + assert dict(model.linear.named_acts_quant_params()) == expected_quant_param_linear_dict + new_config = { + 'linear.output_zero_point': 4., + 'act2.output_scale': 50 + } + quantizer.update_acts_quant_params(new_config) + assert model.linear.output_zero_point == 4 + assert model.act2.output_scale == 50 + + +class DummyWordLangModel(nn.Module): + def __init__(self, embedding, rnn): + super(DummyWordLangModel, self).__init__() + self.embedding = embedding + self.rnn = rnn + + def forward(self, x): + return self.rnn(self.embedding(x)) + + +def test_acts_quant_params_rnn(rnn_model): + model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model).cuda() + stats = gen_stats_for_model(model) + quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats)) + dummy_input = torch.randint(0, 41, size=(79, 23)) + quantizer.prepare_model(dummy_input) + new_config = { + 'rnn.rnn.cells.0.act_o.output_scale': 4, + 'embedding.w_scale': torch.tensor(59.0) + } + quantizer.update_acts_quant_params(new_config) + assert model.rnn.rnn.cells[0].act_o.output_scale == 4 + assert model.embedding.w_scale == 59.0