diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 7fc31ec6837929faca7593931d556490b8b3ec42..ff4eeeaea23e0b98ad5fa7dad09af4bd34852f4b 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -47,22 +47,15 @@ def classification_num_classes(dataset):
             'imagenet': 1000}.get(dataset, None)
 
 
-def classification_get_dummy_input(dataset, device=None):
-    """Generate a representative dummy (random) input for the specified dataset.
-
-    If a device is specified, then the dummay_input is moved to that device.
-    """
+def classification_get_input_shape(dataset):
     if dataset == 'imagenet':
-        dummy_input = torch.randn(1, 3, 224, 224)
+        return 1, 3, 224, 224
     elif dataset == 'cifar10':
-        dummy_input = torch.randn(1, 3, 32, 32)
+        return 1, 3, 32, 32
     elif dataset == 'mnist':
-        dummy_input = torch.randn(1, 1, 28, 28)
+        return 1, 1, 28, 28
     else:
         raise ValueError("dataset %s is not supported" % dataset)
-    if device:
-        dummy_input = dummy_input.to(device)
-    return dummy_input
 
 
 def __dataset_factory(dataset):
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 8c3969aacd73e27a80c6095dcc14e8736efba06c..6eea6fddd289ee3acb04617dc2b66c6f5a4e14c5 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -23,6 +23,8 @@ from . import mnist as mnist_models
 from . import imagenet as imagenet_extra_models
 import pretrainedmodels
 
+from distiller.utils import set_model_input_shape_attr
+
 import logging
 msglogger = logging.getLogger()
 
@@ -66,6 +68,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     """
     model = None
     dataset = dataset.lower()
+    cadene = False
     if dataset == 'imagenet':
         if arch in RESNET_SYMS:
             model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
@@ -74,6 +77,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         elif (arch in imagenet_extra_models.__dict__) and not pretrained:
             model = imagenet_extra_models.__dict__[arch]()
         elif arch in pretrainedmodels.model_names:
+            cadene = True
             model = pretrainedmodels.__dict__[arch](num_classes=1000,
                                                     pretrained=(dataset if pretrained else None))
         else:
@@ -113,4 +117,15 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     else:
         device = 'cpu'
 
+    if cadene and pretrained:
+        # When using pre-trained weights, Cadene models already have an input size attribute
+        # We add the batch dimension to it
+        input_size = model.module.input_size if isinstance(model, torch.nn.DataParallel) else model.input_size
+        shape = tuple([1] + input_size)
+        set_model_input_shape_attr(model, input_shape=shape)
+    elif arch == 'inception_v3':
+        set_model_input_shape_attr(model, input_shape=(1, 3, 299, 299))
+    else:
+        set_model_input_shape_attr(model, dataset=dataset)
+
     return model.to(device)
diff --git a/distiller/utils.py b/distiller/utils.py
index 5104197af393ec6e0fdae78dc60faac9ade6e3b2..7b3777e8709e4b64798125a0fb7d14563d722a31 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -557,8 +557,47 @@ def has_children(module):
         return False
 
 
-def get_dummy_input(dataset, device=None):
-    return distiller.apputils.classification_get_dummy_input(dataset, device)
+def _validate_input_shape(dataset, input_shape):
+    if dataset:
+        try:
+            return tuple(distiller.apputils.classification_get_input_shape(dataset))
+        except ValueError:
+            raise ValueError("Can't infer input shape for dataset {}, please pass shape directly".format(dataset))
+    else:
+        if input_shape is None:
+            raise ValueError('Must provide either dataset name or input shape')
+        if not isinstance(input_shape, tuple):
+            raise ValueError('input shape should be a tuple')
+        return input_shape
+
+
+def get_dummy_input(dataset=None, device=None, input_shape=None):
+    """Generate a representative dummy (random) input.
+
+    If a device is specified, then the dummy_input is moved to that device.
+
+    Args:
+        dataset (str): Name of dataset from which to infer the shape
+        device (str or torch.device): Device on which to create the input
+        input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
+    """
+    shape = _validate_input_shape(dataset, input_shape)
+    dummy_input = torch.randn(shape)
+    if device:
+        dummy_input = dummy_input.to(device)
+    return dummy_input
+
+
+def set_model_input_shape_attr(model, dataset=None, input_shape=None):
+    """Sets an attribute named 'input_shape' within the model instance, specifying the expected input shape
+
+    Args:
+          model (nn.Module): Model instance
+          dataset (str): Name of dataset from which to infer input shape
+          input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
+    """
+    if not hasattr(model, 'input_shape'):
+        model.input_shape = _validate_input_shape(dataset, input_shape)
 
 
 def make_non_parallel_copy(model):