From 4ec96d9029931c359c943c8eb2a351f07dd72a81 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Sun, 21 Jul 2019 15:37:17 +0300 Subject: [PATCH] Fix bug in test that resulted in duplicate modules in a model --- tests/test_post_train_quant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 122129a..9d932aa 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -600,9 +600,9 @@ class LinearBNSplitAct(nn.Module): super(LinearBNSplitAct, self).__init__() self.linear = nn.Linear(10, 40) self.bn = nn.BatchNorm1d(40) - acts_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid()} - self.act1 = acts_map[act1_type] - self.act2 = acts_map[act2_type] + acts_map = {'relu': nn.ReLU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid} + self.act1 = acts_map[act1_type]() + self.act2 = acts_map[act2_type]() def forward(self, x): x = self.linear(x) -- GitLab