diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py
index 6bae2a78c7b24bada3996a7d318b028af29ca4b2..12da4ec464241922465f360ef4161943cfecd5fd 100644
--- a/tests/test_post_train_quant.py
+++ b/tests/test_post_train_quant.py
@@ -28,14 +28,14 @@ from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuan
 
 @pytest.fixture()
 def conv_input():
-    return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float64),
-                      torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float64)), 0)
+    return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float32),
+                      torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float32)), 0)
 
 
 @pytest.fixture()
 def conv_weights():
     return torch.tensor([[[[-1, -0.5, 0], [0.5, 1, 1.5], [2, 2.5, 3]]],
-                         [[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float64)
+                         [[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float32)
 
 
 @pytest.mark.parametrize(
@@ -43,30 +43,30 @@ def conv_weights():
     [
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False,
          torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]],
-                                   [[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float64),
+                                   [[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float32),
                     torch.tensor([[[[12.51811144, 13.01883589], [14.0918168, 14.59254133]],
-                                   [[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float64)),
+                                   [[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float32)),
                    dim=0)
          ),
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, True, False,
          torch.cat((torch.tensor([[[[-1.089218234, -1.089218234], [1.055180164, 2.518817167]],
-                                   [[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float64),
+                                   [[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float32),
                     torch.tensor([[[[7.59048957, 7.59048957], [7.59048957, 7.59048957]],
-                                   [[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float64)),
+                                   [[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float32)),
                    dim=0)
          ),
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True,
          torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]],
-                                   [[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float64),
+                                   [[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float32),
                     torch.tensor([[[[12.51811144, 13.01883589], [14.09181687, 14.59254133]],
-                                   [[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float64)),
+                                   [[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float32)),
                    dim=0)
          ),
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, True, True,
          torch.cat((torch.tensor([[[[-1.089768056, -1.089768056], [1.055712804, 2.52008863]],
-                                   [[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float64),
+                                   [[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float32),
                     torch.tensor([[[[7.59432114, 7.59432114], [7.59432114, 7.59432114]],
-                                   [[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float64)),
+                                   [[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float32)),
                    dim=0)
          )
     ]
@@ -91,28 +91,28 @@ def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_chann
 
 @pytest.fixture()
 def linear_input():
-    return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float64)
+    return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float32)
 
 
 @pytest.fixture()
 def linear_weights():
     return torch.tensor([[-1, 0.5, 0, 0.5],
                          [-0.05, 0, 0.05, 0.1],
-                         [0.3, 0.6, -0.1, -0.2]], dtype=torch.float64)
+                         [0.3, 0.6, -0.1, -0.2]], dtype=torch.float32)
 
 
 @pytest.fixture()
 def linear_bias():
-    return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float64)
+    return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float32)
 
 
 @pytest.mark.parametrize(
     "mode, clip_acts, per_channel_wts, expected_output",
     [
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False,
-         torch.tensor([[7.698556917, 0.262450804, 0.787352412]], dtype=torch.float64)),
+         torch.tensor([[7.686200692, 0.241135708, 0.783691051]], dtype=torch.float32)),
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True,
-         torch.tensor([[7.71233218, 0.262920415, 0.788761246]], dtype=torch.float64))
+         torch.tensor([[7.698823529, 0.241531719, 0.784978085]], dtype=torch.float32))
     ]
 )
 def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,