Skip to content
Snippets Groups Projects
Commit 321abb61 authored by Bar's avatar Bar Committed by Neta Zmora
Browse files

Integrate Cadene pretrained PyTorch models [fixes #142] (#184)

Integrate Cadene ```pretrainedmodels``` package.

This PR integrates a large set of pre-trained PyTorch image-classification and object-detection models which originate from https://github.com/Cadene/pretrained-models.pytorch.

*******************************************************************************************
PLEASE NOTE: 
This PR adds a dependency on he ```pretrainedmodels``` package, and you will need to install it using ```pip3 install pretrainedmodels```.  For new users, we have also updated the ```requirements.txt``` file.
*******************************************************************************************

Distiller does not currently support the compression of object-detectors (a sample application is required - and the community is invited to send us a PR).

Compression of some of these models may not be fully supported by Distiller due to bugs and/or missing features.  If you encounter any issues, please report to us.

Whenever there is contention on the names of models passed to the ```compress_classifier.py``` sample application, it will prefer to use the Cadene models at the lowest priority (e.g. Torchvision models are used in favor of Cadene models, when the same model is supported by both packages).

This PR also:
* Adds documentation to ```create_model```
* Adds tests for ```create_model```
parent cd2c5e73
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import torchvision.models as torch_models import torchvision.models as torch_models
from . import cifar10 as cifar10_models from . import cifar10 as cifar10_models
from . import imagenet as imagenet_extra_models from . import imagenet as imagenet_extra_models
import pretrainedmodels
import logging import logging
msglogger = logging.getLogger() msglogger = logging.getLogger()
...@@ -34,51 +35,63 @@ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__ ...@@ -34,51 +35,63 @@ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__ IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
and callable(imagenet_extra_models.__dict__[name]))) and callable(imagenet_extra_models.__dict__[name])))
IMAGENET_MODEL_NAMES.extend(pretrainedmodels.model_names)
CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
and callable(cifar10_models.__dict__[name])) and callable(cifar10_models.__dict__[name]))
ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))) ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
"""Create a pytorch model based on the model architecture and dataset """Create a pytorch model based on the model architecture and dataset
Args: Args:
pretrained: True is you wish to load a pretrained model. Only torchvision models pretrained [boolean]: True is you wish to load a pretrained model.
have a pretrained model. Some models do not have a pretrained version.
dataset: dataset: dataset name (only 'imagenet' and 'cifar10' are supported)
arch: arch: architecture name
parallel: parallel [boolean]: if set, use torch.nn.DataParallel
device_ids: Devices on which model should be created - device_ids: Devices on which model should be created -
None - GPU if available, otherwise CPU None - GPU if available, otherwise CPU
-1 - CPU -1 - CPU
>=0 - GPU device IDs >=0 - GPU device IDs
""" """
msglogger.info('==> using %s dataset' % dataset)
model = None model = None
dataset = dataset.lower()
if dataset == 'imagenet': if dataset == 'imagenet':
str_pretrained = 'pretrained ' if pretrained else ''
msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch))
assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \
"Model %s is not supported for dataset %s" % (arch, 'ImageNet')
if arch in RESNET_SYMS: if arch in RESNET_SYMS:
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
elif arch in torch_models.__dict__: elif arch in torch_models.__dict__:
model = torch_models.__dict__[arch](pretrained=pretrained) model = torch_models.__dict__[arch](pretrained=pretrained)
elif (arch in imagenet_extra_models.__dict__) and not pretrained:
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
elif arch in pretrainedmodels.model_names:
model = pretrainedmodels.__dict__[arch](
num_classes=1000,
pretrained=(dataset if pretrained else None))
else: else:
assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch error_message = ''
model = imagenet_extra_models.__dict__[arch]() if arch not in IMAGENET_MODEL_NAMES:
error_message = "Model {} is not supported for dataset ImageNet".format(arch)
elif pretrained:
error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch)
raise ValueError(error_message or 'Failed to find model {}'.format(arch))
msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch,
p=('pretrained ' if pretrained else '')))
elif dataset == 'cifar10': elif dataset == 'cifar10':
if pretrained:
raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch))
try:
model = cifar10_models.__dict__[arch]()
except KeyError:
raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch))
msglogger.info("=> creating %s model for CIFAR10" % arch) msglogger.info("=> creating %s model for CIFAR10" % arch)
assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch
assert not pretrained, "Model %s (CIFAR10) does not have a pretrained model" % arch
model = cifar10_models.__dict__[arch]()
else: else:
print("FATAL ERROR: create_model does not support models for dataset %s" % dataset) raise ValueError('Could not recognize dataset {}'.format(dataset))
exit()
if torch.cuda.is_available() and device_ids != -1: if torch.cuda.is_available() and device_ids != -1:
device = 'cuda' device = 'cuda'
......
...@@ -17,3 +17,4 @@ bqplot==0.10.5 ...@@ -17,3 +17,4 @@ bqplot==0.10.5
pyyaml pyyaml
pytest==3.5.1 pytest==3.5.1
xlsxwriter>=1.1.1 xlsxwriter>=1.1.1
pretrainedmodels
...@@ -30,6 +30,40 @@ except ImportError: ...@@ -30,6 +30,40 @@ except ImportError:
import distiller import distiller
from distiller.apputils import save_checkpoint, load_checkpoint from distiller.apputils import save_checkpoint, load_checkpoint
from distiller.models import create_model from distiller.models import create_model
import pretrainedmodels
def test_create_model_cifar():
pretrained = False
model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
with pytest.raises(ValueError):
# only cifar _10_ is currently supported
model = create_model(pretrained, 'cifar100', 'resnet20_cifar')
with pytest.raises(ValueError):
model = create_model(pretrained, 'cifar10', 'no_such_model!')
pretrained = True
with pytest.raises(ValueError):
# no pretrained models of cifar10
model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
def test_create_model_imagenet():
model = create_model(False, 'imagenet', 'alexnet')
model = create_model(False, 'imagenet', 'resnet50')
model = create_model(True, 'imagenet', 'resnet50')
with pytest.raises(ValueError):
model = create_model(False, 'imagenet', 'no_such_model!')
def test_create_model_pretrainedmodels():
premodel_name = 'resnext101_32x4d'
model = create_model(True, 'imagenet', premodel_name)
with pytest.raises(ValueError):
model = create_model(False, 'imagenet', 'no_such_model!')
def test_load(): def test_load():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment