diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index 381454dce88ad85c560f6dd40be9465a0fd779d2..d017249ecfad0f2eb22af0518e114b75727ea511 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -47,6 +47,24 @@ def classification_num_classes(dataset): 'imagenet': 1000}.get(dataset, None) +def classification_get_dummy_input(dataset, device=None): + """Generate a representative dummy (random) input for the specified dataset. + + If a device is specified, then the dummay_input is moved to that device. + """ + if dataset == 'imagenet': + dummy_input = torch.randn(1, 3, 224, 224) + elif dataset == 'cifar10': + dummy_input = torch.randn(1, 3, 32, 32) + elif dataset == 'mnist': + dummy_input = torch.randn(1, 1, 28, 28) + else: + raise ValueError("dataset %s is not supported" % dataset) + if device: + dummy_input = dummy_input.to(device) + return dummy_input + + def __dataset_factory(dataset): return {'cifar10': cifar10_get_datasets, 'mnist': mnist_get_datasets, diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py index 38515077bf1f068bdfb0b154fd6276fed16cea65..39b01e9d24aae0602ed14d3d5e0e23af2896bdb2 100755 --- a/distiller/models/mnist/simplenet_mnist.py +++ b/distiller/models/mnist/simplenet_mnist.py @@ -39,7 +39,7 @@ class Simplenet(nn.Module): x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4*4*50) + x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) diff --git a/distiller/utils.py b/distiller/utils.py index 72994ab9d18d0f20234ab4d65d30191061e99607..5104197af393ec6e0fdae78dc60faac9ade6e3b2 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -25,14 +25,13 @@ from copy import deepcopy import logging import operator import random - import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn import yaml - import inspect +import distiller msglogger = logging.getLogger() @@ -559,19 +558,7 @@ def has_children(module): def get_dummy_input(dataset, device=None): - """Generate a representative dummy (random) input for the specified dataset. - - If a device is specified, then the dummay_input is moved to that device. - """ - if dataset == 'imagenet': - dummy_input = torch.randn(1, 3, 224, 224) - elif dataset == 'cifar10': - dummy_input = torch.randn(1, 3, 32, 32) - else: - raise ValueError("dataset %s is not supported" % dataset) - if device: - dummy_input = dummy_input.to(device) - return dummy_input + return distiller.apputils.classification_get_dummy_input(dataset, device) def make_non_parallel_copy(model): diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index 89c81e690303f839e4c0e3280369699ee913635e..f187d12501d758d3c0cc3e04750b338eb0cd30d5 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -64,5 +64,12 @@ def test_summary(what): dataset = "cifar10" arch = "resnet20_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - distiller.model_summary(model, what, dataset=dataset) + + +@pytest.mark.parametrize('what', SUMMARY_CHOICES) +def test_mnist(what): + dataset = "mnist" + arch = "simplenet_mnist" + model, _ = common.setup_test(arch, dataset, parallel=True) + distiller.model_summary(model, what, dataset=dataset) \ No newline at end of file