diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index 2cfbbba8d2878648cf89d90d909a1b1fa035ada7..fc90e8877deb6f8da54e491aea29f5e342fca328 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -62,10 +62,15 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
     filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar'
     fullpath_best = os.path.join(dir, filename_best)
 
-    checkpoint = {}
-    checkpoint['epoch'] = epoch
-    checkpoint['arch'] = arch
-    checkpoint['state_dict'] = model.state_dict()
+    checkpoint = {'epoch': epoch, 'state_dict': model.state_dict(), 'arch': arch}
+    try:
+        checkpoint['is_parallel'] = model.is_parallel
+        checkpoint['dataset'] = model.dataset
+        if not arch:
+            checkpoint['arch'] = model.arch
+    except NameError:
+        pass
+
     if optimizer is not None:
         checkpoint['optimizer_state_dict'] = optimizer.state_dict()
         checkpoint['optimizer_type'] = type(optimizer)
@@ -105,7 +110,10 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
     """Load a pytorch training checkpoint.
 
     Args:
-        model: the pytorch model to which we will load the parameters
+        model: the pytorch model to which we will load the parameters.  You can
+        specify model=None if the checkpoint contains enough metadata to infer
+        the model.  The order of the arguments is misleading and clunky, and is
+        kept this way for backward compatibility.
         chkpt_file: the checkpoint file
         lean_checkpoint: if set, read into model only 'state_dict' field
         optimizer: [deprecated argument]
@@ -159,8 +167,24 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
             msglogger.warning('Optimizer could not be loaded from checkpoint.')
             return None
 
+    def _create_model_from_ckpt():
+        try:
+            return distiller.models.create_model(False, checkpoint['dataset'], checkpoint['arch'],
+                                                 checkpoint['is_parallel'], device_ids=None)
+        except KeyError:
+            return None
+
+    def _sanity_check():
+        try:
+            if model.arch != checkpoint["arch"]:
+                raise ValueError("The model architecture does not match the checkpoint architecture")
+        except (NameError, KeyError):
+            # One of the values is missing so we can't perform the comparison
+            pass
+
     if not os.path.isfile(chkpt_file):
         raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
+    assert optimizer == None, "argument optimizer is deprecated and must be set to None"
 
     msglogger.info("=> loading checkpoint %s", chkpt_file)
     checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage)
@@ -171,9 +195,14 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
     if 'state_dict' not in checkpoint:
         raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'")
 
+    if not model:
+        model = _create_model_from_ckpt()
+        if not model:
+            raise ValueError("You didn't provide a model, and the checkpoint doesn't contain"
+                             "enough information to create one")
+
     checkpoint_epoch = checkpoint.get('epoch', None)
     start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
-
     compression_scheduler = None
     normalize_dataparallel_keys = False
     if 'compression_sched' in checkpoint:
@@ -217,4 +246,5 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
     optimizer = _load_optimizer()
     msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
                                                                    e=checkpoint_epoch))
+    _sanity_check()
     return model, compression_scheduler, optimizer, start_epoch
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 01f78c3760a10c4ff8e45659b6eb2f016af11a4e..83e653938e833775b17317a834055e0e2b7ac423 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -111,14 +111,19 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
                                                                      arch, dataset))
     if torch.cuda.is_available() and device_ids != -1:
         device = 'cuda'
-        if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel:
-            model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
-        elif parallel:
-            model = torch.nn.DataParallel(model, device_ids=device_ids)
+        if parallel:
+            if arch.startswith('alexnet') or arch.startswith('vgg'):
+                model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
+            else:
+                model = torch.nn.DataParallel(model, device_ids=device_ids)
     else:
         device = 'cpu'
 
+    # Cache some attributes which describe the model
     _set_model_input_shape_attr(model, arch, dataset, pretrained, cadene)
+    model.arch = arch
+    model.dataset = dataset
+    model.is_parallel = parallel
     return model.to(device)
 
 
diff --git a/tests/test_infra.py b/tests/test_infra.py
index c778e514f713abfbd4fd3c1ebb76d60c4d8f4790..2a266de82ed97b7252750ea5f519682d964a2150 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -13,9 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
 import logging
 import tempfile
-
 import torch
 import pytest
 import distiller
@@ -182,8 +182,8 @@ def test_load_gpu_model_on_cpu_with_thinning():
     distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
     assert hasattr(gpu_model, 'thinning_recipes')
     scheduler = distiller.CompressionScheduler(gpu_model)
-    save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None,
-        dir='checkpoints')
+    save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model,
+                    scheduler=scheduler, optimizer=None, dir='checkpoints')
 
     CPU_DEVICE_ID = -1
     cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
@@ -269,3 +269,30 @@ def test_get_dummy_input(device):
     check_shape_device(t[0], shape[0], expected_device)
     check_shape_device(t[1][0], shape[1][0], expected_device)
     check_shape_device(t[1][1], shape[1][1], expected_device)
+
+
+def test_load_checkpoint_without_model():
+    checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'
+    # Load a checkpoint w/o specifying the model: this should fail because the loaded
+    # checkpoint is old and does not have the required metadata to create a model.
+    with pytest.raises(ValueError):
+        load_checkpoint(model=None, chkpt_file=checkpoint_filename)
+
+    for model_device in (None, 'cuda', 'cpu'):
+        # Now we create a new model, save a checkpoint, and load it w/o specifying the model.
+        # This should succeed because the checkpoint has enough metadata to create model.
+        model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
+        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, checkpoint_filename)
+        save_checkpoint(epoch=0, arch='resnet20_cifar', model=model, name='eraseme',
+                        scheduler=compression_scheduler, optimizer=None, dir='checkpoints')
+        temp_checkpoint = os.path.join("checkpoints", "eraseme_checkpoint.pth.tar")
+        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model=None,
+                                                                               chkpt_file=temp_checkpoint,
+                                                                               model_device=model_device)
+        assert compression_scheduler is not None
+        assert optimizer is None
+        assert start_epoch == 1
+        assert model
+        assert model.arch == "resnet20_cifar"
+        assert model.dataset == "cifar10"
+        os.remove(temp_checkpoint)
\ No newline at end of file