diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 6bae2a78c7b24bada3996a7d318b028af29ca4b2..12da4ec464241922465f360ef4161943cfecd5fd 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -28,14 +28,14 @@ from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuan @pytest.fixture() def conv_input(): - return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float64), - torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float64)), 0) + return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float32), + torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float32)), 0) @pytest.fixture() def conv_weights(): return torch.tensor([[[[-1, -0.5, 0], [0.5, 1, 1.5], [2, 2.5, 3]]], - [[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float64) + [[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float32) @pytest.mark.parametrize( @@ -43,30 +43,30 @@ def conv_weights(): [ (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False, torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]], - [[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float64), + [[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float32), torch.tensor([[[[12.51811144, 13.01883589], [14.0918168, 14.59254133]], - [[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float64)), + [[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float32)), dim=0) ), (LinearQuantMode.ASYMMETRIC_UNSIGNED, True, False, torch.cat((torch.tensor([[[[-1.089218234, -1.089218234], [1.055180164, 2.518817167]], - [[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float64), + [[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float32), torch.tensor([[[[7.59048957, 7.59048957], [7.59048957, 7.59048957]], - [[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float64)), + [[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float32)), dim=0) ), (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True, torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]], - [[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float64), + [[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float32), torch.tensor([[[[12.51811144, 13.01883589], [14.09181687, 14.59254133]], - [[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float64)), + [[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float32)), dim=0) ), (LinearQuantMode.ASYMMETRIC_UNSIGNED, True, True, torch.cat((torch.tensor([[[[-1.089768056, -1.089768056], [1.055712804, 2.52008863]], - [[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float64), + [[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float32), torch.tensor([[[[7.59432114, 7.59432114], [7.59432114, 7.59432114]], - [[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float64)), + [[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float32)), dim=0) ) ] @@ -91,28 +91,28 @@ def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_chann @pytest.fixture() def linear_input(): - return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float64) + return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float32) @pytest.fixture() def linear_weights(): return torch.tensor([[-1, 0.5, 0, 0.5], [-0.05, 0, 0.05, 0.1], - [0.3, 0.6, -0.1, -0.2]], dtype=torch.float64) + [0.3, 0.6, -0.1, -0.2]], dtype=torch.float32) @pytest.fixture() def linear_bias(): - return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float64) + return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float32) @pytest.mark.parametrize( "mode, clip_acts, per_channel_wts, expected_output", [ (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False, - torch.tensor([[7.698556917, 0.262450804, 0.787352412]], dtype=torch.float64)), + torch.tensor([[7.686200692, 0.241135708, 0.783691051]], dtype=torch.float32)), (LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True, - torch.tensor([[7.71233218, 0.262920415, 0.788761246]], dtype=torch.float64)) + torch.tensor([[7.698823529, 0.241531719, 0.784978085]], dtype=torch.float32)) ] ) def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,