Skip to content
Snippets Groups Projects
Commit 74b49c8d authored by Neta Zmora's avatar Neta Zmora
Browse files

Address the concern raised in issue #94

In a previous commit, we chose to override the TorchVision ResNet
models with our own modified versions of ResNet because collecting
activation statistics after ReLU layers requires that each ReLU
instance is only used once in the graph.

However, overriding TorchVision models means that if someone wants
to directly access the TorchVision model dictionary, he/she won't
find the ResNet models in the dictionary keys.  This is not very
friendly and we should not break imported modules.

This commit overrides the TorchVision ResNet models when accessed
via the distiller.models.create_model API, w/o changing the
TorchVision dictionary.
parent 44144f7c
No related branches found
No related tags found
No related merge requests found
...@@ -27,8 +27,6 @@ msglogger = logging.getLogger() ...@@ -27,8 +27,6 @@ msglogger = logging.getLogger()
# ResNet special treatment: we have our own version of ResNet, so we need to over-ride # ResNet special treatment: we have our own version of ResNet, so we need to over-ride
# TorchVision's version. # TorchVision's version.
RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
for sym in RESNET_SYMS:
torch_models.__dict__.pop(sym)
IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
...@@ -62,14 +60,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): ...@@ -62,14 +60,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch)) msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch))
assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \ assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \
"Model %s is not supported for dataset %s" % (arch, 'ImageNet') "Model %s is not supported for dataset %s" % (arch, 'ImageNet')
if arch in torch_models.__dict__: if arch in RESNET_SYMS:
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
elif arch in torch_models.__dict__:
model = torch_models.__dict__[arch](pretrained=pretrained) model = torch_models.__dict__[arch](pretrained=pretrained)
else: else:
if arch in RESNET_SYMS: assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) model = imagenet_extra_models.__dict__[arch]()
else:
assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
model = imagenet_extra_models.__dict__[arch]()
elif dataset == 'cifar10': elif dataset == 'cifar10':
msglogger.info("=> creating %s model for CIFAR10" % arch) msglogger.info("=> creating %s model for CIFAR10" % arch)
assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment