Skip to content
Snippets Groups Projects
Commit f8085cf4 authored by Neta Zmora's avatar Neta Zmora
Browse files

MNIST support

-Added a test for MNIST
-Added classification_get_dummy_input() to apputils/data_loaders.py
and wrapped it with get_dummy_input() for (temporary) backward
compatibility.
- Changed simplenet_mnist so that it supports thinning
parent 0209264f
No related branches found
No related tags found
No related merge requests found
...@@ -47,6 +47,24 @@ def classification_num_classes(dataset): ...@@ -47,6 +47,24 @@ def classification_num_classes(dataset):
'imagenet': 1000}.get(dataset, None) 'imagenet': 1000}.get(dataset, None)
def classification_get_dummy_input(dataset, device=None):
"""Generate a representative dummy (random) input for the specified dataset.
If a device is specified, then the dummay_input is moved to that device.
"""
if dataset == 'imagenet':
dummy_input = torch.randn(1, 3, 224, 224)
elif dataset == 'cifar10':
dummy_input = torch.randn(1, 3, 32, 32)
elif dataset == 'mnist':
dummy_input = torch.randn(1, 1, 28, 28)
else:
raise ValueError("dataset %s is not supported" % dataset)
if device:
dummy_input = dummy_input.to(device)
return dummy_input
def __dataset_factory(dataset): def __dataset_factory(dataset):
return {'cifar10': cifar10_get_datasets, return {'cifar10': cifar10_get_datasets,
'mnist': mnist_get_datasets, 'mnist': mnist_get_datasets,
......
...@@ -39,7 +39,7 @@ class Simplenet(nn.Module): ...@@ -39,7 +39,7 @@ class Simplenet(nn.Module):
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x)) x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50) x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
......
...@@ -25,14 +25,13 @@ from copy import deepcopy ...@@ -25,14 +25,13 @@ from copy import deepcopy
import logging import logging
import operator import operator
import random import random
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import yaml import yaml
import inspect import inspect
import distiller
msglogger = logging.getLogger() msglogger = logging.getLogger()
...@@ -559,19 +558,7 @@ def has_children(module): ...@@ -559,19 +558,7 @@ def has_children(module):
def get_dummy_input(dataset, device=None): def get_dummy_input(dataset, device=None):
"""Generate a representative dummy (random) input for the specified dataset. return distiller.apputils.classification_get_dummy_input(dataset, device)
If a device is specified, then the dummay_input is moved to that device.
"""
if dataset == 'imagenet':
dummy_input = torch.randn(1, 3, 224, 224)
elif dataset == 'cifar10':
dummy_input = torch.randn(1, 3, 32, 32)
else:
raise ValueError("dataset %s is not supported" % dataset)
if device:
dummy_input = dummy_input.to(device)
return dummy_input
def make_non_parallel_copy(model): def make_non_parallel_copy(model):
......
...@@ -64,5 +64,12 @@ def test_summary(what): ...@@ -64,5 +64,12 @@ def test_summary(what):
dataset = "cifar10" dataset = "cifar10"
arch = "resnet20_cifar" arch = "resnet20_cifar"
model, _ = common.setup_test(arch, dataset, parallel=True) model, _ = common.setup_test(arch, dataset, parallel=True)
distiller.model_summary(model, what, dataset=dataset) distiller.model_summary(model, what, dataset=dataset)
@pytest.mark.parametrize('what', SUMMARY_CHOICES)
def test_mnist(what):
dataset = "mnist"
arch = "simplenet_mnist"
model, _ = common.setup_test(arch, dataset, parallel=True)
distiller.model_summary(model, what, dataset=dataset)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment