diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 49f5e7bfb77d7b33e9718e6f02326d0c5b3b82de..3ad0ceb4ba74c2a7ba15dcab494284016b169185 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -20,6 +20,7 @@ import torch import torchvision.models as torch_models from . import cifar10 as cifar10_models from . import imagenet as imagenet_extra_models +import pretrainedmodels import logging msglogger = logging.getLogger() @@ -34,51 +35,63 @@ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__ IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__ if name.islower() and not name.startswith("__") and callable(imagenet_extra_models.__dict__[name]))) +IMAGENET_MODEL_NAMES.extend(pretrainedmodels.model_names) CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__ if name.islower() and not name.startswith("__") and callable(cifar10_models.__dict__[name])) -ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))) +ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), + set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))) def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): """Create a pytorch model based on the model architecture and dataset Args: - pretrained: True is you wish to load a pretrained model. Only torchvision models - have a pretrained model. - dataset: - arch: - parallel: + pretrained [boolean]: True is you wish to load a pretrained model. + Some models do not have a pretrained version. + dataset: dataset name (only 'imagenet' and 'cifar10' are supported) + arch: architecture name + parallel [boolean]: if set, use torch.nn.DataParallel device_ids: Devices on which model should be created - None - GPU if available, otherwise CPU -1 - CPU >=0 - GPU device IDs """ - msglogger.info('==> using %s dataset' % dataset) - model = None + dataset = dataset.lower() if dataset == 'imagenet': - str_pretrained = 'pretrained ' if pretrained else '' - msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch)) - assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \ - "Model %s is not supported for dataset %s" % (arch, 'ImageNet') if arch in RESNET_SYMS: model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) elif arch in torch_models.__dict__: model = torch_models.__dict__[arch](pretrained=pretrained) + elif (arch in imagenet_extra_models.__dict__) and not pretrained: + model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) + elif arch in pretrainedmodels.model_names: + model = pretrainedmodels.__dict__[arch]( + num_classes=1000, + pretrained=(dataset if pretrained else None)) else: - assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch - model = imagenet_extra_models.__dict__[arch]() + error_message = '' + if arch not in IMAGENET_MODEL_NAMES: + error_message = "Model {} is not supported for dataset ImageNet".format(arch) + elif pretrained: + error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch) + raise ValueError(error_message or 'Failed to find model {}'.format(arch)) + + msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch, + p=('pretrained ' if pretrained else ''))) elif dataset == 'cifar10': + if pretrained: + raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch)) + try: + model = cifar10_models.__dict__[arch]() + except KeyError: + raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch)) msglogger.info("=> creating %s model for CIFAR10" % arch) - assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch - assert not pretrained, "Model %s (CIFAR10) does not have a pretrained model" % arch - model = cifar10_models.__dict__[arch]() else: - print("FATAL ERROR: create_model does not support models for dataset %s" % dataset) - exit() + raise ValueError('Could not recognize dataset {}'.format(dataset)) if torch.cuda.is_available() and device_ids != -1: device = 'cuda' diff --git a/requirements.txt b/requirements.txt index e2375a33acc75f6c3304249ca8bd10721dd75ce4..2a49ea1f5e7f839b91b0a0a2b5b588a414fa67ac 100755 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ bqplot==0.10.5 pyyaml pytest==3.5.1 xlsxwriter>=1.1.1 +pretrainedmodels diff --git a/tests/test_infra.py b/tests/test_infra.py index d1d898d585cce6e06f81c54979789397cf3f2255..2d0524f5bb77cb9d50195e304b1a1d4a1c9e4e1e 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -30,6 +30,40 @@ except ImportError: import distiller from distiller.apputils import save_checkpoint, load_checkpoint from distiller.models import create_model +import pretrainedmodels + + + +def test_create_model_cifar(): + pretrained = False + model = create_model(pretrained, 'cifar10', 'resnet20_cifar') + with pytest.raises(ValueError): + # only cifar _10_ is currently supported + model = create_model(pretrained, 'cifar100', 'resnet20_cifar') + with pytest.raises(ValueError): + model = create_model(pretrained, 'cifar10', 'no_such_model!') + + pretrained = True + with pytest.raises(ValueError): + # no pretrained models of cifar10 + model = create_model(pretrained, 'cifar10', 'resnet20_cifar') + + +def test_create_model_imagenet(): + model = create_model(False, 'imagenet', 'alexnet') + model = create_model(False, 'imagenet', 'resnet50') + model = create_model(True, 'imagenet', 'resnet50') + + with pytest.raises(ValueError): + model = create_model(False, 'imagenet', 'no_such_model!') + + +def test_create_model_pretrainedmodels(): + premodel_name = 'resnext101_32x4d' + model = create_model(True, 'imagenet', premodel_name) + + with pytest.raises(ValueError): + model = create_model(False, 'imagenet', 'no_such_model!') def test_load():