diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 122129a07abab59a1c4615c61bb67b8018c508e3..9d932aac5fc2d7fdffced3ef6c10f268ae69a123 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -600,9 +600,9 @@ class LinearBNSplitAct(nn.Module): super(LinearBNSplitAct, self).__init__() self.linear = nn.Linear(10, 40) self.bn = nn.BatchNorm1d(40) - acts_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid()} - self.act1 = acts_map[act1_type] - self.act2 = acts_map[act2_type] + acts_map = {'relu': nn.ReLU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid} + self.act1 = acts_map[act1_type]() + self.act2 = acts_map[act2_type]() def forward(self, x): x = self.linear(x)