Skip to content
Snippets Groups Projects
Commit bdafebea authored by Neta Zmora's avatar Neta Zmora
Browse files

create_model: fix bug when trying to use an unsupported dataset

When trying to create a model using an unsupported dataset, create_model()
should raise a ValueError, but didn't.
One of the unit-tests (test_create_model_cifar in test_infra.py) was
designed to test this condition (creating a model using an unsupported
dataset - cifar100 in this case), but due to DataParallel implementation
details, the test condition did not function on multi-GPU machines and
did fail the test.
Unit tests should also be executed on single-GPU machines for full-coverage.
parent 4a331d73
No related branches found
No related tags found
No related merge requests found
......@@ -30,9 +30,11 @@ from distiller.utils import set_model_input_shape_attr
import logging
msglogger = logging.getLogger()
SUPPORTED_DATASETS = ('imagenet', 'cifar10', 'mnist')
# ResNet special treatment: we have our own version of ResNet, so we need to over-ride
# TorchVision's version.
RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
RESNET_SYMS = ('ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152')
TORCHVISION_MODEL_NAMES = sorted(
name for name in torch_models.__dict__
......@@ -86,8 +88,11 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
-1 - CPU
>=0 - GPU device IDs
"""
model = None
dataset = dataset.lower()
if dataset not in SUPPORTED_DATASETS:
raise ValueError('Dataset {} is not supported'.format(dataset))
model = None
cadene = False
try:
if dataset == 'imagenet':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment