Skip to content
Snippets Groups Projects
Commit 6913687f authored by Neta Zmora's avatar Neta Zmora
Browse files

Cifar models: remove explicit parameters initialization

except for the case of VGG, our parameter initialization code was
matched the default pytorch initialization (per torch.nn operation),
so writing the initialization code ourselves can only lead to
more code and maintenance; and also we would not benefit from
improvements that occur at the pytorch level (e.g. if FB finds a
better initialization for nn.conv2d than today's kaiming init, we
would not benefit).
The VGG initialization we had was "suspicious" and so reverting
to the default seems reasonable.
parent fbdbe35a
No related branches found
No related tags found
No related merge requests found
...@@ -83,14 +83,6 @@ class PlainCifar(nn.Module): ...@@ -83,14 +83,6 @@ class PlainCifar(nn.Module):
self.avgpool = nn.AvgPool2d(8, stride=1) self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes) self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True): def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True):
# Each layer is composed on 2*num_blocks blocks, and the first block usually # Each layer is composed on 2*num_blocks blocks, and the first block usually
# performs downsampling of the input, and doubling of the number of filters/feature-maps. # performs downsampling of the input, and doubling of the number of filters/feature-maps.
......
...@@ -116,14 +116,6 @@ class PreactResNetCifar(nn.Module): ...@@ -116,14 +116,6 @@ class PreactResNetCifar(nn.Module):
self.avgpool = nn.AvgPool2d(8, stride=1) self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes) self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False): def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False):
downsample = None downsample = None
outplanes = planes * block.expansion outplanes = planes * block.expansion
......
...@@ -87,7 +87,6 @@ class BasicBlock(nn.Module): ...@@ -87,7 +87,6 @@ class BasicBlock(nn.Module):
class ResNetCifar(nn.Module): class ResNetCifar(nn.Module):
def __init__(self, block, layers, num_classes=NUM_CLASSES): def __init__(self, block, layers, num_classes=NUM_CLASSES):
self.nlayers = 0 self.nlayers = 0
# Each layer manages its own gates # Each layer manages its own gates
...@@ -109,14 +108,6 @@ class ResNetCifar(nn.Module): ...@@ -109,14 +108,6 @@ class ResNetCifar(nn.Module):
self.avgpool = nn.AvgPool2d(8, stride=1) self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes) self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, layer_gates, block, planes, blocks, stride=1): def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
......
...@@ -32,12 +32,10 @@ __all__ = [ ...@@ -32,12 +32,10 @@ __all__ = [
class VGGCifar(nn.Module): class VGGCifar(nn.Module):
def __init__(self, features, num_classes=10, init_weights=True): def __init__(self, features, num_classes=10):
super(VGGCifar, self).__init__() super(VGGCifar, self).__init__()
self.features = features self.features = features
self.classifier = nn.Linear(512, num_classes) self.classifier = nn.Linear(512, num_classes)
if init_weights:
self._initialize_weights()
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
...@@ -45,19 +43,6 @@ class VGGCifar(nn.Module): ...@@ -45,19 +43,6 @@ class VGGCifar(nn.Module):
x = self.classifier(x) x = self.classifier(x)
return x return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False): def make_layers(cfg, batch_norm=False):
layers = [] layers = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment