From 0209264f75f7e5ef70ab0a4b049b19adff136d1f Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 30 May 2019 00:49:59 +0300
Subject: [PATCH] Added support from the MNIST dataset

Also added a simple network model for MNIST, under
distiller/models/mnist.
---
 distiller/apputils/data_loaders.py            | 65 ++++++++++++++++---
 distiller/models/__init__.py                  | 25 +++++--
 distiller/models/mnist/__init__.py            | 19 ++++++
 distiller/models/mnist/simplenet_mnist.py     | 49 ++++++++++++++
 .../compress_classifier.py                    |  4 +-
 5 files changed, 143 insertions(+), 19 deletions(-)
 create mode 100755 distiller/models/mnist/__init__.py
 create mode 100755 distiller/models/mnist/simplenet_mnist.py

diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 874dbdc..381454d 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -25,10 +25,32 @@ import torchvision.transforms as transforms
 import torchvision.datasets as datasets
 from torch.utils.data.sampler import Sampler
 import numpy as np
-
 import distiller
 
-DATASETS_NAMES = ['imagenet', 'cifar10']
+
+DATASETS_NAMES = ['imagenet', 'cifar10', 'mnist']
+
+
+def classification_dataset_str_from_arch(arch):
+    if 'cifar' in arch:
+        dataset = 'cifar10' 
+    elif 'mnist' in arch:
+        dataset = 'mnist' 
+    else:
+        dataset = 'imagenet'
+    return dataset
+
+
+def classification_num_classes(dataset):
+    return {'cifar10': 10,
+            'mnist': 10,
+            'imagenet': 1000}.get(dataset, None)
+
+
+def __dataset_factory(dataset):
+    return {'cifar10': cifar10_get_datasets,
+            'mnist': mnist_get_datasets,
+            'imagenet': imagenet_get_datasets}.get(dataset, None)
 
 
 def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False,
@@ -45,20 +67,43 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete
         deterministic: set to True if you want the data loading process to be deterministic.
           Note that deterministic data loading suffers from poor performance.
         effective_train/valid/test_size: portion of the datasets to load on each epoch.
-          The subset is chosen randomly each time. For the training and validation sets, this is applied AFTER
-          the split to those sets according to the validation_split parameter
-        fixed_subset: set to True to keep the same subset of data throughout the run (the size of the subset
-          is still determined according to the effective_train/valid/test_size args)
+          The subset is chosen randomly each time. For the training and validation sets,
+          this is applied AFTER the split to those sets according to the validation_split parameter
+        fixed_subset: set to True to keep the same subset of data throughout the run
+          (the size of the subset is still determined according to the effective_train/valid/test
+          size args)
     """
     if dataset not in DATASETS_NAMES:
         raise ValueError('load_data does not support dataset %s" % dataset')
-    datasets_fn = cifar10_get_datasets if dataset == 'cifar10' else imagenet_get_datasets
-    return get_data_loaders(datasets_fn, data_dir, batch_size, workers, validation_split=validation_split,
-                            deterministic=deterministic, effective_train_size=effective_train_size,
-                            effective_valid_size=effective_valid_size, effective_test_size=effective_test_size,
+    datasets_fn = __dataset_factory(dataset)
+    return get_data_loaders(datasets_fn, data_dir, batch_size, workers, 
+                            validation_split=validation_split,
+                            deterministic=deterministic, 
+                            effective_train_size=effective_train_size,
+                            effective_valid_size=effective_valid_size, 
+                            effective_test_size=effective_test_size,
                             fixed_subset=fixed_subset)
 
 
+def mnist_get_datasets(data_dir):
+    """Load the MNIST dataset."""
+    train_transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+    train_dataset = datasets.MNIST(root=data_dir, train=True,
+                                   download=True, transform=train_transform)
+
+    test_transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+    test_dataset = datasets.MNIST(root=data_dir, train=False,
+                                  transform=test_transform)
+
+    return train_dataset, test_dataset
+
+
 def cifar10_get_datasets(data_dir):
     """Load the CIFAR10 dataset.
 
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 30e2c8d..8c3969a 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -19,6 +19,7 @@
 import torch
 import torchvision.models as torch_models
 from . import cifar10 as cifar10_models
