diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 3792152f7227331f1b82f9634d1e76db57dc4582..26f7fff4f61c446f63a3c79453677e4c0cba0223 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -26,6 +26,7 @@ from errno import ENOENT
 import logging
 import torch
 import distiller
+from distiller.utils import normalize_module_name
 msglogger = logging.getLogger()
 
 
@@ -80,45 +81,60 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
         chkpt_file: the checkpoint file
         optimizer: the optimizer to which we will load the serialized state
     """
+    if not os.path.isfile(chkpt_file):
+        raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
+
+    msglogger.info("=> loading checkpoint %s", chkpt_file)
+    checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
+    msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))
+
+    if 'state_dict' not in checkpoint:
+        raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'")
+
+    checkpoint_epoch = checkpoint.get('epoch', None)
+    start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0
+
+    best_top1 = checkpoint.get('best_top1', None)
+    if best_top1 is not None:
+        msglogger.info("   best top@1: %.3f", best_top1)
+
     compression_scheduler = None
-    start_epoch = 0
-
-    if os.path.isfile(chkpt_file):
-        msglogger.info("=> loading checkpoint %s", chkpt_file)
-        checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
-        msglogger.info("Checkpoint keys:\n{}".format("\n\t".join(k for k in checkpoint.keys())))
-        start_epoch = checkpoint.get('epoch', -1) + 1
-        best_top1 = checkpoint.get('best_top1', None)
-        if best_top1 is not None:
-            msglogger.info("   best top@1: %.3f", best_top1)
-
-        if 'compression_sched' in checkpoint:
-            compression_scheduler = distiller.CompressionScheduler(model)
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'])
-            msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
-                           checkpoint['epoch'])
-        else:
-            msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
-
-        if 'thinning_recipes' in checkpoint:
-            if 'compression_sched' not in checkpoint:
-                raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
-            msglogger.info("Loaded a thinning recipe from the checkpoint")
-            # Cache the recipes in case we need them later
-            model.thinning_recipes = checkpoint['thinning_recipes']
-            distiller.execute_thinning_recipes_list(model,
-                                                    compression_scheduler.zeros_mask_dict,
-                                                    model.thinning_recipes)
-
-        if 'quantizer_metadata' in checkpoint:
-            msglogger.info('Loaded quantizer metadata from the checkpoint')
-            qmd = checkpoint['quantizer_metadata']
-            quantizer = qmd['type'](model, **qmd['params'])
-            quantizer.prepare_model()
-
-        msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, start_epoch-1)
-
-        model.load_state_dict(checkpoint['state_dict'])
-        return model, compression_scheduler, start_epoch
+    normalize_dataparallel_keys = False
+    if 'compression_sched' in checkpoint:
+        compression_scheduler = distiller.CompressionScheduler(model)
+        try:
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
+        except KeyError as e:
+            # A very common source of this KeyError is loading a GPU model on the CPU.
+            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
+            normalize_dataparallel_keys = True
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
+        msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
+            checkpoint_epoch))
     else:
-        raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
+        msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
+
+    if 'thinning_recipes' in checkpoint:
+        if 'compression_sched' not in checkpoint:
+            raise KeyError("Found thinning_recipes key, but missing mandatory key compression_sched")
+        msglogger.info("Loaded a thinning recipe from the checkpoint")
+        # Cache the recipes in case we need them later
+        model.thinning_recipes = checkpoint['thinning_recipes']
+        if normalize_dataparallel_keys:
+            model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()}         
+        distiller.execute_thinning_recipes_list(model,
+                                                compression_scheduler.zeros_mask_dict,
+                                                model.thinning_recipes)
+
+    if 'quantizer_metadata' in checkpoint:
+        msglogger.info('Loaded quantizer metadata from the checkpoint')
+        qmd = checkpoint['quantizer_metadata']
+        quantizer = qmd['type'](model, **qmd['params'])
+        quantizer.prepare_model()
+
+    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
+                                                                   e=checkpoint_epoch))
+    if normalize_dataparallel_keys:
+            checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
+    model.load_state_dict(checkpoint['state_dict'])
+    return (model, compression_scheduler, start_epoch)
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 6e65238a0abc0d990207fd0e5782c3517bdddf2c..a5e360d01d3e7dfa0bb861d6d9a87fda83ed1fa6 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -18,12 +18,14 @@
 
 This implements the scheduling of the compression policies.
 """
