Skip to content
Snippets Groups Projects
Commit 4ec96d90 authored by Guy Jacob's avatar Guy Jacob
Browse files

Fix bug in test that resulted in duplicate modules in a model

parent 8afac56f
No related branches found
No related tags found
No related merge requests found
......@@ -600,9 +600,9 @@ class LinearBNSplitAct(nn.Module):
super(LinearBNSplitAct, self).__init__()
self.linear = nn.Linear(10, 40)
self.bn = nn.BatchNorm1d(40)
acts_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid()}
self.act1 = acts_map[act1_type]
self.act2 = acts_map[act2_type]
acts_map = {'relu': nn.ReLU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid}
self.act1 = acts_map[act1_type]()
self.act2 = acts_map[act2_type]()
def forward(self, x):
x = self.linear(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment