From f92656c7f4e3d32458e6509caeea9204e81d27de Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 28 Nov 2019 00:42:15 +0200
Subject: [PATCH] Revert "Cifar models: remove explicit parameters
 initialization"

Said commit was wrong: the default inializations in pytorch are not
the same as in our code.  For example, the default convolution
weight initialization uses Kaiming-uniform, while we used
Kaiming-normal.
For backward comparability of the model behavior, we need to
revert to the old behavior.
This reverts commit 6913687f61425e9102151d8a45caaee63ec40618.
---
 distiller/models/cifar10/plain_cifar.py     |  8 ++++++++
 distiller/models/cifar10/preresnet_cifar.py |  8 ++++++++
 distiller/models/cifar10/resnet_cifar.py    |  9 +++++++++
 distiller/models/cifar10/vgg_cifar.py       | 17 ++++++++++++++++-
 4 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/distiller/models/cifar10/plain_cifar.py b/distiller/models/cifar10/plain_cifar.py
index 0ef8851..7f668a5 100755
--- a/distiller/models/cifar10/plain_cifar.py
+++ b/distiller/models/cifar10/plain_cifar.py
@@ -83,6 +83,14 @@ class PlainCifar(nn.Module):
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
     def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True):
         # Each layer is composed on 2*num_blocks blocks, and the first block usually
         # performs downsampling of the input, and doubling of the number of filters/feature-maps.
diff --git a/distiller/models/cifar10/preresnet_cifar.py b/distiller/models/cifar10/preresnet_cifar.py
index 4210647..9a8b2e9 100644
--- a/distiller/models/cifar10/preresnet_cifar.py
+++ b/distiller/models/cifar10/preresnet_cifar.py
@@ -116,6 +116,14 @@ class PreactResNetCifar(nn.Module):
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
     def _make_layer(self, layer_gates, block, planes, blocks, stride=1, conv_downsample=False):
         downsample = None
         outplanes = planes * block.expansion
diff --git a/distiller/models/cifar10/resnet_cifar.py b/distiller/models/cifar10/resnet_cifar.py
index dc8432f..ca31731 100755
--- a/distiller/models/cifar10/resnet_cifar.py
+++ b/distiller/models/cifar10/resnet_cifar.py
@@ -87,6 +87,7 @@ class BasicBlock(nn.Module):
 
 
 class ResNetCifar(nn.Module):
+
     def __init__(self, block, layers, num_classes=NUM_CLASSES):
         self.nlayers = 0
         # Each layer manages its own gates
@@ -108,6 +109,14 @@ class ResNetCifar(nn.Module):
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
 
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
     def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
         downsample = None
         if stride != 1 or self.inplanes != planes * block.expansion:
diff --git a/distiller/models/cifar10/vgg_cifar.py b/distiller/models/cifar10/vgg_cifar.py
index ec82379..0b5a5bb 100755
--- a/distiller/models/cifar10/vgg_cifar.py
+++ b/distiller/models/cifar10/vgg_cifar.py
@@ -32,10 +32,12 @@ __all__ = [
 
 
 class VGGCifar(nn.Module):
-    def __init__(self, features, num_classes=10):
+    def __init__(self, features, num_classes=10, init_weights=True):
         super(VGGCifar, self).__init__()
         self.features = features
         self.classifier = nn.Linear(512, num_classes)
+        if init_weights:
+            self._initialize_weights()
 
     def forward(self, x):
         x = self.features(x)
@@ -43,6 +45,19 @@ class VGGCifar(nn.Module):
         x = self.classifier(x)
         return x
 
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
 
 def make_layers(cfg, batch_norm=False):
     layers = []
-- 
GitLab