diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index 7fc31ec6837929faca7593931d556490b8b3ec42..ff4eeeaea23e0b98ad5fa7dad09af4bd34852f4b 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -47,22 +47,15 @@ def classification_num_classes(dataset): 'imagenet': 1000}.get(dataset, None) -def classification_get_dummy_input(dataset, device=None): - """Generate a representative dummy (random) input for the specified dataset. - - If a device is specified, then the dummay_input is moved to that device. - """ +def classification_get_input_shape(dataset): if dataset == 'imagenet': - dummy_input = torch.randn(1, 3, 224, 224) + return 1, 3, 224, 224 elif dataset == 'cifar10': - dummy_input = torch.randn(1, 3, 32, 32) + return 1, 3, 32, 32 elif dataset == 'mnist': - dummy_input = torch.randn(1, 1, 28, 28) + return 1, 1, 28, 28 else: raise ValueError("dataset %s is not supported" % dataset) - if device: - dummy_input = dummy_input.to(device) - return dummy_input def __dataset_factory(dataset): diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 8c3969aacd73e27a80c6095dcc14e8736efba06c..6eea6fddd289ee3acb04617dc2b66c6f5a4e14c5 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -23,6 +23,8 @@ from . import mnist as mnist_models from . import imagenet as imagenet_extra_models import pretrainedmodels +from distiller.utils import set_model_input_shape_attr + import logging msglogger = logging.getLogger() @@ -66,6 +68,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): """ model = None dataset = dataset.lower() + cadene = False if dataset == 'imagenet': if arch in RESNET_SYMS: model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) @@ -74,6 +77,7 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): elif (arch in imagenet_extra_models.__dict__) and not pretrained: model = imagenet_extra_models.__dict__[arch]() elif arch in pretrainedmodels.model_names: + cadene = True model = pretrainedmodels.__dict__[arch](num_classes=1000, pretrained=(dataset if pretrained else None)) else: @@ -113,4 +117,15 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): else: device = 'cpu' + if cadene and pretrained: + # When using pre-trained weights, Cadene models already have an input size attribute + # We add the batch dimension to it + input_size = model.module.input_size if isinstance(model, torch.nn.DataParallel) else model.input_size + shape = tuple([1] + input_size) + set_model_input_shape_attr(model, input_shape=shape) + elif arch == 'inception_v3': + set_model_input_shape_attr(model, input_shape=(1, 3, 299, 299)) + else: + set_model_input_shape_attr(model, dataset=dataset) + return model.to(device) diff --git a/distiller/utils.py b/distiller/utils.py index 5104197af393ec6e0fdae78dc60faac9ade6e3b2..7b3777e8709e4b64798125a0fb7d14563d722a31 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -557,8 +557,47 @@ def has_children(module): return False -def get_dummy_input(dataset, device=None): - return distiller.apputils.classification_get_dummy_input(dataset, device) +def _validate_input_shape(dataset, input_shape): + if dataset: + try: + return tuple(distiller.apputils.classification_get_input_shape(dataset)) + except ValueError: + raise ValueError("Can't infer input shape for dataset {}, please pass shape directly".format(dataset)) + else: + if input_shape is None: + raise ValueError('Must provide either dataset name or input shape') + if not isinstance(input_shape, tuple): + raise ValueError('input shape should be a tuple') + return input_shape + + +def get_dummy_input(dataset=None, device=None, input_shape=None): + """Generate a representative dummy (random) input. + + If a device is specified, then the dummy_input is moved to that device. + + Args: + dataset (str): Name of dataset from which to infer the shape + device (str or torch.device): Device on which to create the input + input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None + """ + shape = _validate_input_shape(dataset, input_shape) + dummy_input = torch.randn(shape) + if device: + dummy_input = dummy_input.to(device) + return dummy_input + + +def set_model_input_shape_attr(model, dataset=None, input_shape=None): + """Sets an attribute named 'input_shape' within the model instance, specifying the expected input shape + + Args: + model (nn.Module): Model instance + dataset (str): Name of dataset from which to infer input shape + input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None + """ + if not hasattr(model, 'input_shape'): + model.input_shape = _validate_input_shape(dataset, input_shape) def make_non_parallel_copy(model):