+import contextlib
 from functools import partial
 import logging
+
 import torch
 from .quantization.quantizer import FP_BKP_PREFIX
 from .policy import PolicyLoss, LossComponent
-from .utils import model_device
+from .utils import model_device, normalize_module_name
 msglogger = logging.getLogger()
 
 
@@ -187,26 +189,31 @@ class CompressionScheduler(object):
         state = {'masks_dict': masks}
         return state
 
-    def load_state_dict(self, state):
+    def load_state_dict(self, state, normalize_dataparallel_keys):
         """Loads the scheduler state.
 
         Currently the scheduler state is comprised only of the set of pruning masks.
 
         Arguments:
             state_dict (dict): scheduler state. Should be an object returned
-                from a call to :meth:`state_dict`.  It is a dictionary of parameter
+                from a call to :meth:`state_dict`. It is a dictionary of parameter
                 names (keys) and parameter masks (values).
+            normalize_dataparallel_keys (bool): indicates if we should convert the keys from
+                DataParallel format.  This should be set to True when loading a model
+                from a GPU-checkpoint onto a CPU (because currently we don't use DataParallel
+                on the CPU).
         """
         try:
             loaded_masks = state['masks_dict']
-        except Exception as exception:
-            print("ERROR: could not load the CompressionScheduler state")
-            print("Exception: %s %s" % (type(exception), exception))
-            print("\t\tFound the following keys in the state dictionary:")
-            for k in state.keys():
-                print("\t\t" + k)
-            exit(1)
-
+        except KeyError as exception:
+            msglogger.error('could not load the CompressionScheduler state.'
+                ' masks_dict is missing from state')
+            with contextlib.suppress(TypeError):
+                msglogger.debug('Scheduler state keys are: {}'.format(', '.join(state)))
+            raise
+
+        if normalize_dataparallel_keys:
+            loaded_masks = {normalize_module_name(k): v for k, v in loaded_masks.items()}
         device = model_device(self.model)
         for name, mask in self.zeros_mask_dict.items():
             masker = self.zeros_mask_dict[name]
diff --git a/tests/test_infra.py b/tests/test_infra.py
index 099598016c5368b19096379ab8ff592473acd2a1..a101bdb82d4b5d5e058a0e53992b7a4b6e153cab 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -17,13 +17,18 @@
 import logging
 import os
 import sys
+import tempfile
+
+import torch
 import pytest
 module_path = os.path.abspath(os.path.join('..'))
 if module_path not in sys.path:
     sys.path.append(module_path)
 
-from models import create_model
+import distiller
 from apputils import load_checkpoint
+from models import create_model
+
 
 def test_load():
     logger = logging.getLogger('simple_example')
@@ -34,7 +39,42 @@ def test_load():
     assert compression_scheduler is not None
     assert start_epoch == 180
 
+def test_load_state_dict():
+    # prepare lean checkpoint
+    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
+
+    with tempfile.NamedTemporaryFile() as tmpfile:
+        torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
+        model = create_model(False, 'cifar10', 'resnet20_cifar')
+        model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
+
+    assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0
+    assert compression_scheduler is None
+    assert start_epoch == 0
+
+def test_load_dumb_checkpoint():
+    # prepare lean checkpoint
+    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
+
+    with tempfile.NamedTemporaryFile() as tmpfile:
+        torch.save(state_dict_arrays, tmpfile.name)
+        model = create_model(False, 'cifar10', 'resnet20_cifar')
+        with pytest.raises(ValueError):
+            model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
+
 def test_load_negative():
     with pytest.raises(FileNotFoundError):
         model = create_model(False, 'cifar10', 'resnet20_cifar')
         model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
+
+
+def test_load_gpu_model_on_cpu():
+    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1)
+    model, compression_scheduler, start_epoch = load_checkpoint(model,
+                                                                '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
+    assert compression_scheduler is not None
+    assert start_epoch == 180
+    assert distiller.model_device(model) == 'cpu'
+
+if __name__ == '__main__':
+    test_load_gpu_model_on_cpu()