From 6b904a6482c9487e94ad3a3d2a2f31786de47f26 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 25 Mar 2019 13:43:36 +0200
Subject: [PATCH] Fix ResNet50 Early Exit

- Fix the invocation of resnet50_earlyexit (missing 'pretrained') parameter.
- Remove all ResNet depths other than 50, to prevent confusion (these are currently not supported).
---
 distiller/models/imagenet/resnet_earlyexit.py | 32 ++-----------------
 1 file changed, 2 insertions(+), 30 deletions(-)

diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py
index c4b8742..4e6ba99 100644
--- a/distiller/models/imagenet/resnet_earlyexit.py
+++ b/distiller/models/imagenet/resnet_earlyexit.py
@@ -6,7 +6,7 @@ from torchvision.models.resnet import Bottleneck
 from torchvision.models.resnet import BasicBlock
 
 
-__all__ = ['resnet18_earlyexit', 'resnet34_earlyexit', 'resnet50_earlyexit', 'resnet101_earlyexit', 'resnet152_earlyexit']
+__all__ = ['resnet50_earlyexit']
 
 
 def conv3x3(in_planes, out_planes, stride=1):
@@ -66,36 +66,8 @@ class ResNetEarlyExit(models.ResNet):
         return output
 
 
-def resnet18_earlyexit(**kwargs):
-    """Constructs a ResNet-18 model.
-    """
-    model = ResNetEarlyExit(BasicBlock, [2, 2, 2, 2], **kwargs)
-    return model
-
-
-def resnet34_earlyexit(**kwargs):
-    """Constructs a ResNet-34 model.
-    """
-    model = ResNetEarlyExit(BasicBlock, [3, 4, 6, 3], **kwargs)
-    return model
-
-
-def resnet50_earlyexit(**kwargs):
+def resnet50_earlyexit(pretrained=False, **kwargs):
     """Constructs a ResNet-50 model.
     """
     model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs)
     return model
-
-
-def resnet101_earlyexit(**kwargs):
-    """Constructs a ResNet-101 model.
-    """
-    model = ResNetEarlyExit(Bottleneck, [3, 4, 23, 3], **kwargs)
-    return model
-
-
-def resnet152_earlyexit(**kwargs):
-    """Constructs a ResNet-152 model.
-    """
-    model = ResNetEarlyExit(Bottleneck, [3, 8, 36, 3], **kwargs)
-    return model
-- 
GitLab