Skip to content
Snippets Groups Projects
Commit 498b3cb8 authored by Guy Jacob's avatar Guy Jacob
Browse files

get_dummy_input: extend to return tuples of tensors + add tests

parent 4bc263ed
No related branches found
No related tags found
No related merge requests found
......@@ -567,8 +567,18 @@ def _validate_input_shape(dataset, input_shape):
if input_shape is None:
raise ValueError('Must provide either dataset name or input shape')
if not isinstance(input_shape, tuple):
raise ValueError('input shape should be a tuple')
return input_shape
raise TypeError('Shape should be a tuple of integers, or a tuple of tuples of integers')
def val_recurse(in_shape):
if all(isinstance(x, int) for x in in_shape):
if any(x < 0 for x in in_shape):
raise ValueError("Shape can't contain negative dimensions: {}".format(in_shape))
return in_shape
if all(isinstance(x, tuple) for x in in_shape):
return tuple(val_recurse(x) for x in in_shape)
raise TypeError('Shape should be a tuple of integers, or a tuple of tuples of integers')
return val_recurse(input_shape)
def get_dummy_input(dataset=None, device=None, input_shape=None):
......@@ -579,13 +589,22 @@ def get_dummy_input(dataset=None, device=None, input_shape=None):
Args:
dataset (str): Name of dataset from which to infer the shape
device (str or torch.device): Device on which to create the input
input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
input_shape (tuple): Tuple of integers representing the input shape. Can also be a tuple of tuples, allowing
arbitrarily complex collections of tensors. Used only if 'dataset' is None
"""
shape = _validate_input_shape(dataset, input_shape)
dummy_input = torch.randn(shape)
if device:
dummy_input = dummy_input.to(device)
return dummy_input
def create_single(shape):
t = torch.randn(shape)
if device:
t = t.to(device)
return t
def create_recurse(shape):
if all(isinstance(x, int) for x in shape):
return create_single(shape)
return tuple(create_recurse(s) for s in shape)
input_shape = _validate_input_shape(dataset, input_shape)
return create_recurse(input_shape)
def set_model_input_shape_attr(model, dataset=None, input_shape=None):
......@@ -594,7 +613,8 @@ def set_model_input_shape_attr(model, dataset=None, input_shape=None):
Args:
model (nn.Module): Model instance
dataset (str): Name of dataset from which to infer input shape
input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
input_shape (tuple): Tuple of integers representing the input shape. Can also be a tuple of tuples, allowing
arbitrarily complex collections of tensors. Used only if 'dataset' is None
"""
if not hasattr(model, 'input_shape'):
model.input_shape = _validate_input_shape(dataset, input_shape)
......
......@@ -190,3 +190,83 @@ def test_load_gpu_model_on_cpu_with_thinning():
cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
load_lean_checkpoint(cpu_model, "checkpoints/checkpoint.pth.tar")
assert distiller.model_device(cpu_model) == 'cpu'
def test_validate_input_shape():
with pytest.raises(ValueError):
distiller.utils._validate_input_shape('', None)
with pytest.raises(ValueError):
distiller.utils._validate_input_shape('not_a_dataset', None)
with pytest.raises(TypeError):
distiller.utils._validate_input_shape('', 'non_numeric_shape')
with pytest.raises(TypeError):
distiller.utils._validate_input_shape('', ('blah', 2))
with pytest.raises(TypeError):
distiller.utils._validate_input_shape('', (1.5, 2))
with pytest.raises(TypeError):
# Mix "flattened" shape and tuple
distiller.utils._validate_input_shape('', (1, 2, (3, 4)))
s = distiller.utils._validate_input_shape('imagenet', None)
assert s == (1, 3, 224, 224)
s = distiller.utils._validate_input_shape('imagenet', (1, 2))
assert s == (1, 3, 224, 224)
s = distiller.utils._validate_input_shape('', (1, 2))
assert s == (1, 2)
s = distiller.utils._validate_input_shape('', ((1, 2), (3, 4)))
assert s == ((1, 2), (3, 4))
s = distiller.utils._validate_input_shape('', ((1, 2), ((3, 4), (5, 6))))
assert s == ((1, 2), ((3, 4), (5, 6)))
@pytest.mark.parametrize('device', [None, 'cpu', 'cuda:0'])
def test_get_dummy_input(device):
def check_shape_device(t, exp_shape, exp_device):
assert t.shape == exp_shape
assert str(t.device) == exp_device
if device is None:
expected_device = 'cpu'
else:
if 'cuda' in device and not torch.cuda.is_available():
return
expected_device = device
with pytest.raises(ValueError):
distiller.utils.get_dummy_input('', None)
with pytest.raises(ValueError):
distiller.utils.get_dummy_input(dataset='not_a_dataset')
with pytest.raises(TypeError):
distiller.utils.get_dummy_input(input_shape='non_numeric_shape')
with pytest.raises(TypeError):
distiller.utils.get_dummy_input(input_shape=('blah', 2))
with pytest.raises(TypeError):
distiller.utils.get_dummy_input(input_shape=(1.5, 2))
with pytest.raises(TypeError):
# Mix "flattened" shape and tuple
distiller.utils.get_dummy_input(input_shape=(1, 2, (3, 4)))
t = distiller.utils.get_dummy_input(dataset='imagenet', device=device)
check_shape_device(t, (1, 3, 224, 224), expected_device)
t = distiller.utils.get_dummy_input(dataset='imagenet', device=device, input_shape=(1, 2))
check_shape_device(t, (1, 3, 224, 224), expected_device)
shape = 1, 2
t = distiller.utils.get_dummy_input(dataset='', device=device, input_shape=shape)
check_shape_device(t, shape, expected_device)
shape = ((1, 2), (3, 4))
t = distiller.utils.get_dummy_input(device=device, input_shape=shape)
assert isinstance(t, tuple)
check_shape_device(t[0], shape[0], expected_device)
check_shape_device(t[1], shape[1], expected_device)
shape = ((1, 2), ((3, 4), (5, 6)))
t = distiller.utils.get_dummy_input(device=device, input_shape=shape)
assert isinstance(t, tuple)
assert isinstance(t[0], torch.Tensor)
assert isinstance(t[1], tuple)
check_shape_device(t[0], shape[0], expected_device)
check_shape_device(t[1][0], shape[1][0], expected_device)
check_shape_device(t[1][1], shape[1][1], expected_device)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment