diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index 874dbdc92b7023fac9bd429d03752857fb749c09..381454dce88ad85c560f6dd40be9465a0fd779d2 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -25,10 +25,32 @@ import torchvision.transforms as transforms import torchvision.datasets as datasets from torch.utils.data.sampler import Sampler import numpy as np - import distiller -DATASETS_NAMES = ['imagenet', 'cifar10'] + +DATASETS_NAMES = ['imagenet', 'cifar10', 'mnist'] + + +def classification_dataset_str_from_arch(arch): + if 'cifar' in arch: + dataset = 'cifar10' + elif 'mnist' in arch: + dataset = 'mnist' + else: + dataset = 'imagenet' + return dataset + + +def classification_num_classes(dataset): + return {'cifar10': 10, + 'mnist': 10, + 'imagenet': 1000}.get(dataset, None) + + +def __dataset_factory(dataset): + return {'cifar10': cifar10_get_datasets, + 'mnist': mnist_get_datasets, + 'imagenet': imagenet_get_datasets}.get(dataset, None) def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False, @@ -45,20 +67,43 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete deterministic: set to True if you want the data loading process to be deterministic. Note that deterministic data loading suffers from poor performance. effective_train/valid/test_size: portion of the datasets to load on each epoch. - The subset is chosen randomly each time. For the training and validation sets, this is applied AFTER - the split to those sets according to the validation_split parameter - fixed_subset: set to True to keep the same subset of data throughout the run (the size of the subset - is still determined according to the effective_train/valid/test_size args) + The subset is chosen randomly each time. For the training and validation sets, + this is applied AFTER the split to those sets according to the validation_split parameter + fixed_subset: set to True to keep the same subset of data throughout the run + (the size of the subset is still determined according to the effective_train/valid/test + size args) """ if dataset not in DATASETS_NAMES: raise ValueError('load_data does not support dataset %s" % dataset') - datasets_fn = cifar10_get_datasets if dataset == 'cifar10' else imagenet_get_datasets - return get_data_loaders(datasets_fn, data_dir, batch_size, workers, validation_split=validation_split, - deterministic=deterministic, effective_train_size=effective_train_size, - effective_valid_size=effective_valid_size, effective_test_size=effective_test_size, + datasets_fn = __dataset_factory(dataset) + return get_data_loaders(datasets_fn, data_dir, batch_size, workers, + validation_split=validation_split, + deterministic=deterministic, + effective_train_size=effective_train_size, + effective_valid_size=effective_valid_size, + effective_test_size=effective_test_size, fixed_subset=fixed_subset) +def mnist_get_datasets(data_dir): + """Load the MNIST dataset.""" + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + train_dataset = datasets.MNIST(root=data_dir, train=True, + download=True, transform=train_transform) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + test_dataset = datasets.MNIST(root=data_dir, train=False, + transform=test_transform) + + return train_dataset, test_dataset + + def cifar10_get_datasets(data_dir): """Load the CIFAR10 dataset. diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 30e2c8d7f1c770bf17b8ef8ff1ca0f7b953af98b..8c3969aacd73e27a80c6095dcc14e8736efba06c 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -19,6 +19,7 @@ import torch import torchvision.models as torch_models from . import cifar10 as cifar10_models +from . import mnist as mnist_models from . import imagenet as imagenet_extra_models import pretrainedmodels @@ -41,8 +42,12 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__ if name.islower() and not name.startswith("__") and callable(cifar10_models.__dict__[name])) +MNIST_MODEL_NAMES = sorted(name for name in mnist_models.__dict__ + if name.islower() and not name.startswith("__") + and callable(mnist_models.__dict__[name])) + ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), - set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))) + set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES))) def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): @@ -69,9 +74,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): elif (arch in imagenet_extra_models.__dict__) and not pretrained: model = imagenet_extra_models.__dict__[arch]() elif arch in pretrainedmodels.model_names: - model = pretrainedmodels.__dict__[arch]( - num_classes=1000, - pretrained=(dataset if pretrained else None)) + model = pretrainedmodels.__dict__[arch](num_classes=1000, + pretrained=(dataset if pretrained else None)) else: error_message = '' if arch not in IMAGENET_MODEL_NAMES: @@ -80,8 +84,6 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch) raise ValueError(error_message or 'Failed to find model {}'.format(arch)) - msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch, - p=('pretrained ' if pretrained else ''))) elif dataset == 'cifar10': if pretrained: raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch)) @@ -89,10 +91,19 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): model = cifar10_models.__dict__[arch]() except KeyError: raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch)) - msglogger.info("=> creating %s model for CIFAR10" % arch) + + elif dataset == 'mnist': + if pretrained: + raise ValueError("Model {} (MNIST) does not have a pretrained model".format(arch)) + try: + model = mnist_models.__dict__[arch]() + except KeyError: + raise ValueError("Model {} is not supported for dataset MNIST".format(arch)) else: raise ValueError('Could not recognize dataset {}'.format(dataset)) + msglogger.info("=> creating a %s%s model with the %s dataset" % ('pretrained ' if pretrained else '', + arch, dataset)) if torch.cuda.is_available() and device_ids != -1: device = 'cuda' if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: diff --git a/distiller/models/mnist/__init__.py b/distiller/models/mnist/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..125515ec276a21d8f97cead55dd6474ff43d34d8 --- /dev/null +++ b/distiller/models/mnist/__init__.py @@ -0,0 +1,19 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This package contains MNIST image classification models for pytorch""" + +from .simplenet_mnist import * \ No newline at end of file diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py new file mode 100755 index 0000000000000000000000000000000000000000..38515077bf1f068bdfb0b154fd6276fed16cea65 --- /dev/null +++ b/distiller/models/mnist/simplenet_mnist.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +#Â Â Â Â Â http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""An implementation of a trivial MNIST model. +Â +The original network definition is sourced here: https://github.com/pytorch/examples/blob/master/mnist/main.py +""" + +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = ['simplenet_mnist'] + + +class Simplenet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4*4*50, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4*4*50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +def simplenet_mnist(): + model = Simplenet() + return model diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index b4e81f01c3f5e8739532e5090996cf5dddecc9fa..0f5b78c370928b617c3e7595d95c6e76360d2ed4 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -138,8 +138,8 @@ def main(): torch.cuda.set_device(args.gpus[0]) # Infer the dataset from the model name - args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' - args.num_classes = 10 if args.dataset == 'cifar10' else 1000 + args.dataset = distiller.apputils.classification_dataset_str_from_arch(args.arch) + args.num_classes = distiller.apputils.classification_num_classes(args.dataset) if args.earlyexit_thresholds: args.num_exits = len(args.earlyexit_thresholds) + 1