diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index ce7648250553b744012f122c33068cbde2a53696..2cfbbba8d2878648cf89d90d909a1b1fa035ada7 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
         checkpoint['quantizer_metadata'] = model.quantizer_metadata
 
     checkpoint['extras'] = extras
-
     torch.save(checkpoint, fullpath)
     if is_best:
         shutil.copyfile(fullpath, fullpath_best)
@@ -101,8 +100,8 @@ def get_contents_table(d):
     return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql")
 
 
-def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, 
-                    lean_checkpoint=False, strict=False):
+def load_checkpoint(model, chkpt_file, optimizer=None,
+                    model_device=None, lean_checkpoint=False, strict=False):
     """Load a pytorch training checkpoint.
 
     Args:
@@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
                 This should be set to either 'cpu' or 'cuda'.
     :returns: updated model, compression_scheduler, optimizer, start_epoch
     """
+    def _load_compression_scheduler():
+        normalize_keys = False
+        try:
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys)
+        except KeyError as e:
+            # A very common source of this KeyError is loading a GPU model on the CPU.
+            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
+            normalize_keys = True
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys)
+        msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
+            checkpoint_epoch))
+        return normalize_keys
+
+    def _load_and_execute_thinning_recipes():
+        msglogger.info("Loaded a thinning recipe from the checkpoint")
+        # Cache the recipes in case we need them later
+        model.thinning_recipes = checkpoint['thinning_recipes']
+        if normalize_dataparallel_keys:
+            model.thinning_recipes = [distiller.get_normalized_recipe(recipe)
+                                      for recipe in model.thinning_recipes]
+        distiller.execute_thinning_recipes_list(model,
+                                                compression_scheduler.zeros_mask_dict,
+                                                model.thinning_recipes)
+
+    def _load_optimizer():
+        """Initialize optimizer with model parameters and load src_state_dict"""
+        try:
+            cls, src_state_dict = checkpoint['optimizer_type'], checkpoint['optimizer_state_dict']
+            # Initialize the dest_optimizer with a dummy learning rate,
+            # this is required to support SGD.__init__()
+            dest_optimizer = cls(model.parameters(), lr=1)
+            dest_optimizer.load_state_dict(src_state_dict)
+            msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
+                            type=type(dest_optimizer)))
+            optimizer_param_groups = dest_optimizer.state_dict()['param_groups']
+            msglogger.info('Optimizer Args: {}'.format(
+                            dict((k, v) for k, v in optimizer_param_groups[0].items()
+                                 if k != 'params')))
+            return dest_optimizer
+        except KeyError:
+            # Older checkpoints do support optimizer loading: They either had an 'optimizer' field
+            # (different name) which was not used during the load, or they didn't even checkpoint
+            # the optimizer.
+            msglogger.warning('Optimizer could not be loaded from checkpoint.')
+            return None
+
     if not os.path.isfile(chkpt_file):
         raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
 
@@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
     normalize_dataparallel_keys = False
     if 'compression_sched' in checkpoint:
         compression_scheduler = distiller.CompressionScheduler(model)
-        try:
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
-        except KeyError as e:
-            # A very common source of this KeyError is loading a GPU model on the CPU.
-            # We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
-            normalize_dataparallel_keys = True
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
-        msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
-            checkpoint_epoch))
+        normalize_dataparallel_keys = _load_compression_scheduler()
     else:
         msglogger.info("Warning: compression schedule data does not exist in the checkpoint")
 
     if 'thinning_recipes' in checkpoint:
-        if 'compression_sched' not in checkpoint:
-            msglogger.warning("Found thinning_recipes key, but missing mandatory key compression_sched")
+        if not compression_scheduler:
+            msglogger.warning("Found thinning_recipes key, but missing key compression_scheduler")
             compression_scheduler = distiller.CompressionScheduler(model)
-        msglogger.info("Loaded a thinning recipe from the checkpoint")
-        # Cache the recipes in case we need them later
-        model.thinning_recipes = checkpoint['thinning_recipes']
-        if normalize_dataparallel_keys:
-            model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes]
-        distiller.execute_thinning_recipes_list(model,
-                                                compression_scheduler.zeros_mask_dict,
-                                                model.thinning_recipes)
+        _load_and_execute_thinning_recipes()
 
     if 'quantizer_metadata' in checkpoint:
         msglogger.info('Loaded quantizer metadata from the checkpoint')
@@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
         quantizer.prepare_model(qmd['dummy_input'])
 
     if normalize_dataparallel_keys:
-            checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
+        checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
     anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
     if anomalous_keys:
         # This is pytorch 1.1+
         missing_keys, unexpected_keys = anomalous_keys
         if unexpected_keys:
-            msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys)))
+            msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" %
+                              (chkpt_file, len(unexpected_keys)))
         if missing_keys:
-            raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys)))
-            
+            raise ValueError("The loaded checkpoint (%s) is missing %d state keys" %
+                             (chkpt_file, len(missing_keys)))
+
     if model_device is not None:
         model.to(model_device)
 
     if lean_checkpoint:
         msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file)))
-        return (model, None, None, 0)
-
-    def _load_optimizer(cls, src_state_dict, model):
-        """Initiate optimizer with model parameters and load src_state_dict"""
-        # initiate the dest_optimizer with a dummy learning rate,
-        # this is required to support SGD.__init__()
-        dest_optimizer = cls(model.parameters(), lr=1)
-        dest_optimizer.load_state_dict(src_state_dict)
-        return dest_optimizer
-
-    try:
-        optimizer = _load_optimizer(checkpoint['optimizer_type'],
-            checkpoint['optimizer_state_dict'], model)
-    except KeyError:
-        # Older checkpoints do support optimizer loading: They either had an 'optimizer' field 
-        # (different name) which was not used during the load, or they didn't even checkpoint
-        # the optimizer. 
-        optimizer = None
-
-    if optimizer is not None:
-        msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
-            type=type(optimizer)))
-        msglogger.info('Optimizer Args: {}'.format(
-            dict((k,v) for k,v in optimizer.state_dict()['param_groups'][0].items()
-                            if k != 'params')))
-    else:
-        msglogger.warning('Optimizer could not be loaded from checkpoint.')
+        return model, None, None, 0
 
+    optimizer = _load_optimizer()
     msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
                                                                    e=checkpoint_epoch))
-    return (model, compression_scheduler, optimizer, start_epoch)
+    return model, compression_scheduler, optimizer, start_epoch