From 321abb61daf7893336c2461224560260168b5f9f Mon Sep 17 00:00:00 2001
From: Bar <elhararb@gmail.com>
Date: Mon, 11 Mar 2019 17:51:22 +0200
Subject: [PATCH] Integrate Cadene pretrained PyTorch models [fixes #142]
 (#184)

Integrate Cadene ```pretrainedmodels``` package.

This PR integrates a large set of pre-trained PyTorch image-classification and object-detection models which originate from https://github.com/Cadene/pretrained-models.pytorch.

*******************************************************************************************
PLEASE NOTE:
This PR adds a dependency on he ```pretrainedmodels``` package, and you will need to install it using ```pip3 install pretrainedmodels```.  For new users, we have also updated the ```requirements.txt``` file.
*******************************************************************************************

Distiller does not currently support the compression of object-detectors (a sample application is required - and the community is invited to send us a PR).

Compression of some of these models may not be fully supported by Distiller due to bugs and/or missing features.  If you encounter any issues, please report to us.

Whenever there is contention on the names of models passed to the ```compress_classifier.py``` sample application, it will prefer to use the Cadene models at the lowest priority (e.g. Torchvision models are used in favor of Cadene models, when the same model is supported by both packages).

This PR also:
* Adds documentation to ```create_model```
* Adds tests for ```create_model```
---
 distiller/models/__init__.py | 51 ++++++++++++++++++++++--------------
 requirements.txt             |  1 +
 tests/test_infra.py          | 34 ++++++++++++++++++++++++
 3 files changed, 67 insertions(+), 19 deletions(-)

diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 49f5e7b..3ad0ceb 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -20,6 +20,7 @@ import torch
 import torchvision.models as torch_models
 from . import cifar10 as cifar10_models
 from . import imagenet as imagenet_extra_models
+import pretrainedmodels
 
 import logging
 msglogger = logging.getLogger()
@@ -34,51 +35,63 @@ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
 IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__
                                    if name.islower() and not name.startswith("__")
                                    and callable(imagenet_extra_models.__dict__[name])))
+IMAGENET_MODEL_NAMES.extend(pretrainedmodels.model_names)
 
 CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
                              if name.islower() and not name.startswith("__")
                              and callable(cifar10_models.__dict__[name]))
 
-ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
+ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
+                            set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
 
 
 def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     """Create a pytorch model based on the model architecture and dataset
 
     Args:
-        pretrained: True is you wish to load a pretrained model.  Only torchvision models
-          have a pretrained model.
-        dataset:
-        arch:
-        parallel:
+        pretrained [boolean]: True is you wish to load a pretrained model.
+            Some models do not have a pretrained version.
+        dataset: dataset name (only 'imagenet' and 'cifar10' are supported)
+        arch: architecture name
+        parallel [boolean]: if set, use torch.nn.DataParallel
         device_ids: Devices on which model should be created -
             None - GPU if available, otherwise CPU
             -1 - CPU
             >=0 - GPU device IDs
     """
-    msglogger.info('==> using %s dataset' % dataset)
-
     model = None
+    dataset = dataset.lower()
     if dataset == 'imagenet':
-        str_pretrained = 'pretrained ' if pretrained else ''
-        msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch))
-        assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \
-            "Model %s is not supported for dataset %s" % (arch, 'ImageNet')
         if arch in RESNET_SYMS:
             model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
         elif arch in torch_models.__dict__:
             model = torch_models.__dict__[arch](pretrained=pretrained)
+        elif (arch in imagenet_extra_models.__dict__) and not pretrained:
+            model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
+        elif arch in pretrainedmodels.model_names:
+            model = pretrainedmodels.__dict__[arch](
+                        num_classes=1000,
+                        pretrained=(dataset if pretrained else None))
         else:
-            assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
-            model = imagenet_extra_models.__dict__[arch]()
+            error_message = ''
+            if arch not in IMAGENET_MODEL_NAMES:
+                error_message = "Model {} is not supported for dataset ImageNet".format(arch)
+            elif pretrained:
+                error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch)
+            raise ValueError(error_message or 'Failed to find model {}'.format(arch))
+
+        msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch,
+            p=('pretrained ' if pretrained else '')))
     elif dataset == 'cifar10':
+        if pretrained:
+            raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch))
+        try:
+            model = cifar10_models.__dict__[arch]()
+        except KeyError:
+            raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch))
         msglogger.info("=> creating %s model for CIFAR10" % arch)
-        assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch
-        assert not pretrained, "Model %s (CIFAR10) does not have a pretrained model" % arch
-        model = cifar10_models.__dict__[arch]()
     else:
-        print("FATAL ERROR: create_model does not support models for dataset %s" % dataset)
-        exit()
+        raise ValueError('Could not recognize dataset {}'.format(dataset))
 
     if torch.cuda.is_available() and device_ids != -1:
         device = 'cuda'
diff --git a/requirements.txt b/requirements.txt
index e2375a3..2a49ea1 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,3 +17,4 @@ bqplot==0.10.5
 pyyaml
 pytest==3.5.1
 xlsxwriter>=1.1.1
+pretrainedmodels
diff --git a/tests/test_infra.py b/tests/test_infra.py
index d1d898d..2d0524f 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -30,6 +30,40 @@ except ImportError:
 import distiller
 from distiller.apputils import save_checkpoint, load_checkpoint
 from distiller.models import create_model
+import pretrainedmodels
+
+
+
+def test_create_model_cifar():
+    pretrained = False
+    model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
+    with pytest.raises(ValueError):
+        # only cifar _10_ is currently supported
+        model = create_model(pretrained, 'cifar100', 'resnet20_cifar')
+    with pytest.raises(ValueError):
+        model = create_model(pretrained, 'cifar10', 'no_such_model!')
+
+    pretrained = True
+    with pytest.raises(ValueError):
+        # no pretrained models of cifar10
+        model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
+
+
+def test_create_model_imagenet():
+    model = create_model(False, 'imagenet', 'alexnet')
+    model = create_model(False, 'imagenet', 'resnet50')
+    model = create_model(True, 'imagenet', 'resnet50')
+
+    with pytest.raises(ValueError):
+        model = create_model(False, 'imagenet', 'no_such_model!')
+
+
+def test_create_model_pretrainedmodels():
+    premodel_name = 'resnext101_32x4d'
+    model = create_model(True, 'imagenet', premodel_name)
+
+    with pytest.raises(ValueError):
+        model = create_model(False, 'imagenet', 'no_such_model!')
 
 
 def test_load():
-- 
GitLab