diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 381454dce88ad85c560f6dd40be9465a0fd779d2..d017249ecfad0f2eb22af0518e114b75727ea511 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -47,6 +47,24 @@ def classification_num_classes(dataset):
             'imagenet': 1000}.get(dataset, None)
 
 
+def classification_get_dummy_input(dataset, device=None):
+    """Generate a representative dummy (random) input for the specified dataset.
+
+    If a device is specified, then the dummay_input is moved to that device.
+    """
+    if dataset == 'imagenet':
+        dummy_input = torch.randn(1, 3, 224, 224)
+    elif dataset == 'cifar10':
+        dummy_input = torch.randn(1, 3, 32, 32)
+    elif dataset == 'mnist':
+        dummy_input = torch.randn(1, 1, 28, 28)
+    else:
+        raise ValueError("dataset %s is not supported" % dataset)
+    if device:
+        dummy_input = dummy_input.to(device)
+    return dummy_input
+
+
 def __dataset_factory(dataset):
     return {'cifar10': cifar10_get_datasets,
             'mnist': mnist_get_datasets,
diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py
index 38515077bf1f068bdfb0b154fd6276fed16cea65..39b01e9d24aae0602ed14d3d5e0e23af2896bdb2 100755
--- a/distiller/models/mnist/simplenet_mnist.py
+++ b/distiller/models/mnist/simplenet_mnist.py
@@ -39,7 +39,7 @@ class Simplenet(nn.Module):
         x = F.max_pool2d(x, 2, 2)
         x = F.relu(self.conv2(x))
         x = F.max_pool2d(x, 2, 2)
-        x = x.view(-1, 4*4*50)
+        x = x.view(x.size(0), -1)
         x = F.relu(self.fc1(x))
         x = self.fc2(x)
         return F.log_softmax(x, dim=1)
diff --git a/distiller/utils.py b/distiller/utils.py
index 72994ab9d18d0f20234ab4d65d30191061e99607..5104197af393ec6e0fdae78dc60faac9ade6e3b2 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -25,14 +25,13 @@ from copy import deepcopy
 import logging
 import operator
 import random
-
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.backends.cudnn as cudnn
 import yaml
-
 import inspect
+import distiller
 
 msglogger = logging.getLogger()
 
@@ -559,19 +558,7 @@ def has_children(module):
 
 
 def get_dummy_input(dataset, device=None):
-    """Generate a representative dummy (random) input for the specified dataset.
-
-    If a device is specified, then the dummay_input is moved to that device.
-    """
-    if dataset == 'imagenet':
-        dummy_input = torch.randn(1, 3, 224, 224)
-    elif dataset == 'cifar10':
-        dummy_input = torch.randn(1, 3, 32, 32)
-    else:
-        raise ValueError("dataset %s is not supported" % dataset)
-    if device:
-        dummy_input = dummy_input.to(device)
-    return dummy_input
+    return distiller.apputils.classification_get_dummy_input(dataset, device)
 
 
 def make_non_parallel_copy(model):
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
index 89c81e690303f839e4c0e3280369699ee913635e..f187d12501d758d3c0cc3e04750b338eb0cd30d5 100755
--- a/tests/test_model_summary.py
+++ b/tests/test_model_summary.py
@@ -64,5 +64,12 @@ def test_summary(what):
     dataset = "cifar10"
     arch = "resnet20_cifar"
     model, _ = common.setup_test(arch, dataset, parallel=True)
-
     distiller.model_summary(model, what, dataset=dataset)
+
+
+@pytest.mark.parametrize('what', SUMMARY_CHOICES)
+def test_mnist(what):
+    dataset = "mnist"
+    arch = "simplenet_mnist"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
+    distiller.model_summary(model, what, dataset=dataset)
\ No newline at end of file