From f8085cf483fae29987d532902af98d4bc19f2277 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 30 May 2019 12:31:21 +0300
Subject: [PATCH] MNIST support

-Added a test for MNIST
-Added classification_get_dummy_input() to apputils/data_loaders.py
and wrapped it with get_dummy_input() for (temporary) backward
compatibility.
- Changed simplenet_mnist so that it supports thinning
---
 distiller/apputils/data_loaders.py        | 18 ++++++++++++++++++
 distiller/models/mnist/simplenet_mnist.py |  2 +-
 distiller/utils.py                        | 17 ++---------------
 tests/test_model_summary.py               |  9 ++++++++-
 4 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 381454d..d017249 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -47,6 +47,24 @@ def classification_num_classes(dataset):
             'imagenet': 1000}.get(dataset, None)
 
 
+def classification_get_dummy_input(dataset, device=None):
+    """Generate a representative dummy (random) input for the specified dataset.
+
+    If a device is specified, then the dummay_input is moved to that device.
+    """
+    if dataset == 'imagenet':
+        dummy_input = torch.randn(1, 3, 224, 224)
+    elif dataset == 'cifar10':
+        dummy_input = torch.randn(1, 3, 32, 32)
+    elif dataset == 'mnist':
+        dummy_input = torch.randn(1, 1, 28, 28)
+    else:
+        raise ValueError("dataset %s is not supported" % dataset)
+    if device:
+        dummy_input = dummy_input.to(device)
+    return dummy_input
+
+
 def __dataset_factory(dataset):
     return {'cifar10': cifar10_get_datasets,
             'mnist': mnist_get_datasets,
diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py
index 3851507..39b01e9 100755
--- a/distiller/models/mnist/simplenet_mnist.py
+++ b/distiller/models/mnist/simplenet_mnist.py
@@ -39,7 +39,7 @@ class Simplenet(nn.Module):
         x = F.max_pool2d(x, 2, 2)
         x = F.relu(self.conv2(x))
         x = F.max_pool2d(x, 2, 2)
-        x = x.view(-1, 4*4*50)
+        x = x.view(x.size(0), -1)
         x = F.relu(self.fc1(x))
         x = self.fc2(x)
         return F.log_softmax(x, dim=1)
diff --git a/distiller/utils.py b/distiller/utils.py
index 72994ab..5104197 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -25,14 +25,13 @@ from copy import deepcopy
 import logging
 import operator
 import random
-
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.backends.cudnn as cudnn
 import yaml
-
 import inspect
+import distiller
 
 msglogger = logging.getLogger()
 
@@ -559,19 +558,7 @@ def has_children(module):
 
 
 def get_dummy_input(dataset, device=None):
-    """Generate a representative dummy (random) input for the specified dataset.
-
-    If a device is specified, then the dummay_input is moved to that device.
-    """
-    if dataset == 'imagenet':
-        dummy_input = torch.randn(1, 3, 224, 224)
-    elif dataset == 'cifar10':
-        dummy_input = torch.randn(1, 3, 32, 32)
-    else:
-        raise ValueError("dataset %s is not supported" % dataset)
-    if device:
-        dummy_input = dummy_input.to(device)
-    return dummy_input
+    return distiller.apputils.classification_get_dummy_input(dataset, device)
 
 
 def make_non_parallel_copy(model):
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
index 89c81e6..f187d12 100755
--- a/tests/test_model_summary.py
+++ b/tests/test_model_summary.py
@@ -64,5 +64,12 @@ def test_summary(what):
     dataset = "cifar10"
     arch = "resnet20_cifar"
     model, _ = common.setup_test(arch, dataset, parallel=True)
-
     distiller.model_summary(model, what, dataset=dataset)
+
+
+@pytest.mark.parametrize('what', SUMMARY_CHOICES)
+def test_mnist(what):
+    dataset = "mnist"
+    arch = "simplenet_mnist"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
+    distiller.model_summary(model, what, dataset=dataset)
\ No newline at end of file
-- 
GitLab