From 2cab774155e6b8e34e71da1adbc85ee6c997f442 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Tue, 18 Jun 2019 14:23:51 +0300
Subject: [PATCH] Input shape attribute in models (#292)

Setting this attribute will become a requirement for some upcoming features

* Add utility function to set input_shape attribute in models
* Set this attribute in our classification models factory
  (models.create_model())
* Add shape parameter in get_dummy_input() (in addition to dataset)
---
 distiller/apputils/data_loaders.py | 15 +++--------
 distiller/models/__init__.py       | 15 +++++++++++
 distiller/utils.py                 | 43 ++++++++++++++++++++++++++++--
 3 files changed, 60 insertions(+), 13 deletions(-)

diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 7fc31ec..ff4eeea 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -47,22 +47,15 @@ def classification_num_classes(dataset):
             'imagenet': 1000}.get(dataset, None)
 
 
-def classification_get_dummy_input(dataset, device=None):
-    """Generate a representative dummy (random) input for the specified dataset.
-
-    If a device is specified, then the dummay_input is moved to that device.
-    """
+def classification_get_input_shape(dataset):
     if dataset == 'imagenet':
-        dummy_input = torch.randn(1, 3, 224, 224)
+        return 1, 3, 224, 224
     elif dataset == 'cifar10':
-        dummy_input = torch.randn(1, 3, 32, 32)
+        return 1, 3, 32, 32
     elif dataset == 'mnist':
-        dummy_input = torch.randn(1, 1, 28, 28)
+        return 1, 1, 28, 28
     else:
         raise ValueError("dataset %s is not supported" % dataset)
-    if device:
-        dummy_input = dummy_input.to(device)
-    return dummy_input
 
 
 def __dataset_factory(dataset):
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 8c3969a..6eea6fd 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -23,6 +23,8 @@ from . import mnist as mnist_models
 from . import imagenet as imagenet_extra_models
 import pretrainedmodels
 
+from distiller.utils import set_model_input_shape_attr
+
 import logging
 msglogger = logging.getLogger()
 
@@ -66,6 +68,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     """
     model = None
     dataset = dataset.lower()
+    cadene = False
     if dataset == 'imagenet':
         if arch in RESNET_SYMS:
             model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
@@ -74,6 +77,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         elif (arch in imagenet_extra_models.__dict__) and not pretrained:
             model = imagenet_extra_models.__dict__[arch]()
         elif arch in pretrainedmodels.model_names:
+            cadene = True
             model = pretrainedmodels.__dict__[arch](num_classes=1000,
                                                     pretrained=(dataset if pretrained else None))
         else:
@@ -113,4 +117,15 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     else:
         device = 'cpu'
 
+    if cadene and pretrained:
+        # When using pre-trained weights, Cadene models already have an input size attribute
+        # We add the batch dimension to it
+        input_size = model.module.input_size if isinstance(model, torch.nn.DataParallel) else model.input_size
+        shape = tuple([1] + input_size)
+        set_model_input_shape_attr(model, input_shape=shape)
+    elif arch == 'inception_v3':
+        set_model_input_shape_attr(model, input_shape=(1, 3, 299, 299))
+    else:
+        set_model_input_shape_attr(model, dataset=dataset)
+
     return model.to(device)
diff --git a/distiller/utils.py b/distiller/utils.py
index 5104197..7b3777e 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -557,8 +557,47 @@ def has_children(module):
         return False
 
 
-def get_dummy_input(dataset, device=None):
-    return distiller.apputils.classification_get_dummy_input(dataset, device)
+def _validate_input_shape(dataset, input_shape):
+    if dataset:
+        try:
+            return tuple(distiller.apputils.classification_get_input_shape(dataset))
+        except ValueError:
+            raise ValueError("Can't infer input shape for dataset {}, please pass shape directly".format(dataset))
+    else:
+        if input_shape is None:
+            raise ValueError('Must provide either dataset name or input shape')
+        if not isinstance(input_shape, tuple):
+            raise ValueError('input shape should be a tuple')
+        return input_shape
+
+
+def get_dummy_input(dataset=None, device=None, input_shape=None):
+    """Generate a representative dummy (random) input.
+
+    If a device is specified, then the dummy_input is moved to that device.
+
+    Args:
+        dataset (str): Name of dataset from which to infer the shape
+        device (str or torch.device): Device on which to create the input
+        input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
+    """
+    shape = _validate_input_shape(dataset, input_shape)
+    dummy_input = torch.randn(shape)
+    if device:
+        dummy_input = dummy_input.to(device)
+    return dummy_input
+
+
+def set_model_input_shape_attr(model, dataset=None, input_shape=None):
+    """Sets an attribute named 'input_shape' within the model instance, specifying the expected input shape
+
+    Args:
+          model (nn.Module): Model instance
+          dataset (str): Name of dataset from which to infer input shape
+          input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
+    """
+    if not hasattr(model, 'input_shape'):
+        model.input_shape = _validate_input_shape(dataset, input_shape)
 
 
 def make_non_parallel_copy(model):
-- 
GitLab