+from . import mnist as mnist_models
 from . import imagenet as imagenet_extra_models
 import pretrainedmodels
 
@@ -41,8 +42,12 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
                              if name.islower() and not name.startswith("__")
                              and callable(cifar10_models.__dict__[name]))
 
+MNIST_MODEL_NAMES = sorted(name for name in mnist_models.__dict__
+                           if name.islower() and not name.startswith("__")
+                           and callable(mnist_models.__dict__[name]))
+
 ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
-                            set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
+                            set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES)))
 
 
 def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
@@ -69,9 +74,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         elif (arch in imagenet_extra_models.__dict__) and not pretrained:
             model = imagenet_extra_models.__dict__[arch]()
         elif arch in pretrainedmodels.model_names:
-            model = pretrainedmodels.__dict__[arch](
-                        num_classes=1000,
-                        pretrained=(dataset if pretrained else None))
+            model = pretrainedmodels.__dict__[arch](num_classes=1000,
+                                                    pretrained=(dataset if pretrained else None))
         else:
             error_message = ''
             if arch not in IMAGENET_MODEL_NAMES:
@@ -80,8 +84,6 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
                 error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch)
             raise ValueError(error_message or 'Failed to find model {}'.format(arch))
 
-        msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch,
-            p=('pretrained ' if pretrained else '')))
     elif dataset == 'cifar10':
         if pretrained:
             raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch))
@@ -89,10 +91,19 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
             model = cifar10_models.__dict__[arch]()
         except KeyError:
             raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch))
-        msglogger.info("=> creating %s model for CIFAR10" % arch)
+
+    elif dataset == 'mnist':
+        if pretrained:
+            raise ValueError("Model {} (MNIST) does not have a pretrained model".format(arch))
+        try:
+            model = mnist_models.__dict__[arch]()
+        except KeyError:
+            raise ValueError("Model {} is not supported for dataset MNIST".format(arch))
     else:
         raise ValueError('Could not recognize dataset {}'.format(dataset))
 
+    msglogger.info("=> creating a %s%s model with the %s dataset" % ('pretrained ' if pretrained else '', 
+                                                                     arch, dataset))
     if torch.cuda.is_available() and device_ids != -1:
         device = 'cuda'
         if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel:
diff --git a/distiller/models/mnist/__init__.py b/distiller/models/mnist/__init__.py
new file mode 100755
index 0000000..125515e
--- /dev/null
+++ b/distiller/models/mnist/__init__.py
@@ -0,0 +1,19 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This package contains MNIST image classification models for pytorch"""
+
+from .simplenet_mnist import *
\ No newline at end of file
diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py
new file mode 100755
index 0000000..3851507
--- /dev/null
+++ b/distiller/models/mnist/simplenet_mnist.py
@@ -0,0 +1,49 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""An implementation of a trivial MNIST model.
+ 
+The original network definition is sourced here: https://github.com/pytorch/examples/blob/master/mnist/main.py
+"""
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+__all__ = ['simplenet_mnist']
+
+
+class Simplenet(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5, 1)
+        self.conv2 = nn.Conv2d(20, 50, 5, 1)
+        self.fc1 = nn.Linear(4*4*50, 500)
+        self.fc2 = nn.Linear(500, 10)
+        
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = x.view(-1, 4*4*50)
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+        return F.log_softmax(x, dim=1)
+        
+def simplenet_mnist():
+    model = Simplenet()
+    return model
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index b4e81f0..0f5b78c 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -138,8 +138,8 @@ def main():
             torch.cuda.set_device(args.gpus[0])
 
     # Infer the dataset from the model name
-    args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
-    args.num_classes = 10 if args.dataset == 'cifar10' else 1000
+    args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch)
+    args.num_classes = distiller.apputils.classification_num_classes(args.dataset)
 
     if args.earlyexit_thresholds:
         args.num_exits = len(args.earlyexit_thresholds) + 1
-- 
GitLab