From f8085cf483fae29987d532902af98d4bc19f2277 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 30 May 2019 12:31:21 +0300 Subject: [PATCH] MNIST support -Added a test for MNIST -Added classification_get_dummy_input() to apputils/data_loaders.py and wrapped it with get_dummy_input() for (temporary) backward compatibility. - Changed simplenet_mnist so that it supports thinning --- distiller/apputils/data_loaders.py | 18 ++++++++++++++++++ distiller/models/mnist/simplenet_mnist.py | 2 +- distiller/utils.py | 17 ++--------------- tests/test_model_summary.py | 9 ++++++++- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index 381454d..d017249 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -47,6 +47,24 @@ def classification_num_classes(dataset): 'imagenet': 1000}.get(dataset, None) +def classification_get_dummy_input(dataset, device=None): + """Generate a representative dummy (random) input for the specified dataset. + + If a device is specified, then the dummay_input is moved to that device. + """ + if dataset == 'imagenet': + dummy_input = torch.randn(1, 3, 224, 224) + elif dataset == 'cifar10': + dummy_input = torch.randn(1, 3, 32, 32) + elif dataset == 'mnist': + dummy_input = torch.randn(1, 1, 28, 28) + else: + raise ValueError("dataset %s is not supported" % dataset) + if device: + dummy_input = dummy_input.to(device) + return dummy_input + + def __dataset_factory(dataset): return {'cifar10': cifar10_get_datasets, 'mnist': mnist_get_datasets, diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py index 3851507..39b01e9 100755 --- a/distiller/models/mnist/simplenet_mnist.py +++ b/distiller/models/mnist/simplenet_mnist.py @@ -39,7 +39,7 @@ class Simplenet(nn.Module): x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4*4*50) + x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) diff --git a/distiller/utils.py b/distiller/utils.py index 72994ab..5104197 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -25,14 +25,13 @@ from copy import deepcopy import logging import operator import random - import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn import yaml - import inspect +import distiller msglogger = logging.getLogger() @@ -559,19 +558,7 @@ def has_children(module): def get_dummy_input(dataset, device=None): - """Generate a representative dummy (random) input for the specified dataset. - - If a device is specified, then the dummay_input is moved to that device. - """ - if dataset == 'imagenet': - dummy_input = torch.randn(1, 3, 224, 224) - elif dataset == 'cifar10': - dummy_input = torch.randn(1, 3, 32, 32) - else: - raise ValueError("dataset %s is not supported" % dataset) - if device: - dummy_input = dummy_input.to(device) - return dummy_input + return distiller.apputils.classification_get_dummy_input(dataset, device) def make_non_parallel_copy(model): diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index 89c81e6..f187d12 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -64,5 +64,12 @@ def test_summary(what): dataset = "cifar10" arch = "resnet20_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - distiller.model_summary(model, what, dataset=dataset) + + +@pytest.mark.parametrize('what', SUMMARY_CHOICES) +def test_mnist(what): + dataset = "mnist" + arch = "simplenet_mnist" + model, _ = common.setup_test(arch, dataset, parallel=True) + distiller.model_summary(model, what, dataset=dataset) \ No newline at end of file -- GitLab