From 6913687f61425e9102151d8a45caaee63ec40618 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sat, 16 Nov 2019 17:44:26 +0200
Subject: [PATCH] Cifar models: remove explicit parameters initialization

except for the case of VGG, our parameter initialization code was
matched the default pytorch initialization (per torch.nn operation),
so writing the initialization code ourselves can only lead to
more code and maintenance; and also we would not benefit from
improvements that occur at the pytorch level (e.g. if FB finds a
better initialization for nn.conv2d than today's kaiming init, we
would not benefit).
The VGG initialization we had was "suspicious" and so reverting
to the default seems reasonable.
---
 distiller/models/cifar10/plain_cifar.py     |  8 --------
 distiller/models/cifar10/preresnet_cifar.py |  8 --------
 distiller/models/cifar10/resnet_cifar.py    |  9 ---------
 distiller/models/cifar10/vgg_cifar.py       | 17 +----------------
 4 files changed, 1 insertion(+), 41 deletions(-)

diff --git a/distiller/models/cifar10/plain_cifar.py b/distiller/models/cifar10/plain_cifar.py
index 7f668a5..0ef8851 100755
--- a/distiller/models/cifar10/plain_cifar.py
+++ b/distiller/models/cifar10/plain_cifar.py
@@ -83,14 +83,6 @@ class PlainCifar(nn.Module):
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                m.weight.data.normal_(0, math.sqrt(2. / n))
-            elif isinstance(m, nn.BatchNorm2d):
-                m.weight.data.fill_(1)
-                m.bias.data.zero_()
-
     def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True):
         # Each layer is composed on 2*num_blocks blocks, and the first block usually
         # performs downsampling of the input, and doubling of the number of filters/feature-maps.
diff --git a/distiller/models/cifar10/preresnet_cifar.py b/distiller/models/cifar10/preresnet_cifar.py
index 9a8b2e9..4210647 100644
--- a/distiller/models/cifar10/preresnet_cifar.py
+++ b/distiller/models/cifar10/preresnet_cifar.py
@@ -116,14 +116,6 @@ class PreactResNetCifar(nn.Module):
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                m.weight.data.normal_(0, math.sqrt(2. / n))
-            elif isinstance(m, nn.BatchNorm2d):
-                m.weight.data.fill_(1)
-                m.bias.data.zero_()
-
     def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False):
         downsample = None
         outplanes = planes * block.expansion
diff --git a/distiller/models/cifar10/resnet_cifar.py b/distiller/models/cifar10/resnet_cifar.py
index ca31731..dc8432f 100755
--- a/distiller/models/cifar10/resnet_cifar.py
+++ b/distiller/models/cifar10/resnet_cifar.py
@@ -87,7 +87,6 @@ class BasicBlock(nn.Module):
 
 
 class ResNetCifar(nn.Module):
-
     def __init__(self, block, layers, num_classes=NUM_CLASSES):
         self.nlayers = 0
         # Each layer manages its own gates
@@ -109,14 +108,6 @@ class ResNetCifar(nn.Module):
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                m.weight.data.normal_(0, math.sqrt(2. / n))
-            elif isinstance(m, nn.BatchNorm2d):
-                m.weight.data.fill_(1)
-                m.bias.data.zero_()
-
     def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
         downsample = None
         if stride != 1 or self.inplanes != planes * block.expansion:
diff --git a/distiller/models/cifar10/vgg_cifar.py b/distiller/models/cifar10/vgg_cifar.py
index 0b5a5bb..ec82379 100755
--- a/distiller/models/cifar10/vgg_cifar.py
+++ b/distiller/models/cifar10/vgg_cifar.py
@@ -32,12 +32,10 @@ __all__ = [
 
 
 class VGGCifar(nn.Module):
-    def __init__(self, features, num_classes=10, init_weights=True):
+    def __init__(self, features, num_classes=10):
         super(VGGCifar, self).__init__()
         self.features = features
         self.classifier = nn.Linear(512, num_classes)
-        if init_weights:
-            self._initialize_weights()
 
     def forward(self, x):
         x = self.features(x)
@@ -45,19 +43,6 @@ class VGGCifar(nn.Module):
         x = self.classifier(x)
         return x
 
-    def _initialize_weights(self):
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
-                if m.bias is not None:
-                    nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.BatchNorm2d):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.Linear):
-                nn.init.normal_(m.weight, 0, 0.01)
-                nn.init.constant_(m.bias, 0)
-
 
 def make_layers(cfg, batch_norm=False):
     layers = []
-- 
GitLab