From 8b341593e8fe71919de149cdd00e269061cecaed Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 25 Nov 2019 16:58:15 +0200
Subject: [PATCH] Resnet50 early-exit update

Update the definition of the exits using info from Haim.

This is still very unsatsifactory because we don't have working
examples to show users :-(
---
 distiller/models/imagenet/resnet.py           |  3 ++-
 distiller/models/imagenet/resnet_earlyexit.py | 15 ++++++---------
 2 files changed, 8 insertions(+), 10 deletions(-)

diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py
index 6f0fd7a..c578434 100755
--- a/distiller/models/imagenet/resnet.py
+++ b/distiller/models/imagenet/resnet.py
@@ -31,7 +31,8 @@ from distiller.modules import EltwiseAdd
 
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
-           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'DistillerBottleneck']
 
 
 class DistillerBasicBlock(BasicBlock):
diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py
index 03fd6c9..def335c 100644
--- a/distiller/models/imagenet/resnet_earlyexit.py
+++ b/distiller/models/imagenet/resnet_earlyexit.py
@@ -1,9 +1,9 @@
 import torch.nn as nn
 import torchvision.models as models
-from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
 from .resnet import DistillerBottleneck
 import distiller
 
+
 __all__ = ['resnet50_earlyexit']
 
 
@@ -15,18 +15,15 @@ def conv3x3(in_planes, out_planes, stride=1):
 
 def get_exits_def(num_classes):
     expansion = 1 # models.ResNet.BasicBlock.expansion
-    exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
-                                                  nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True),
-                                                  nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True),
-                                                  #nn.AdaptiveAvgPool2d((1, 1)),
+    exits_def = [('layer1.2.relu3', nn.Sequential(nn.Conv2d(256, 10, kernel_size=7, stride=2, padding=3, bias=True),
+                                                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                                                   nn.Flatten(),
-                                                  nn.Linear(12 * expansion, num_classes))),
+                                                  nn.Linear(1960, num_classes))),
                                                   #distiller.modules.Print())),
                  ('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True),
-                                                  nn.AdaptiveAvgPool2d((1, 1)),
+                                                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                                                   nn.Flatten(),
-                                                  #distiller.modules.Print()))]
-                                                  nn.Linear(12 * expansion, num_classes)))]
+                                                  nn.Linear(588, num_classes)))]
     return exits_def
 
 
-- 
GitLab