From 74b49c8d8108e4a3927cc07b71e123aec49f8192 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 2 Dec 2018 15:29:18 +0200 Subject: [PATCH] Address the concern raised in issue #94 In a previous commit, we chose to override the TorchVision ResNet models with our own modified versions of ResNet because collecting activation statistics after ReLU layers requires that each ReLU instance is only used once in the graph. However, overriding TorchVision models means that if someone wants to directly access the TorchVision model dictionary, he/she won't find the ResNet models in the dictionary keys. This is not very friendly and we should not break imported modules. This commit overrides the TorchVision ResNet models when accessed via the distiller.models.create_model API, w/o changing the TorchVision dictionary. --- models/__init__.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/models/__init__.py b/models/__init__.py index b20bb49..7bde406 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -27,8 +27,6 @@ msglogger = logging.getLogger() # ResNet special treatment: we have our own version of ResNet, so we need to over-ride # TorchVision's version. RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] -for sym in RESNET_SYMS: - torch_models.__dict__.pop(sym) IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__ if name.islower() and not name.startswith("__") @@ -62,14 +60,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch)) assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \ "Model %s is not supported for dataset %s" % (arch, 'ImageNet') - if arch in torch_models.__dict__: + if arch in RESNET_SYMS: + model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) + elif arch in torch_models.__dict__: model = torch_models.__dict__[arch](pretrained=pretrained) else: - if arch in RESNET_SYMS: - model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) - else: - assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch - model = imagenet_extra_models.__dict__[arch]() + assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch + model = imagenet_extra_models.__dict__[arch]() elif dataset == 'cifar10': msglogger.info("=> creating %s model for CIFAR10" % arch) assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch -- GitLab