diff --git a/models/__init__.py b/models/__init__.py index b20bb4951bed71dcc5f0a48257fc6b1293da036a..7bde40697c3cc328541a2f790b43efe47c587071 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -27,8 +27,6 @@ msglogger = logging.getLogger() # ResNet special treatment: we have our own version of ResNet, so we need to over-ride # TorchVision's version. RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] -for sym in RESNET_SYMS: - torch_models.__dict__.pop(sym) IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__ if name.islower() and not name.startswith("__") @@ -62,14 +60,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch)) assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \ "Model %s is not supported for dataset %s" % (arch, 'ImageNet') - if arch in torch_models.__dict__: + if arch in RESNET_SYMS: + model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) + elif arch in torch_models.__dict__: model = torch_models.__dict__[arch](pretrained=pretrained) else: - if arch in RESNET_SYMS: - model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) - else: - assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch - model = imagenet_extra_models.__dict__[arch]() + assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch + model = imagenet_extra_models.__dict__[arch]() elif dataset == 'cifar10': msglogger.info("=> creating %s model for CIFAR10" % arch) assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch