Skip to content
Snippets Groups Projects
Commit 0dac3e07 authored by Guy Jacob's avatar Guy Jacob
Browse files

Fix dtypes in test_post_train_quant.py

parent 994e58d0
No related branches found
No related tags found
No related merge requests found
......@@ -28,14 +28,14 @@ from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuan
@pytest.fixture()
def conv_input():
return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float64),
torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float64)), 0)
return torch.cat((torch.tensor([[[[-7, 5], [2, -3]]]], dtype=torch.float32),
torch.tensor([[[[-15, 10], [-1, 5]]]], dtype=torch.float32)), 0)
@pytest.fixture()
def conv_weights():
return torch.tensor([[[[-1, -0.5, 0], [0.5, 1, 1.5], [2, 2.5, 3]]],
[[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float64)
[[[-0.3, -0.25, -0.2], [-0.15, -0.1, -0.05], [0, 0.05, 0.1]]]], dtype=torch.float32)
@pytest.mark.parametrize(
......@@ -43,30 +43,30 @@ def conv_weights():
[
(LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False,
torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]],
[[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float64),
[[0.214596196, 0.500724457], [0.715320653, 0.786852719]]]], dtype=torch.float32),
torch.tensor([[[[12.51811144, 13.01883589], [14.0918168, 14.59254133]],
[[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float64)),
[[1.359109242, 1.645237503], [1.573705438, 1.645237503]]]], dtype=torch.float32)),
dim=0)
),
(LinearQuantMode.ASYMMETRIC_UNSIGNED, True, False,
torch.cat((torch.tensor([[[[-1.089218234, -1.089218234], [1.055180164, 2.518817167]],
[[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float64),
[[0.238266489, 0.476532978], [0.680761396, 0.782875606]]]], dtype=torch.float32),
torch.tensor([[[[7.59048957, 7.59048957], [7.59048957, 7.59048957]],
[[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float64)),
[[1.123256304, 1.259408583], [1.089218234, 1.089218234]]]], dtype=torch.float32)),
dim=0)
),
(LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True,
torch.cat((torch.tensor([[[[-3.648135333, -2.14596196], [0.858384784, 2.432090222]],
[[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float64),
[[0.214596196, 0.429192392], [0.715320653, 0.858384784]]]], dtype=torch.float32),
torch.tensor([[[[12.51811144, 13.01883589], [14.09181687, 14.59254133]],
[[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float64)),
[[1.430641307, 1.502173372], [1.573705438, 1.645237503]]]], dtype=torch.float32)),
dim=0)
),
(LinearQuantMode.ASYMMETRIC_UNSIGNED, True, True,
torch.cat((torch.tensor([[[[-1.089768056, -1.089768056], [1.055712804, 2.52008863]],
[[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float64),
[[0.238386762, 0.408663021], [0.681105035, 0.817326042]]]], dtype=torch.float32),
torch.tensor([[[[7.59432114, 7.59432114], [7.59432114, 7.59432114]],
[[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float64)),
[[1.191933811, 1.15787856], [1.123823308, 1.089768056]]]], dtype=torch.float32)),
dim=0)
)
]
......@@ -91,28 +91,28 @@ def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_chann
@pytest.fixture()
def linear_input():
return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float64)
return torch.tensor([[-7, 5, 2, -3]], dtype=torch.float32)
@pytest.fixture()
def linear_weights():
return torch.tensor([[-1, 0.5, 0, 0.5],
[-0.05, 0, 0.05, 0.1],
[0.3, 0.6, -0.1, -0.2]], dtype=torch.float64)
[0.3, 0.6, -0.1, -0.2]], dtype=torch.float32)
@pytest.fixture()
def linear_bias():
return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float64)
return torch.tensor([-0.3, 0.1, -0.5], dtype=torch.float32)
@pytest.mark.parametrize(
"mode, clip_acts, per_channel_wts, expected_output",
[
(LinearQuantMode.ASYMMETRIC_UNSIGNED, False, False,
torch.tensor([[7.698556917, 0.262450804, 0.787352412]], dtype=torch.float64)),
torch.tensor([[7.686200692, 0.241135708, 0.783691051]], dtype=torch.float32)),
(LinearQuantMode.ASYMMETRIC_UNSIGNED, False, True,
torch.tensor([[7.71233218, 0.262920415, 0.788761246]], dtype=torch.float64))
torch.tensor([[7.698823529, 0.241531719, 0.784978085]], dtype=torch.float32))
]
)
def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment