Skip to content
Snippets Groups Projects
Commit 321abb61 authored by Bar's avatar Bar Committed by Neta Zmora
Browse files

Integrate Cadene pretrained PyTorch models [fixes #142] (#184)

Integrate Cadene ```pretrainedmodels``` package.

This PR integrates a large set of pre-trained PyTorch image-classification and object-detection models which originate from https://github.com/Cadene/pretrained-models.pytorch.

*******************************************************************************************
PLEASE NOTE: 
This PR adds a dependency on he ```pretrainedmodels``` package, and you will need to install it using ```pip3 install pretrainedmodels```.  For new users, we have also updated the ```requirements.txt``` file.
*******************************************************************************************

Distiller does not currently support the compression of object-detectors (a sample application is required - and the community is invited to send us a PR).

Compression of some of these models may not be fully supported by Distiller due to bugs and/or missing features.  If you encounter any issues, please report to us.

Whenever there is contention on the names of models passed to the ```compress_classifier.py``` sample application, it will prefer to use the Cadene models at the lowest priority (e.g. Torchvision models are used in favor of Cadene models, when the same model is supported by both packages).

This PR also:
* Adds documentation to ```create_model```
* Adds tests for ```create_model```
parent cd2c5e73
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ import torch
import torchvision.models as torch_models
from . import cifar10 as cifar10_models
from . import imagenet as imagenet_extra_models
import pretrainedmodels
import logging
msglogger = logging.getLogger()
......@@ -34,51 +35,63 @@ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__
if name.islower() and not name.startswith("__")
and callable(imagenet_extra_models.__dict__[name])))
IMAGENET_MODEL_NAMES.extend(pretrainedmodels.model_names)
CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
if name.islower() and not name.startswith("__")
and callable(cifar10_models.__dict__[name]))
ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
"""Create a pytorch model based on the model architecture and dataset
Args:
pretrained: True is you wish to load a pretrained model. Only torchvision models
have a pretrained model.
dataset:
arch:
parallel:
pretrained [boolean]: True is you wish to load a pretrained model.
Some models do not have a pretrained version.
dataset: dataset name (only 'imagenet' and 'cifar10' are supported)
arch: architecture name
parallel [boolean]: if set, use torch.nn.DataParallel
device_ids: Devices on which model should be created -
None - GPU if available, otherwise CPU
-1 - CPU
>=0 - GPU device IDs
"""
msglogger.info('==> using %s dataset' % dataset)
model = None
dataset = dataset.lower()
if dataset == 'imagenet':
str_pretrained = 'pretrained ' if pretrained else ''
msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch))
assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \
"Model %s is not supported for dataset %s" % (arch, 'ImageNet')
if arch in RESNET_SYMS:
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
elif arch in torch_models.__dict__:
model = torch_models.__dict__[arch](pretrained=pretrained)
elif (arch in imagenet_extra_models.__dict__) and not pretrained:
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
elif arch in pretrainedmodels.model_names:
model = pretrainedmodels.__dict__[arch](
num_classes=1000,
pretrained=(dataset if pretrained else None))
else:
assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
model = imagenet_extra_models.__dict__[arch]()
error_message = ''
if arch not in IMAGENET_MODEL_NAMES:
error_message = "Model {} is not supported for dataset ImageNet".format(arch)
elif pretrained:
error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch)
raise ValueError(error_message or 'Failed to find model {}'.format(arch))
msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch,
p=('pretrained ' if pretrained else '')))
elif dataset == 'cifar10':
if pretrained:
raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch))
try:
model = cifar10_models.__dict__[arch]()
except KeyError:
raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch))
msglogger.info("=> creating %s model for CIFAR10" % arch)
assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch
assert not pretrained, "Model %s (CIFAR10) does not have a pretrained model" % arch
model = cifar10_models.__dict__[arch]()
else:
print("FATAL ERROR: create_model does not support models for dataset %s" % dataset)
exit()
raise ValueError('Could not recognize dataset {}'.format(dataset))
if torch.cuda.is_available() and device_ids != -1:
device = 'cuda'
......
......@@ -17,3 +17,4 @@ bqplot==0.10.5
pyyaml
pytest==3.5.1
xlsxwriter>=1.1.1
pretrainedmodels
......@@ -30,6 +30,40 @@ except ImportError:
import distiller
from distiller.apputils import save_checkpoint, load_checkpoint
from distiller.models import create_model
import pretrainedmodels
def test_create_model_cifar():
pretrained = False
model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
with pytest.raises(ValueError):
# only cifar _10_ is currently supported
model = create_model(pretrained, 'cifar100', 'resnet20_cifar')
with pytest.raises(ValueError):
model = create_model(pretrained, 'cifar10', 'no_such_model!')
pretrained = True
with pytest.raises(ValueError):
# no pretrained models of cifar10
model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
def test_create_model_imagenet():
model = create_model(False, 'imagenet', 'alexnet')
model = create_model(False, 'imagenet', 'resnet50')
model = create_model(True, 'imagenet', 'resnet50')
with pytest.raises(ValueError):
model = create_model(False, 'imagenet', 'no_such_model!')
def test_create_model_pretrainedmodels():
premodel_name = 'resnext101_32x4d'
model = create_model(True, 'imagenet', premodel_name)
with pytest.raises(ValueError):
model = create_model(False, 'imagenet', 'no_such_model!')
def test_load():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment