Skip to content
Snippets Groups Projects
Commit 10ce938c authored by Guy Jacob's avatar Guy Jacob
Browse files

Updates to quantization parameters calculation:

* Always include 0 in the range
* Handle case where tensor is zeros only (fixes issue #115)
* Add unit tests
parent cfbc3798
No related branches found
No related tags found
No related merge requests found
...@@ -17,29 +17,70 @@ ...@@ -17,29 +17,70 @@
import torch import torch
def _prep_saturation_val_tensor(sat_val):
is_scalar = not isinstance(sat_val, torch.Tensor)
out = torch.tensor(sat_val)
if not out.is_floating_point():
out = out.to(torch.float32)
if out.dim() == 0:
out = out.unsqueeze(0)
return is_scalar, out
def symmetric_linear_quantization_params(num_bits, saturation_val): def symmetric_linear_quantization_params(num_bits, saturation_val):
is_scalar, sat_val = _prep_saturation_val_tensor(saturation_val)
if any(sat_val < 0):
raise ValueError('Saturation value must be >= 0')
# Leave one bit for sign # Leave one bit for sign
n = 2 ** (num_bits - 1) - 1 n = 2 ** (num_bits - 1) - 1
scale = n / saturation_val
if isinstance(scale, torch.Tensor): # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation
zero_point = torch.zeros_like(scale) # value to 'n', so the scale becomes 1
else: sat_val[sat_val == 0] = n
zero_point = 0.0 scale = n / sat_val
zero_point = torch.zeros_like(scale)
if is_scalar:
# If input was scalar, return scalars
return scale.item(), zero_point.item()
return scale, zero_point return scale, zero_point
def asymmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, def asymmetric_linear_quantization_params(num_bits, saturation_min, saturation_max,
integral_zero_point=True, signed=False): integral_zero_point=True, signed=False):
scalar_min, sat_min = _prep_saturation_val_tensor(saturation_min)
scalar_max, sat_max = _prep_saturation_val_tensor(saturation_max)
is_scalar = scalar_min and scalar_max
if scalar_max and not scalar_min:
sat_max = sat_max.to(sat_min.device)
elif scalar_min and not scalar_max:
sat_min = sat_min.to(sat_max.device)
if any(sat_min > sat_max):
raise ValueError('saturation_min must be smaller than saturation_max')
n = 2 ** num_bits - 1 n = 2 ** num_bits - 1
scale = n / (saturation_max - saturation_min)
zero_point = scale * saturation_min # Make sure 0 is in the range
sat_min = torch.min(sat_min, torch.zeros_like(sat_min))
sat_max = torch.max(sat_max, torch.zeros_like(sat_max))
diff = sat_max - sat_min
# If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation
# value to 'n', so the scale becomes 1
diff[diff == 0] = n
scale = n / diff
zero_point = scale * sat_min
if integral_zero_point: if integral_zero_point:
if isinstance(zero_point, torch.Tensor): zero_point = zero_point.round()
zero_point = zero_point.round()
else:
zero_point = float(round(zero_point))
if signed: if signed:
zero_point += 2 ** (num_bits - 1) zero_point += 2 ** (num_bits - 1)
if is_scalar:
return scale.item(), zero_point.item()
return scale, zero_point return scale, zero_point
......
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import pytest
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
from distiller.quantization import q_utils as qu
def test_symmetric_qparams():
with pytest.raises(ValueError):
# Negative scalar
qu.symmetric_linear_quantization_params(8, -5.)
# Negative element in tensor
qu.symmetric_linear_quantization_params(8, torch.tensor([-5., 10.]))
# Scalar positive
scale, zp = qu.symmetric_linear_quantization_params(8, 4.)
assert not isinstance(scale, torch.Tensor)
assert not isinstance(zp, torch.Tensor)
assert scale == 31.75
assert zp == 0
# Scalar positive integer
scale, zp = qu.symmetric_linear_quantization_params(8, 4)
assert not isinstance(scale, torch.Tensor)
assert not isinstance(zp, torch.Tensor)
assert scale == 31.75
assert zp == 0
# Scalar zero
scale, zp = qu.symmetric_linear_quantization_params(8, 0.)
assert scale == 1
assert zp == 0
# Tensor positives
sat = torch.tensor([4., 10.])
scale, zp = qu.symmetric_linear_quantization_params(8, sat)
assert torch.equal(scale, torch.tensor([31.75, 12.7]))
assert torch.equal(zp, torch.zeros_like(sat))
# Tensor positives - integer saturation values
sat = torch.tensor([4, 10])
scale, zp = qu.symmetric_linear_quantization_params(8, sat)
assert torch.equal(scale, torch.tensor([31.75, 12.7]))
assert torch.equal(zp, torch.zeros_like(sat, dtype=torch.float32))
# Tensor with 0
sat = torch.tensor([4., 0.])
scale, zp = qu.symmetric_linear_quantization_params(8, sat)
assert torch.equal(scale, torch.tensor([31.75, 1.]))
assert torch.equal(zp, torch.zeros_like(sat))
def test_asymmetric_qparams():
with pytest.raises(ValueError):
# Test min > max
# min scalar, max scalar
qu.asymmetric_linear_quantization_params(8, 5., 4.)
# min scalar, max tensor
qu.asymmetric_linear_quantization_params(8, 5., torch.tensor([5., 3.]))
# min tensor, max scalar
qu.asymmetric_linear_quantization_params(8, torch.tensor([5., 3.]), 4.)
# min tensor, max tensor
qu.asymmetric_linear_quantization_params(8, torch.tensor([5., 3.]), torch.tensor([4., 7.]))
# min scalar, max scalar
# Min negative, max positive
scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=True, signed=False)
assert not isinstance(scale, torch.Tensor)
assert not isinstance(zp, torch.Tensor)
assert scale == 21.25
assert zp == -42
scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=True, signed=True)
assert scale == 21.25
assert zp == 86
scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=False, signed=False)
assert scale == 21.25
assert zp == -42.5
scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=False, signed=True)
assert scale == 21.25
assert zp == 85.5
# Integer saturation values
scale, zp = qu.asymmetric_linear_quantization_params(8, -2, 10, integral_zero_point=False, signed=True)
assert scale == 21.25
assert zp == 85.5
# Both positive
scale, zp = qu.asymmetric_linear_quantization_params(8, 5., 10.)
assert scale == 25.5
assert zp == 0
# Both negative
scale, zp = qu.asymmetric_linear_quantization_params(8, -10., -5.)
assert scale == 25.5
assert zp == -255
# Both zero
scale, zp = qu.asymmetric_linear_quantization_params(8, 0., 0.)
assert scale == 1.
assert zp == 0
# min scalar, max tensor
scale, zp = qu.asymmetric_linear_quantization_params(8, -10., torch.tensor([-2., 5.]))
assert torch.equal(scale,torch.tensor([25.5, 17]))
assert torch.equal(zp, torch.tensor([-255., -170]))
scale, zp = qu.asymmetric_linear_quantization_params(8, 0., torch.tensor([0., 5.]))
assert torch.equal(scale, torch.tensor([1., 51.]))
assert torch.equal(zp, torch.tensor([0., 0.]))
# Integer saturation values
scale, zp = qu.asymmetric_linear_quantization_params(8, -10., torch.tensor([-2, 5]))
assert torch.equal(scale, torch.tensor([25.5, 17]))
assert torch.equal(zp, torch.tensor([-255., -170]))
# min tensor, max scalar
scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([-2., 5.]), 10.)
assert torch.equal(scale, torch.tensor([21.25, 25.5]))
assert torch.equal(zp, torch.tensor([-42., 0.]))
scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([0., -5.]), 0.)
assert torch.equal(scale, torch.tensor([1., 51.]))
assert torch.equal(zp, torch.tensor([0., -255.]))
# Integer saturation values
scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([-2, 5]), 10.)
assert torch.equal(scale, torch.tensor([21.25, 25.5]))
assert torch.equal(zp, torch.tensor([-42., 0.]))
# min tensor, max tensor
scale, zp = qu.asymmetric_linear_quantization_params(8,
torch.tensor([-2., 5., -10., 0.]),
torch.tensor([10., 10., -5., 0.]))
assert torch.equal(scale, torch.tensor([21.25, 25.5, 25.5, 1.]))
assert torch.equal(zp, torch.tensor([-42., 0., -255., 0.]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment