diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py
index 306f319689aa6994c997b4017dd555498260d323..2e3ca7b989947a4efc22be85b100c816903bbc30 100644
--- a/distiller/quantization/q_utils.py
+++ b/distiller/quantization/q_utils.py
@@ -17,29 +17,70 @@
 import torch
 
 
+def _prep_saturation_val_tensor(sat_val):
+    is_scalar = not isinstance(sat_val, torch.Tensor)
+    out = torch.tensor(sat_val)
+    if not out.is_floating_point():
+        out = out.to(torch.float32)
+    if out.dim() == 0:
+        out = out.unsqueeze(0)
+    return is_scalar, out
+
+
 def symmetric_linear_quantization_params(num_bits, saturation_val):
+    is_scalar, sat_val = _prep_saturation_val_tensor(saturation_val)
+
+    if any(sat_val < 0):
+        raise ValueError('Saturation value must be >= 0')
+
     # Leave one bit for sign
     n = 2 ** (num_bits - 1) - 1
-    scale = n / saturation_val
-    if isinstance(scale, torch.Tensor):
-        zero_point = torch.zeros_like(scale)
-    else:
-        zero_point = 0.0
+
+    # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation
+    # value to 'n', so the scale becomes 1
+    sat_val[sat_val == 0] = n
+    scale = n / sat_val
+    zero_point = torch.zeros_like(scale)
+
+    if is_scalar:
+        # If input was scalar, return scalars
+        return scale.item(), zero_point.item()
     return scale, zero_point
 
 
 def asymmetric_linear_quantization_params(num_bits, saturation_min, saturation_max,
                                           integral_zero_point=True, signed=False):
+    scalar_min, sat_min = _prep_saturation_val_tensor(saturation_min)
+    scalar_max, sat_max = _prep_saturation_val_tensor(saturation_max)
+    is_scalar = scalar_min and scalar_max
+
+    if scalar_max and not scalar_min:
+        sat_max = sat_max.to(sat_min.device)
+    elif scalar_min and not scalar_max:
+        sat_min = sat_min.to(sat_max.device)
+
+    if any(sat_min > sat_max):
+        raise ValueError('saturation_min must be smaller than saturation_max')
+
     n = 2 ** num_bits - 1
-    scale = n / (saturation_max - saturation_min)
-    zero_point = scale * saturation_min
+
+    # Make sure 0 is in the range
+    sat_min = torch.min(sat_min, torch.zeros_like(sat_min))
+    sat_max = torch.max(sat_max, torch.zeros_like(sat_max))
+
+    diff = sat_max - sat_min
+    # If float values are all 0, we just want the quantized values to be 0 as well. So overriding the saturation
+    # value to 'n', so the scale becomes 1
+    diff[diff == 0] = n
+
+    scale = n / diff
+    zero_point = scale * sat_min
     if integral_zero_point:
-        if isinstance(zero_point, torch.Tensor):
-            zero_point = zero_point.round()
-        else:
-            zero_point = float(round(zero_point))
+        zero_point = zero_point.round()
     if signed:
         zero_point += 2 ** (num_bits - 1)
+    if is_scalar:
+        return scale.item(), zero_point.item()
     return scale, zero_point
 
 
diff --git a/tests/test_quant_utils.py b/tests/test_quant_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f613e50dcfb4c054f8581f1fe88afa19cf53d911
--- /dev/null
+++ b/tests/test_quant_utils.py
@@ -0,0 +1,159 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import pytest
+import sys
+import os
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+from distiller.quantization import q_utils as qu
+
+
+def test_symmetric_qparams():
+    with pytest.raises(ValueError):
+        # Negative scalar
+        qu.symmetric_linear_quantization_params(8, -5.)
+
+        # Negative element in tensor
+        qu.symmetric_linear_quantization_params(8, torch.tensor([-5., 10.]))
+
+    # Scalar positive
+    scale, zp = qu.symmetric_linear_quantization_params(8, 4.)
+    assert not isinstance(scale, torch.Tensor)
+    assert not isinstance(zp, torch.Tensor)
+    assert scale == 31.75
+    assert zp == 0
+
+    # Scalar positive integer
+    scale, zp = qu.symmetric_linear_quantization_params(8, 4)
+    assert not isinstance(scale, torch.Tensor)
+    assert not isinstance(zp, torch.Tensor)
+    assert scale == 31.75
+    assert zp == 0
+
+    # Scalar zero
+    scale, zp = qu.symmetric_linear_quantization_params(8, 0.)
+    assert scale == 1
+    assert zp == 0
+
+    # Tensor positives
+    sat = torch.tensor([4., 10.])
+    scale, zp = qu.symmetric_linear_quantization_params(8, sat)
+    assert torch.equal(scale, torch.tensor([31.75, 12.7]))
+    assert torch.equal(zp, torch.zeros_like(sat))
+
+    # Tensor positives - integer saturation values
+    sat = torch.tensor([4, 10])
+    scale, zp = qu.symmetric_linear_quantization_params(8, sat)
+    assert torch.equal(scale, torch.tensor([31.75, 12.7]))
+    assert torch.equal(zp, torch.zeros_like(sat, dtype=torch.float32))
+
+    # Tensor with 0
+    sat = torch.tensor([4., 0.])
+    scale, zp = qu.symmetric_linear_quantization_params(8, sat)
+    assert torch.equal(scale, torch.tensor([31.75, 1.]))
+    assert torch.equal(zp, torch.zeros_like(sat))
+
+
+def test_asymmetric_qparams():
+    with pytest.raises(ValueError):
+        # Test min > max
+        # min scalar, max scalar
+        qu.asymmetric_linear_quantization_params(8, 5., 4.)
+        # min scalar, max tensor
+        qu.asymmetric_linear_quantization_params(8, 5., torch.tensor([5., 3.]))
+        # min tensor, max scalar
+        qu.asymmetric_linear_quantization_params(8, torch.tensor([5., 3.]), 4.)
+        # min tensor, max tensor
+        qu.asymmetric_linear_quantization_params(8, torch.tensor([5., 3.]), torch.tensor([4., 7.]))
+
+    # min scalar, max scalar
+
+    # Min negative, max positive
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=True, signed=False)
+    assert not isinstance(scale, torch.Tensor)
+    assert not isinstance(zp, torch.Tensor)
+    assert scale == 21.25
+    assert zp == -42
+
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=True, signed=True)
+    assert scale == 21.25
+    assert zp == 86
+
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=False, signed=False)
+    assert scale == 21.25
+    assert zp == -42.5
+
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -2., 10., integral_zero_point=False, signed=True)
+    assert scale == 21.25
+    assert zp == 85.5
+
+    # Integer saturation values
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -2, 10, integral_zero_point=False, signed=True)
+    assert scale == 21.25
+    assert zp == 85.5
+
+    # Both positive
+    scale, zp = qu.asymmetric_linear_quantization_params(8, 5., 10.)
+    assert scale == 25.5
+    assert zp == 0
+
+    # Both negative
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -10., -5.)
+    assert scale == 25.5
+    assert zp == -255
+
+    # Both zero
+    scale, zp = qu.asymmetric_linear_quantization_params(8, 0., 0.)
+    assert scale == 1.
+    assert zp == 0
+
+    # min scalar, max tensor
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -10., torch.tensor([-2., 5.]))
+    assert torch.equal(scale,torch.tensor([25.5, 17]))
+    assert torch.equal(zp, torch.tensor([-255., -170]))
+
+    scale, zp = qu.asymmetric_linear_quantization_params(8, 0., torch.tensor([0., 5.]))
+    assert torch.equal(scale, torch.tensor([1., 51.]))
+    assert torch.equal(zp, torch.tensor([0., 0.]))
+
+    # Integer saturation values
+    scale, zp = qu.asymmetric_linear_quantization_params(8, -10., torch.tensor([-2, 5]))
+    assert torch.equal(scale, torch.tensor([25.5, 17]))
+    assert torch.equal(zp, torch.tensor([-255., -170]))
+
+    # min tensor, max scalar
+    scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([-2., 5.]), 10.)
+    assert torch.equal(scale, torch.tensor([21.25, 25.5]))
+    assert torch.equal(zp, torch.tensor([-42., 0.]))
+
+    scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([0., -5.]), 0.)
+    assert torch.equal(scale, torch.tensor([1., 51.]))
+    assert torch.equal(zp, torch.tensor([0., -255.]))
+
+    # Integer saturation values
+    scale, zp = qu.asymmetric_linear_quantization_params(8, torch.tensor([-2, 5]), 10.)
+    assert torch.equal(scale, torch.tensor([21.25, 25.5]))
+    assert torch.equal(zp, torch.tensor([-42., 0.]))
+
+    # min tensor, max tensor
+    scale, zp = qu.asymmetric_linear_quantization_params(8,
+                                                         torch.tensor([-2., 5., -10., 0.]),
+                                                         torch.tensor([10., 10., -5., 0.]))
+    assert torch.equal(scale, torch.tensor([21.25, 25.5, 25.5, 1.]))
+    assert torch.equal(zp, torch.tensor([-42., 0., -255., 0.]))