From 74b49c8d8108e4a3927cc07b71e123aec49f8192 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 2 Dec 2018 15:29:18 +0200
Subject: [PATCH] Address the concern raised in issue #94

In a previous commit, we chose to override the TorchVision ResNet
models with our own modified versions of ResNet because collecting
activation statistics after ReLU layers requires that each ReLU
instance is only used once in the graph.

However, overriding TorchVision models means that if someone wants
to directly access the TorchVision model dictionary, he/she won't
find the ResNet models in the dictionary keys.  This is not very
friendly and we should not break imported modules.

This commit overrides the TorchVision ResNet models when accessed
via the distiller.models.create_model API, w/o changing the
TorchVision dictionary.
---
 models/__init__.py | 13 +++++--------
 1 file changed, 5 insertions(+), 8 deletions(-)

diff --git a/models/__init__.py b/models/__init__.py
index b20bb49..7bde406 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -27,8 +27,6 @@ msglogger = logging.getLogger()
 # ResNet special treatment: we have our own version of ResNet, so we need to over-ride
 # TorchVision's version.
 RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
-for sym in RESNET_SYMS:
-    torch_models.__dict__.pop(sym)
 
 IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
                               if name.islower() and not name.startswith("__")
@@ -62,14 +60,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch))
         assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \
             "Model %s is not supported for dataset %s" % (arch, 'ImageNet')
-        if arch in torch_models.__dict__:
+        if arch in RESNET_SYMS:
+            model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
+        elif arch in torch_models.__dict__:
             model = torch_models.__dict__[arch](pretrained=pretrained)
         else:
-            if arch in RESNET_SYMS:
-                model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
-            else:
-                assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
-                model = imagenet_extra_models.__dict__[arch]()
+            assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
+            model = imagenet_extra_models.__dict__[arch]()
     elif dataset == 'cifar10':
         msglogger.info("=> creating %s model for CIFAR10" % arch)
         assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch
-- 
GitLab