diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 49f5e7bfb77d7b33e9718e6f02326d0c5b3b82de..3ad0ceb4ba74c2a7ba15dcab494284016b169185 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -20,6 +20,7 @@ import torch
 import torchvision.models as torch_models
 from . import cifar10 as cifar10_models
 from . import imagenet as imagenet_extra_models
+import pretrainedmodels
 
 import logging
 msglogger = logging.getLogger()
@@ -34,51 +35,63 @@ IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
 IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__
                                    if name.islower() and not name.startswith("__")
                                    and callable(imagenet_extra_models.__dict__[name])))
+IMAGENET_MODEL_NAMES.extend(pretrainedmodels.model_names)
 
 CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
                              if name.islower() and not name.startswith("__")
                              and callable(cifar10_models.__dict__[name]))
 
-ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
+ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
+                            set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
 
 
 def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     """Create a pytorch model based on the model architecture and dataset
 
     Args:
-        pretrained: True is you wish to load a pretrained model.  Only torchvision models
-          have a pretrained model.
-        dataset:
-        arch:
-        parallel:
+        pretrained [boolean]: True is you wish to load a pretrained model.
+            Some models do not have a pretrained version.
+        dataset: dataset name (only 'imagenet' and 'cifar10' are supported)
+        arch: architecture name
+        parallel [boolean]: if set, use torch.nn.DataParallel
         device_ids: Devices on which model should be created -
             None - GPU if available, otherwise CPU
             -1 - CPU
             >=0 - GPU device IDs
     """
-    msglogger.info('==> using %s dataset' % dataset)
-
     model = None
+    dataset = dataset.lower()
     if dataset == 'imagenet':
-        str_pretrained = 'pretrained ' if pretrained else ''
-        msglogger.info("=> using %s%s model for ImageNet" % (str_pretrained, arch))
-        assert arch in torch_models.__dict__ or arch in imagenet_extra_models.__dict__, \
-            "Model %s is not supported for dataset %s" % (arch, 'ImageNet')
         if arch in RESNET_SYMS:
             model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
         elif arch in torch_models.__dict__:
             model = torch_models.__dict__[arch](pretrained=pretrained)
+        elif (arch in imagenet_extra_models.__dict__) and not pretrained:
+            model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
+        elif arch in pretrainedmodels.model_names:
+            model = pretrainedmodels.__dict__[arch](
+                        num_classes=1000,
+                        pretrained=(dataset if pretrained else None))
         else:
-            assert not pretrained, "Model %s (ImageNet) does not have a pretrained model" % arch
-            model = imagenet_extra_models.__dict__[arch]()
+            error_message = ''
+            if arch not in IMAGENET_MODEL_NAMES:
+                error_message = "Model {} is not supported for dataset ImageNet".format(arch)
+            elif pretrained:
+                error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch)
+            raise ValueError(error_message or 'Failed to find model {}'.format(arch))
+
+        msglogger.info("=> using {p}{a} model for ImageNet".format(a=arch,
+            p=('pretrained ' if pretrained else '')))
     elif dataset == 'cifar10':
+        if pretrained:
+            raise ValueError("Model {} (CIFAR10) does not have a pretrained model".format(arch))
+        try:
+            model = cifar10_models.__dict__[arch]()
+        except KeyError:
+            raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch))
         msglogger.info("=> creating %s model for CIFAR10" % arch)
-        assert arch in cifar10_models.__dict__, "Model %s is not supported for dataset CIFAR10" % arch
-        assert not pretrained, "Model %s (CIFAR10) does not have a pretrained model" % arch
-        model = cifar10_models.__dict__[arch]()
     else:
-        print("FATAL ERROR: create_model does not support models for dataset %s" % dataset)
-        exit()
+        raise ValueError('Could not recognize dataset {}'.format(dataset))
 
     if torch.cuda.is_available() and device_ids != -1:
         device = 'cuda'
diff --git a/requirements.txt b/requirements.txt
index e2375a33acc75f6c3304249ca8bd10721dd75ce4..2a49ea1f5e7f839b91b0a0a2b5b588a414fa67ac 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,3 +17,4 @@ bqplot==0.10.5
 pyyaml
 pytest==3.5.1
 xlsxwriter>=1.1.1
+pretrainedmodels
diff --git a/tests/test_infra.py b/tests/test_infra.py
index d1d898d585cce6e06f81c54979789397cf3f2255..2d0524f5bb77cb9d50195e304b1a1d4a1c9e4e1e 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -30,6 +30,40 @@ except ImportError:
 import distiller
 from distiller.apputils import save_checkpoint, load_checkpoint
 from distiller.models import create_model
+import pretrainedmodels
+
+
+
+def test_create_model_cifar():
+    pretrained = False
+    model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
+    with pytest.raises(ValueError):
+        # only cifar _10_ is currently supported
+        model = create_model(pretrained, 'cifar100', 'resnet20_cifar')
+    with pytest.raises(ValueError):
+        model = create_model(pretrained, 'cifar10', 'no_such_model!')
+
+    pretrained = True
+    with pytest.raises(ValueError):
+        # no pretrained models of cifar10
+        model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
+
+
+def test_create_model_imagenet():
+    model = create_model(False, 'imagenet', 'alexnet')
+    model = create_model(False, 'imagenet', 'resnet50')
+    model = create_model(True, 'imagenet', 'resnet50')
+
+    with pytest.raises(ValueError):
+        model = create_model(False, 'imagenet', 'no_such_model!')
+
+
+def test_create_model_pretrainedmodels():
+    premodel_name = 'resnext101_32x4d'
+    model = create_model(True, 'imagenet', premodel_name)
+
+    with pytest.raises(ValueError):
+        model = create_model(False, 'imagenet', 'no_such_model!')
 
 
 def test_load():