diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py index 306f319689aa6994c997b4017dd555498260d323..2e3ca7b989947a4efc22be85b100c816903bbc30 100644 --- a/distiller/quantization/q_utils.py +++ b/distiller/quantization/q_utils.py @@ -17,29 +17,70 @@ import torch +def _prep_saturation_val_tensor(sat_val): + is_scalar = not isinstance(sat_val, torch.Tensor) + out = torch.tensor(sat_val) + if not out.is_floating_point(): + out = out.to(torch.float32) + if out.dim() == 0: + out = out.unsqueeze(0) + return is_scalar, out + + def symmetric_linear_quantization_params(num_bits, saturation_val): + is_scalar, sat_val = _prep_saturation_val_tensor(saturation_val) + + if any(sat_val < 0): + raise ValueError('Saturation value must be >= 0') + # Leave one bit for sign n = 2 ** (num_bits - 1) - 1 - scale = n / saturation_val - if isinstance(scale, torch.Tensor): - zero_point = torch.zeros_like(scale) - else: - zero_point = 0.0 + + # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation + # value to 'n', so the scale becomes 1 + sat_val[sat_val == 0] = n + scale = n / sat_val + zero_point = torch.zeros_like(scale) + + if is_scalar: + # If input was scalar, return scalars + return scale.item(), zero_point.item() return scale, zero_point def asymmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, integral_zero_point=True, signed=False): + scalar_min, sat_min = _prep_saturation_val_tensor(saturation_min) + scalar_max, sat_max = _prep_saturation_val_tensor(saturation_max) + is_scalar = scalar_min and scalar_max + + if scalar_max and not scalar_min: + sat_max = sat_max.to(sat_min.device) + elif scalar_min and not scalar_max: + sat_min = sat_min.to(sat_max.device) + + if any(sat_min > sat_max): + raise ValueError('saturation_min must be smaller than saturation_max') + n = 2 ** num_bits - 1 - scale = n / (saturation_max - saturation_min) - zero_point = scale * saturation_min + + # Make sure 0 is in the range + sat_min = torch.min(sat_min, torch.zeros_like(sat_min)) + sat_max = torch.max(sat_max, torch.zeros_like(sat_max)) + + diff = sat_max - sat_min + # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation + # value to 'n', so the scale becomes 1 + diff[diff == 0] = n + + scale = n / diff + zero_point = scale * sat_min if integral_zero_point: - if isinstance(zero_point, torch.Tensor): - zero_point = zero_point.round() - else: - zero_point = float(round(zero_point)) + zero_point = zero_point.round() if signed: zero_point += 2 ** (num_bits - 1) + if is_scalar: + return scale.item(), zero_point.item() return scale, zero_point diff --git a/tests/test_quant_utils.py b/tests/test_quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f613e50dcfb4c054f8581f1fe88afa19cf53d911 --- /dev/null +++ b/tests/test_quant_utils.py @@ -0,0 +1,159 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import pytest +import sys +import os +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) +from distiller.quantization import q_utils as qu + + +def test_symmetric_qparams(): + with pytest.raises(ValueError): + # Negative scalar + qu.symmetric_linear_quantization_params(8, -5.) + + # Negative element in tensor + qu.symmetric_linear_quantization_params(8, torch.tensor([-5., 10.])) + + # Scalar positive + scale, zp = qu.symmetric_linear_quantization_params(8, 4.) + assert not isinstance(scale, torch.Tensor) + assert not isinstance(zp, torch.Tensor) + assert scale == 31.75 + assert zp == 0 + + # Scalar positive integer + scale, zp = qu.symmetric_linear_quantization_params(8, 4) + assert not isinstance(scale, torch.Tensor) + assert not isinstance(zp, torch.Tensor) + assert scale == 31.75 + assert zp == 0 + + # Scalar zero + scale, zp = qu.symmetric_linear_quantization_params(8, 0.) + assert scale == 1 + assert zp == 0 + + # Tensor positives + sat = torch.tensor([4., 10.]) + scale, zp = qu.symmetric_linear_quantization_params(8, sat) + assert torch.equal(scale, torch.tensor([31.75, 12.7])) + assert torch.equal(zp, torch.zeros_like(sat)) + + # Tensor positives - integer saturation values + sat = torch.tensor([4, 10]) + scale, zp = qu.symmetric_linear_quantization_params(8, sat) + assert torch.equal(scale, torch.tensor([31.75, 12.7])) + assert torch.equal(zp, torch.zeros_like(sat, dtype=torch.float32)) + + # Tensor with 0 + sat = torch.tensor([4., 0.]) + scale, zp = qu.symmetric_linear_quantization_params(8, sat) + assert torch.equal(scale, torch.tensor([31.75, 1.])) + assert torch.equal(zp, torch.zeros_like(sat)) + + +def test_asymmetric_qparams(): + with pytest.raises(ValueError): + # Test min > max + # min scalar, max scalar + qu.asymmetric_linear_quantization_params(8, 5., 4.) + # min scalar, max tensor + qu.asymmetric_linear_quantization_params(8, 5., torch.tensor([5., 3.])) + # min tensor, max scalar + qu.asymmetric_linear_quantization_params(8, torch.tensor([5., 3.]), 4.) + # min tensor, max tensor + qu.asymmetric_linear_quantization_params(8, torch.tensor([5., 3.]), torch.tensor([4., 7.])) + + # min scalar, max scalar + + # Min negative, max positive + scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=True, signed=False) + assert not isinstance(scale, torch.Tensor) + assert not isinstance(zp, torch.Tensor) + assert scale == 21.25 + assert zp == -42 + + scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=True, signed=True) + assert scale == 21.25 + assert zp == 86 + + scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=False, signed=False) + assert scale == 21.25 + assert zp == -42.5 + + scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=False, signed=True) + assert scale == 21.25 + assert zp == 85.5 + + # Integer saturation values + scale, zp = qu.asymmetric_linear_quantization_params(8, -2, 10, integral_zero_point=False, signed=True) + assert scale == 21.25 + assert zp == 85.5 + + # Both positive + scale, zp = qu.asymmetric_linear_quantization_params(8, 5., 10.) + assert scale == 25.5 + assert zp == 0 + + # Both negative + scale, zp = qu.asymmetric_linear_quantization_params(8, -10., -5.) + assert scale == 25.5 + assert zp == -255 + + # Both zero + scale, zp = qu.asymmetric_linear_quantization_params(8, 0., 0.) + assert scale == 1. + assert zp == 0 + + # min scalar, max tensor + scale, zp = qu.asymmetric_linear_quantization_params(8, -10., torch.tensor([-2., 5.])) + assert torch.equal(scale,torch.tensor([25.5, 17])) + assert torch.equal(zp, torch.tensor([-255., -170])) + + scale, zp = qu.asymmetric_linear_quantization_params(8, 0., torch.tensor([0., 5.])) + assert torch.equal(scale, torch.tensor([1., 51.])) + assert torch.equal(zp, torch.tensor([0., 0.])) + + # Integer saturation values + scale, zp = qu.asymmetric_linear_quantization_params(8, -10., torch.tensor([-2, 5])) + assert torch.equal(scale, torch.tensor([25.5, 17])) + assert torch.equal(zp, torch.tensor([-255., -170])) + + # min tensor, max scalar + scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([-2., 5.]), 10.) + assert torch.equal(scale, torch.tensor([21.25, 25.5])) + assert torch.equal(zp, torch.tensor([-42., 0.])) + + scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([0., -5.]), 0.) + assert torch.equal(scale, torch.tensor([1., 51.])) + assert torch.equal(zp, torch.tensor([0., -255.])) + + # Integer saturation values + scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([-2, 5]), 10.) + assert torch.equal(scale, torch.tensor([21.25, 25.5])) + assert torch.equal(zp, torch.tensor([-42., 0.])) + + # min tensor, max tensor + scale, zp = qu.asymmetric_linear_quantization_params(8, + torch.tensor([-2., 5., -10., 0.]), + torch.tensor([10., 10., -5., 0.])) + assert torch.equal(scale, torch.tensor([21.25, 25.5, 25.5, 1.])) + assert torch.equal(zp, torch.tensor([-42., 0., -255., 0.]))