Skip to content
Snippets Groups Projects
Commit f92656c7 authored by Neta Zmora's avatar Neta Zmora
Browse files

Revert "Cifar models: remove explicit parameters initialization"

Said commit was wrong: the default inializations in pytorch are not
the same as in our code.  For example, the default convolution
weight initialization uses Kaiming-uniform, while we used
Kaiming-normal.
For backward comparability of the model behavior, we need to
revert to the old behavior.
This reverts commit 6913687f.
parent 8b341593
No related branches found
No related tags found
No related merge requests found
......@@ -83,6 +83,14 @@ class PlainCifar(nn.Module):
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True):
# Each layer is composed on 2*num_blocks blocks, and the first block usually
# performs downsampling of the input, and doubling of the number of filters/feature-maps.
......
......@@ -116,6 +116,14 @@ class PreactResNetCifar(nn.Module):
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False):
downsample = None
outplanes = planes * block.expansion
......
......@@ -87,6 +87,7 @@ class BasicBlock(nn.Module):
class ResNetCifar(nn.Module):
def __init__(self, block, layers, num_classes=NUM_CLASSES):
self.nlayers = 0
# Each layer manages its own gates
......@@ -108,6 +109,14 @@ class ResNetCifar(nn.Module):
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
......
......@@ -32,10 +32,12 @@ __all__ = [
class VGGCifar(nn.Module):
def __init__(self, features, num_classes=10):
def __init__(self, features, num_classes=10, init_weights=True):
super(VGGCifar, self).__init__()
self.features = features
self.classifier = nn.Linear(512, num_classes)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
......@@ -43,6 +45,19 @@ class VGGCifar(nn.Module):
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False):
layers = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment