diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index 91005ebc7bb1407d2b73840ec611ee6ca6ebacf8..a3c4ae62463c10ec95dc6b1e7fdfb73503c225b9 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -60,7 +60,8 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
     if best_top1 is not None:
         checkpoint['best_top1'] = best_top1
     if optimizer is not None:
-        checkpoint['optimizer'] = optimizer.state_dict()
+        checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+        checkpoint['optimizer_type'] = type(optimizer)
     if scheduler is not None:
         checkpoint['compression_sched'] = scheduler.state_dict()
     if hasattr(model, 'thinning_recipes'):
@@ -73,13 +74,22 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
         shutil.copyfile(fullpath, fullpath_best)
 
 
-def load_checkpoint(model, chkpt_file, optimizer=None):
-    """Load a pytorch training checkpoint
+def load_lean_checkpoint(model, chkpt_file, model_device=None):
+    return load_checkpoint(model, chkpt_file, model_device=model_device,
+                           lean_checkpoint=True)[0]
+
+
+def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, lean_checkpoint=False):
+    """Load a pytorch training checkpoint.
 
     Args:
         model: the pytorch model to which we will load the parameters
         chkpt_file: the checkpoint file
-        optimizer: the optimizer to which we will load the serialized state
+        lean_checkpoint: if set, read into model only 'state_dict' field
+        optimizer: [deprecated argument]
+        model_device [str]: if set, call model.to($model_device)
+                This should be set to either 'cpu' or 'cuda'.
+    :returns: updated model, compression_scheduler, optimizer, start_epoch
     """
     if not os.path.isfile(chkpt_file):
         raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
@@ -132,9 +142,43 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
         quantizer = qmd['type'](model, **qmd['params'])
         quantizer.prepare_model()
 
-    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
-                                                                   e=checkpoint_epoch))
     if normalize_dataparallel_keys:
             checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
     model.load_state_dict(checkpoint['state_dict'])
-    return (model, compression_scheduler, start_epoch)
+    if model_device is not None:
+        model.to(model_device)
+
+    if lean_checkpoint:
+        msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file)))
+        return (model, None, None, 0)
+
+    def _load_optimizer(cls, src_state_dict, model):
+        """Initiate optimizer with model parameters and load src_state_dict"""
+        # initiate the dest_optimizer with a dummy learning rate,
+        # this is required to support SGD.__init__()
+        dest_optimizer = cls(model.parameters(), lr=1)
+        dest_optimizer.load_state_dict(src_state_dict)
+        return dest_optimizer
+
+    try:
+        optimizer = _load_optimizer(checkpoint['optimizer_type'],
+            checkpoint['optimizer_state_dict'], model)
+    except KeyError:
+        if 'optimizer' not in checkpoint:
+            raise
+        # older checkpoints didn't support this feature
+        # they had the 'optimizer' field instead
+        optimizer = None
+
+    if optimizer is not None:
+        msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
+            type=type(optimizer)))
+        msglogger.info('Optimizer Args: {}'.format(
+            dict((k,v) for k,v in optimizer.state_dict()['param_groups'][0].items()
+                            if k != 'params')))
+    else:
+        msglogger.warning('Optimizer could not be loaded from checkpoint.')
+
+    msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
+                                                                   e=checkpoint_epoch))
+    return (model, compression_scheduler, optimizer, start_epoch)
diff --git a/distiller/config.py b/distiller/config.py
index c0fb982efc0b9469a187f3f665b3380b0db7a149..9367e88e70abab1be7cbaf21e43dfb0f44a3a88b 100755
--- a/distiller/config.py
+++ b/distiller/config.py
@@ -49,7 +49,7 @@ msglogger = logging.getLogger()
 app_cfg_logger = logging.getLogger("app_cfg")
 
 
-def dict_config(model, optimizer, sched_dict, scheduler=None):
+def dict_config(model, optimizer, sched_dict, scheduler=None, resumed_epoch=None):
     app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2))
 
     if scheduler is None:
@@ -110,7 +110,8 @@ def dict_config(model, optimizer, sched_dict, scheduler=None):
             add_policy_to_scheduler(policy, policy_def, scheduler)
 
         # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers
-        lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer)
+        lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer,
+                                  last_epoch=(resumed_epoch if resumed_epoch is not None else -1))
         for policy_def in lr_policies:
             instance_name, args = __policy_params(policy_def, 'lr_scheduler')
             assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(
@@ -138,13 +139,13 @@ def add_policy_to_scheduler(policy, policy_def, scheduler):
                             frequency=policy_def['frequency'])
 
 
-def file_config(model, optimizer, filename, scheduler=None):
+def file_config(model, optimizer, filename, scheduler=None, resumed_epoch=None):
     """Read the schedule from file"""
     with open(filename, 'r') as stream:
         msglogger.info('Reading compression schedule from: %s', filename)
         try:
             sched_dict = distiller.utils.yaml_ordered_load(stream)
-            return dict_config(model, optimizer, sched_dict, scheduler)
+            return dict_config(model, optimizer, sched_dict, scheduler, resumed_epoch)
         except yaml.YAMLError as exc:
             print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename)
             raise
diff --git a/distiller/policy.py b/distiller/policy.py
index f42219f94322fb27177e126e796582832f0a8538..db45e6f18a4efd0f9a04fb88b5c53cd0b6214d52 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -21,6 +21,7 @@
 - LRPolicy: learning-rate decay scheduling
 """
 import torch
+import torch.optim.lr_scheduler
 from collections import namedtuple
 
 import logging
@@ -42,7 +43,7 @@ class ScheduledTrainingPolicy(object):
         self.classes = classes
         self.layers = layers
 
-    def on_epoch_begin(self, model, zeros_mask_dict, meta):
+    def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
         """A new epcoh is about to begin"""
         pass
 
@@ -115,7 +116,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
         self.mini_batch_id = 0          # The ID of the mini_batch within the present epoch
         self.global_mini_batch_id = 0   # The ID of the mini_batch within the present training session
 
-    def on_epoch_begin(self, model, zeros_mask_dict, meta):
+    def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
         msglogger.debug("Pruner {} is about to prune".format(self.pruner.name))
         self.mini_batch_id = 0
         self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
@@ -169,7 +170,7 @@ class RegularizationPolicy(ScheduledTrainingPolicy):
         self.keep_mask = keep_mask
         self.is_last_epoch = False
 
-    def on_epoch_begin(self, model, zeros_mask_dict, meta):
+    def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
         self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
@@ -202,15 +203,19 @@ class RegularizationPolicy(ScheduledTrainingPolicy):
 
 
 class LRPolicy(ScheduledTrainingPolicy):
-    """ Learning-rate decay scheduling policy.
+    """Learning-rate decay scheduling policy.
 
     """
     def __init__(self, lr_scheduler):
         super(LRPolicy, self).__init__()
         self.lr_scheduler = lr_scheduler
 
-    def on_epoch_begin(self, model, zeros_mask_dict, meta):
-        self.lr_scheduler.step()
+    def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
+        if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+            # Note: ReduceLROnPlateau doesn't inherit from _LRScheduler
+            self.lr_scheduler.step(kwargs['metrics'], epoch=meta['current_epoch'])
+        else:
+            self.lr_scheduler.step(epoch=meta['current_epoch'])
 
 
 class QuantizationPolicy(ScheduledTrainingPolicy):
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 2348eea106dc13636e50b7841d40d4f69b4a94d3..7104ac19815093ac6fd0a4bd630391b8ae36590b 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -106,12 +106,12 @@ class CompressionScheduler(object):
                                        'ending_epoch': ending_epoch,
                                        'frequency': frequency}
 
-    def on_epoch_begin(self, epoch, optimizer=None):
-        if epoch in self.policies:
-            for policy in self.policies[epoch]:
-                meta = self.sched_metadata[policy]
-                meta['current_epoch'] = epoch
-                policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta)
+    def on_epoch_begin(self, epoch, optimizer=None, **kwargs):
+        for policy in self.policies.get(epoch, list()):
+            meta = self.sched_metadata[policy]
+            meta['current_epoch'] = epoch
+            policy.on_epoch_begin(self.model, self.zeros_mask_dict, meta,
+                                  **kwargs)
 
     def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         if epoch in self.policies:
diff --git a/distiller/thinning.py b/distiller/thinning.py
index 94758369c3b051e3c0ae5e0a5fea3e0c29e6c20e..78cd2879469b1c62f40094f2702f4e4e2e6ecedb 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -463,8 +463,12 @@ def optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
     This function is brittle as it is tested on SGD only and relies on the internal representation of
     the SGD optimizer, which can change w/o notice.
     """
-    if optimizer is None or not isinstance(optimizer, torch.optim.SGD):
+    if optimizer is None:
         return False
+
+    if not isinstance(optimizer, torch.optim.SGD):
+        raise NotImplementedError('optimizer thinning supports only SGD')
+
     for group in optimizer.param_groups:
         momentum = group.get('momentum', 0)
         if momentum == 0:
@@ -473,7 +477,7 @@ def optimizer_thinning(optimizer, param, dim, indices, new_shape=None):
             if id(p) != id(param):
                 continue
             param_state = optimizer.state[p]
-            if 'momentum_buffer' in param_state:
+            if param_state.get('momentum_buffer', None) is not None:
                 param_state['momentum_buffer'] = torch.index_select(param_state['momentum_buffer'], dim, indices)
                 if new_shape is not None:
                     msglogger.debug("optimizer_thinning: new shape {}".format(*new_shape))
@@ -526,9 +530,6 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                         grad_selection_view = param.grad.resize_(*directive[2])
                         if grad_selection_view.size(dim) != len_indices:
                             param.grad = torch.index_select(grad_selection_view, dim, indices)
-                            if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
-                                msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
-                                                format(param_name, dim, len_indices, directive[3]))
 
                 param.data = param.view(*directive[3])
                 if param.grad is not None:
@@ -542,8 +543,13 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
                 # not exist, and therefore won't need to be re-dimensioned.
                 if param.grad is not None and param.grad.size(dim) != len_indices:
                     param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
-                    if optimizer_thinning(optimizer, param, dim, indices):
-                        msglogger.debug("Updated velocity buffer %s" % param_name)
+
+            # update optimizer
+            if optimizer_thinning(optimizer, param, dim, indices,
+                                  new_shape=directive[3] if len(directive)==4 else None):
+                msglogger.debug("Updated velocity buffer %s" % param_name)
+            else:
+                msglogger.debug('Failed to update the optimizer by thinning directive')
 
             if not loaded_from_file:
                 # If the masks are loaded from a checkpoint file, then we don't need to change
diff --git a/distiller/utils.py b/distiller/utils.py
index 875b4043e1c8a3cb86037edd9965500db3a74605..ac4a31abe323ef85948c9fce20e2c02c9813fe1c 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -39,6 +39,10 @@ def model_device(model):
     return 'cpu'
 
 
+def optimizer_device_name(opt):
+    return str(list(list(opt.state)[0])[0].device)
+
+
 def to_np(var):
     return var.data.cpu().numpy()
 
diff --git a/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml b/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml
index 6e587355ad79523db6bcdbba28e2392c5dfeabd6..dd51be800c79caab762907c503575b18b35b1ccb 100755
--- a/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml
+++ b/examples/agp-pruning/mobilenet.imagenet.schedule_agp.yaml
@@ -1,4 +1,5 @@
-# time python3 compress_classifier.py -a=mobilenet -p=50 --lr=0.001 ../../../data.imagenet/ -j=22 --resume=mobilenet_sgd_68.848.pth.tar --epochs=96 --compress=../agp-pruning//mobilenet.imagenet.schedule_agp.yaml
+# time python3 compress_classifier.py -a=mobilenet -p=50 --lr=0.001 ../../../data.imagenet/ -j=22 --resume-from=mobilenet_sgd_68.848.pth.tar --epochs=96 --compress=../agp-pruning//mobilenet.imagenet.schedule_agp.yaml  --reset-optimizer --vs=0
+#
 #
 # A pretrained MobileNet (width=1) can be downloaded from: https://github.com/marvis/pytorch-mobilenet (top1: 68.848; top5: 88.740)
 #
@@ -86,18 +87,18 @@ lr_schedulers:
 policies:
   - pruner:
       instance_name : 'conv50_pruner'
-    starting_epoch: 103
-    ending_epoch: 123
+    starting_epoch: 0
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name : 'conv60_pruner'
-    starting_epoch: 103
-    ending_epoch: 123
+    starting_epoch: 0
+    ending_epoch: 20
     frequency: 2
 
   - lr_scheduler:
       instance_name: pruning_lr
-    starting_epoch: 103
-    ending_epoch: 200
+    starting_epoch: 0
+    ending_epoch: 100
     frequency: 1
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
index aa04ed5253a80803336d30dc3d1deba68b4a3810..362e52c00c800528060502c14002e65cddf3f07c 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -14,7 +14,7 @@
 #     Total sparsity: 41.10
 #     # of parameters: 120,000  (=55.7% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -109,29 +109,29 @@ extensions:
 policies:
   - pruner:
       instance_name : low_pruner
-    starting_epoch: 180
-    ending_epoch: 210
+    starting_epoch: 0
+    ending_epoch: 30
     frequency: 2
 
   - pruner:
       instance_name : fine_pruner
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
   - pruner:
       instance_name : fc_pruner
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
 # After completeing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
-    epochs: [212]
+    epochs: [32]
 
   - lr_scheduler:
       instance_name: pruning_lr
-    starting_epoch: 180
+    starting_epoch: 0
     ending_epoch: 400
     frequency: 1
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
index 4c9d3b5005a40a7cf74a1f7d69323e221f80daf8..1ea06610ad142adb4b641f1f4ab0c0f03dc7e7af 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
@@ -14,8 +14,7 @@
 #     Total sparsity: 41.84
 #     # of parameters: 143,488 (=53% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
-#
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -103,20 +102,20 @@ extensions:
 policies:
   - pruner:
       instance_name : low_pruner
-    starting_epoch: 180
-    ending_epoch: 200
+    starting_epoch: 0
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name : fine_pruner
-    starting_epoch: 200
-    ending_epoch: 220
+    starting_epoch: 20
+    ending_epoch: 40
     frequency: 2
 
 # After completeing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
-    epochs: [202]
+    epochs: [22]
 
   - lr_scheduler:
       instance_name: pruning_lr
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
index 38798587c8b421a69f7a4a42a6d01116f7d247dd..f958ec13ce230b5764e1bcec22eeef49d1671752 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
@@ -14,7 +14,7 @@
 #     Total sparsity: 56.41%
 #     # of parameters: 95922  (=35.4% of the baseline parameters ==> 64.6% sparsity)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -116,26 +116,26 @@ extensions:
 policies:
   - pruner:
       instance_name : low_pruner
-    starting_epoch: 180
-    ending_epoch: 210
+    starting_epoch: 0
+    ending_epoch: 30
     frequency: 2
 
   - pruner:
       instance_name : fine_pruner1
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
   - pruner:
       instance_name : fine_pruner2
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
   - pruner:
       instance_name : fc_pruner
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
   # Currently the thinner is disabled until the the structure pruner is done, because it interacts
@@ -150,10 +150,10 @@ policies:
 # After completeing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
-    epochs: [212]
+    epochs: [32]
 
   - lr_scheduler:
       instance_name: pruning_lr
-    starting_epoch: 180
+    starting_epoch: 0
     ending_epoch: 400
     frequency: 1
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
index 3f85c2db2c780446a71887a976ab145da41baf96..147d7b8e0cd86b3afa6af9545d397f645b6f2736 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
@@ -14,7 +14,7 @@
 #     Total sparsity: 39.66
 #     # of parameters: 78,776  (=29.1% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -110,26 +110,26 @@ extensions:
 policies:
   - pruner:
       instance_name : low_pruner
-    starting_epoch: 180
-    ending_epoch: 210
+    starting_epoch: 0
+    ending_epoch: 30
     frequency: 2
 
   - pruner:
       instance_name : low_pruner_2
-    starting_epoch: 180
-    ending_epoch: 210
+    starting_epoch: 0
+    ending_epoch: 30
     frequency: 2
 
   - pruner:
       instance_name : fine_pruner
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
   - pruner:
       instance_name : fc_pruner
-    starting_epoch: 210
-    ending_epoch: 230
+    starting_epoch: 30
+    ending_epoch: 50
     frequency: 2
 
   # Currently the thinner is disabled until the the structure pruner is done, because it interacts
@@ -144,11 +144,10 @@ policies:
 # After completeing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
-    #epochs: [181]
-    epochs: [212]
+    epochs: [32]
 
   - lr_scheduler:
       instance_name: pruning_lr
-    starting_epoch: 180
+    starting_epoch: 0
     ending_epoch: 400
     frequency: 1
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index c33cc3fa60d469d06dd51545663d3d6fccc35f3b..e7b318c937969d98d4ee09e661161d5d4320bdea 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -87,6 +87,8 @@ def main():
 
     # Parse arguments
     args = parser.get_parser().parse_args()
+    if args.epochs is None:
+        args.epochs = 90
 
     if not os.path.exists(args.output_dir):
         os.makedirs(args.output_dir)
@@ -98,6 +100,7 @@ def main():
     msglogger.debug("Distiller: %s", distiller.__version__)
 
     start_epoch = 0
+    ending_epoch = args.epochs
     perf_scores_history = []
     if args.deterministic:
         # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
@@ -155,19 +158,36 @@ def main():
     if args.earlyexit_thresholds:
         msglogger.info('=> using early-exit threshold values of %s', args.earlyexit_thresholds)
 
-    # We can optionally resume from a checkpoint
-    if args.resume:
-        model, compression_scheduler, start_epoch = apputils.load_checkpoint(model, chkpt_file=args.resume)
-        model.to(args.device)
+    # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1
+    if args.deprecated_resume:
+        msglogger.warning('The "--resume" flag is deprecated. Please use "--resume-from=YOUR_PATH" instead.')
+        if not args.reset_optimizer:
+            msglogger.warning('If you wish to also reset the optimizer, call with: --reset-optimizer')
+            args.reset_optimizer = True
+        args.resumed_checkpoint_path = args.deprecated_resume
 
-    # Define loss function (criterion) and optimizer
+    # We can optionally resume from a checkpoint
+    optimizer = None
+    if args.resumed_checkpoint_path:
+        model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
+            model, args.resumed_checkpoint_path, model_device=args.device)
+    elif args.load_model_path:
+        model = apputils.load_lean_checkpoint(model, args.load_model_path,
+                                              model_device=args.device)
+    if args.reset_optimizer:
+        start_epoch = 0
+        if optimizer is not None:
+            optimizer = None
+            msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0')
+
+    # Define loss function (criterion)
     criterion = nn.CrossEntropyLoss().to(args.device)
 
-    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
-                                momentum=args.momentum,
-                                weight_decay=args.weight_decay)
-    msglogger.info('Optimizer Type: %s', type(optimizer))
-    msglogger.info('Optimizer Args: %s', optimizer.defaults)
+    if optimizer is None:
+        optimizer = torch.optim.SGD(model.parameters(),
+            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
+        msglogger.info('Optimizer Type: %s', type(optimizer))
+        msglogger.info('Optimizer Args: %s', optimizer.defaults)
 
     if args.AMC:
         return automated_deep_compression(model, criterion, optimizer, pylogger, args)
@@ -211,7 +231,8 @@ def main():
     if args.compress:
         # The main use-case for this sample application is CNN compression. Compression
         # requires a compression schedule configuration file in YAML.
-        compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler)
+        compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler,
+            (start_epoch-1) if args.resumed_checkpoint_path else None)
         # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
         model.to(args.device)
     elif compression_scheduler is None:
@@ -219,10 +240,10 @@ def main():
 
     if args.thinnify:
         #zeros_mask_dict = distiller.create_model_masks_dict(model)
-        assert args.resume is not None, "You must use --resume to provide a checkpoint file to thinnify"
+        assert args.resumed_checkpoint_path is not None, "You must use --resume-from to provide a checkpoint file to thinnify"
         distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None)
         apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler,
-                                 name="{}_thinned".format(args.resume.replace(".pth.tar", "")), dir=msglogger.logdir)
+                                 name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")), dir=msglogger.logdir)
         print("Note: your model may have collapsed to random inference, so you may want to fine-tune")
         return
 
@@ -230,7 +251,7 @@ def main():
     if args.kd_teacher:
         teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)
         if args.kd_resume:
-            teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)
+            teacher = apputils.load_lean_checkpoint(teacher, args.kd_resume)
         dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)
         args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)
         compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,
@@ -243,11 +264,17 @@ def main():
                        ' | '.join(['{:.2f}'.format(val) for val in dlw]))
         msglogger.info('\tStarting from Epoch: %s', args.kd_start_epoch)
 
-    for epoch in range(start_epoch, start_epoch + args.epochs):
+    if start_epoch >= ending_epoch:
+        msglogger.error(
+            'epoch count is too low, starting epoch is {} but total epochs set to {}'.format(
+            start_epoch, ending_epoch))
+        raise ValueError('Epochs parameter is too low. Nothing to do.')
+    for epoch in range(start_epoch, ending_epoch):
         # This is the main training loop.
         msglogger.info('\n')
         if compression_scheduler:
-            compression_scheduler.on_epoch_begin(epoch)
+            compression_scheduler.on_epoch_begin(epoch,
+                metrics=(vloss if (epoch != start_epoch) else 10**6))
 
         # Train for one epoch
         with collectors_context(activations_collectors["train"]) as collectors:
@@ -595,7 +622,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector
     # the test dataset.
     # You can optionally quantize the model to 8-bit integer before evaluation.
     # For example:
-    # python3 compress_classifier.py --arch resnet20_cifar  ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate
+    # python3 compress_classifier.py --arch resnet20_cifar  ../data.cifar10 -p=50 --resume-from=checkpoint.pth.tar --evaluate
 
     if not isinstance(loggers, list):
         loggers = [loggers]
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
index 9b51f513031c3ae40654db503f0704f158ee23d4..2783b29c311edad9a7f5da1132d7c0978055b2ac 100755
--- a/examples/classifier_compression/parser.py
+++ b/examples/classifier_compression/parser.py
@@ -36,24 +36,40 @@ def get_parser():
                         ' (default: resnet18)')
     parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                         help='number of data loading workers (default: 4)')
-    parser.add_argument('--epochs', default=90, type=int, metavar='N',
-                        help='number of total epochs to run')
+    parser.add_argument('--epochs', type=int, metavar='N',
+                        help='number of total epochs to run (default: 90')
     parser.add_argument('-b', '--batch-size', default=256, type=int,
                         metavar='N', help='mini-batch size (default: 256)')
-    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
-                        metavar='LR', help='initial learning rate')
-    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
-                        help='momentum')
-    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
-                        metavar='W', help='weight decay (default: 1e-4)')
+
+    optimizer_args = parser.add_argument_group('Optimizer arguments')
+    optimizer_args.add_argument('--lr', '--learning-rate', default=0.1,
+                    type=float, metavar='LR', help='initial learning rate')
+    optimizer_args.add_argument('--momentum', default=0.9, type=float,
+                    metavar='M', help='momentum')
+    optimizer_args.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
+                    metavar='W', help='weight decay (default: 1e-4)')
+
     parser.add_argument('--print-freq', '-p', default=10, type=int,
                         metavar='N', help='print frequency (default: 10)')
-    parser.add_argument('--resume', default='', type=str, metavar='PATH',
-                        help='path to latest checkpoint (default: none)')
+
+    load_checkpoint_group = parser.add_argument_group('Resuming arguments')
+    load_checkpoint_group_exc = load_checkpoint_group.add_mutually_exclusive_group()
+    # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1
+    load_checkpoint_group_exc.add_argument('--resume', dest='deprecated_resume', default='', type=str,
+                        metavar='PATH', help=argparse.SUPPRESS)
+    load_checkpoint_group_exc.add_argument('--resume-from', dest='resumed_checkpoint_path', default='',
+                        type=str, metavar='PATH',
+                        help='path to latest checkpoint. Use to resume paused training session.')
+    load_checkpoint_group_exc.add_argument('--exp-load-weights-from', dest='load_model_path',
+                        default='', type=str, metavar='PATH',
+                        help='path to checkpoint to load weights from (excluding other fields) (experimental)')
+    load_checkpoint_group.add_argument('--pretrained', dest='pretrained', action='store_true',
+                        help='use pre-trained model')
+    load_checkpoint_group.add_argument('--reset-optimizer', action='store_true',
+                        help='Flag to override optimizer if resumed from checkpoint. This will reset epochs count.')
+
     parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                         help='evaluate model on validation set')
-    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
-                        help='use pre-trained model')
     parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
                         help='collect activation statistics on phases: train, valid, and/or test'
                         ' (WARNING: this slows down training)')
diff --git a/examples/network_surgery/resnet20.network_surgery.yaml b/examples/network_surgery/resnet20.network_surgery.yaml
index 06644da5de7cd594fc29713b5890f0fd7203e4b7..e77b980142ac5dbc298163fdacbce67a321cae4e 100755
--- a/examples/network_surgery/resnet20.network_surgery.yaml
+++ b/examples/network_surgery/resnet20.network_surgery.yaml
@@ -31,7 +31,7 @@
 #     Total sparsity: 69.1%
 #     # of parameters: 83,671
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml -j=1 --deterministic  --validation-split=0 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --masks-sparsity --num-best-scores=10
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml --vs=0 --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --masks-sparsity --num-best-scores=5
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -127,13 +127,13 @@ policies:
         mini_batch_pruning_frequency: 1
         mask_on_forward_only: True
         # use_double_copies: True
-    starting_epoch: 180
-    ending_epoch: 280
+    starting_epoch: 0
+    ending_epoch: 100
     frequency: 1
 
 
   - lr_scheduler:
       instance_name: training_lr
-    starting_epoch: 225
+    starting_epoch: 0
     ending_epoch: 400
     frequency: 1
diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
index cf6c7220d39d85ce95e1f933174ff774beeebbd7..c12f9f01f5a456ba80e113711e76f6a04bb9aea2 100755
--- a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
+++ b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml
@@ -14,7 +14,7 @@
 #     Total MACs: 78,856,832
 #
 #
-# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --act-stats=valid
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -167,35 +167,34 @@ lr_schedulers:
 policies:
   - pruner:
       instance_name: filter_pruner_60
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name: filter_pruner_50
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name: filter_pruner_30
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name: filter_pruner_10
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - extension:
       instance_name: net_thinner
-    epochs: [200]
-    #epochs: [181]
+    epochs: [20]
 
   - lr_scheduler:
       instance_name: exp_finetuning_lr
-    starting_epoch: 190
+    starting_epoch: 10
     ending_epoch: 300
     frequency: 1
diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
index 4f65e441bb6afa19cf1b8de8ece040fed2272c0e..bf6d078200696cece0291a7088b488c997e04dd5 100755
--- a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
+++ b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml
@@ -14,7 +14,7 @@
 #     Total MACs: 67,797,632
 #
 #
-# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic --act-stats=valid
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --act-stats=valid
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -168,34 +168,34 @@ lr_schedulers:
 policies:
   - pruner:
       instance_name: filter_pruner_60
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name: filter_pruner_50
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name: filter_pruner_30
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - pruner:
       instance_name: filter_pruner_10
-    starting_epoch: 181
-    ending_epoch: 200
+    starting_epoch: 1
+    ending_epoch: 20
     frequency: 2
 
   - extension:
       instance_name: net_thinner
-    epochs: [200]
+    epochs: [20]
 
   - lr_scheduler:
       instance_name: exp_finetuning_lr
-    starting_epoch: 190
+    starting_epoch: 10
     ending_epoch: 300
     frequency: 1
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml
index 4d934c5bd8867809d2aebaf240d6cb9853e10cbc..68c760456eaee8ff59ad721ca261b0fe3e231215 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml
@@ -5,7 +5,7 @@
 # However, instead of one-shot filter ranking and pruning, we perform one-shot channel ranking and
 # pruning, using L1-magnitude of the structures.
 #
-# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0
 #
 # Baseline results:
 #     Top1: 92.850    Top5: 99.780    Loss: 0.464
@@ -164,26 +164,26 @@ lr_schedulers:
 policies:
   - pruner:
       instance_name: filter_pruner_70
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_60
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_40
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_20
-    epochs: [180]
+    epochs: [0]
 
   - extension:
       instance_name: net_thinner
-    epochs: [180]
+    epochs: [0]
 
   - lr_scheduler:
       instance_name: exp_finetuning_lr
-    starting_epoch: 190
+    starting_epoch: 10
     ending_epoch: 300
     frequency: 1
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
index d4dce9622f98f6d4d664f0a241c336f52c4a1a80..e52adf6700b90cba9fb89617fcb751f9dcb9a054 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml
@@ -14,7 +14,7 @@
 # You may either train this model from scratch, or download it from the link below.
 # https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
 #
-# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar  --reset-optimizer --vs=0
 #
 # Results: 62.7% of the original convolution MACs (when calculated using direct convolution)
 #
@@ -175,26 +175,26 @@ lr_schedulers:
 policies:
   - pruner:
       instance_name: filter_pruner_60
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_50
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_30
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_10
-    epochs: [180]
+    epochs: [0]
 
   - extension:
       instance_name: net_thinner
-    epochs: [180]
+    epochs: [0]
 
   - lr_scheduler:
       instance_name: exp_finetuning_lr
-    starting_epoch: 190
+    starting_epoch: 10
     ending_epoch: 300
     frequency: 1
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
index 6e00522e43dd568918e08625ede4baba212929b8..58ac2e87b66f9963df2870225789d0cf156dcdf8 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
@@ -15,7 +15,7 @@
 # You may either train this model from scratch, or download it from the link below.
 # https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
 #
-# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic
+# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0
 #
 # Results: 53.9% (1.85x) of the original convolution MACs (when calculated using direct convolution)
 #
@@ -177,26 +177,26 @@ lr_schedulers:
 policies:
   - pruner:
       instance_name: filter_pruner_70
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_60
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_40
-    epochs: [180]
+    epochs: [0]
 
   - pruner:
       instance_name: filter_pruner_20
-    epochs: [180]
+    epochs: [0]
 
   - extension:
       instance_name: net_thinner
-    epochs: [180]
+    epochs: [0]
 
   - lr_scheduler:
       instance_name: exp_finetuning_lr
-    starting_epoch: 190
+    starting_epoch: 10
     ending_epoch: 300
     frequency: 1
diff --git a/examples/ssl/ssl_channels-removal_finetuning.yaml b/examples/ssl/ssl_channels-removal_finetuning.yaml
index 549629cf20cd4a102b1610fd29c6ea46b5e6fa59..89f65b88b168cc3a11913ac236fb04aa96e0a3e7 100755
--- a/examples/ssl/ssl_channels-removal_finetuning.yaml
+++ b/examples/ssl/ssl_channels-removal_finetuning.yaml
@@ -3,7 +3,7 @@
 #
 # We save the output (i.e. checkpoint.pth.tar) in ../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20_finetuned.pth.tar
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar --reset-optimizer
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -58,6 +58,6 @@ lr_schedulers:
 policies:
   - lr_scheduler:
       instance_name: training_lr
-    starting_epoch: 45
+    starting_epoch: 0
     ending_epoch: 300
     frequency: 1
diff --git a/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml b/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml
index 195b09865815e871343407f6e14c74ff4eb8729b..4e96499b9000556c64c201d8568985d8746f8062 100755
--- a/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml
+++ b/examples/ssl/ssl_channels-removal_finetuning_x1.8.yaml
@@ -7,7 +7,7 @@
 #
 # Total MACs: 22,583,936 == 55.3% compute density
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=<enter-your-model>
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning_x1.8.yaml --reset-optimizer --resume-from=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
@@ -66,6 +66,6 @@ lr_schedulers:
 policies:
   - lr_scheduler:
       instance_name: training_lr
-    starting_epoch: 45
+    starting_epoch: 0
     ending_epoch: 300
     frequency: 1
diff --git a/examples/ssl/ssl_filter-removal_training.yaml b/examples/ssl/ssl_filter-removal_training.yaml
index 84c30e861bca8c8f147c5d124947e6804471dec7..7dade219c405e4e2468d3d8c12f58e015b28d28e 100755
--- a/examples/ssl/ssl_filter-removal_training.yaml
+++ b/examples/ssl/ssl_filter-removal_training.yaml
@@ -9,7 +9,7 @@
 # time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_filter-removal_training.yaml -j=1 --deterministic --name="filters"
 #
 # To fine-tune:
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=...
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml --reset-optimizer --resume-from=...
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/ssl/vgg16_cifar_ssl_channels_training.yaml b/examples/ssl/vgg16_cifar_ssl_channels_training.yaml
index 9b7155bfab9b1355b207b66230c7744ca03e7c2b..736bc8b523781a2df3520352581e083ada00c329 100755
--- a/examples/ssl/vgg16_cifar_ssl_channels_training.yaml
+++ b/examples/ssl/vgg16_cifar_ssl_channels_training.yaml
@@ -5,7 +5,7 @@
 # time python3 compress_classifier.py --arch vgg16_cifar  ../../../data.cifar10 -p=50 --lr=0.05 --epochs=180 --compress=../ssl/vgg16_cifar_ssl_channels_training.yaml -j=1 --deterministic
 #
 # The results below are from the SSL training session, and you can follow-up with some fine-tuning:
-# time python3 compress_classifier.py --arch vgg16_cifar  ../../../data.cifar10 --resume=checkpoint.vgg16_cifar.pth.tar --lr=0.01 --epochs=20
+# time python3 compress_classifier.py --arch vgg16_cifar  ../../../data.cifar10 --resume-from=checkpoint.vgg16_cifar.pth.tar  --reset-optimizer --lr=0.01 --epochs=20
 # ==> Top1: 91.010    Top5: 99.480    Loss: 0.513
 #
 # Parameters:
@@ -74,14 +74,13 @@ extensions:
 policies:
   - lr_scheduler:
       instance_name: training_lr
-    starting_epoch: 45
+    starting_epoch: 0
     ending_epoch: 300
     frequency: 1
 
-# After completeing the regularization, we perform network thinning and exit.
   - extension:
       instance_name: net_thinner
-    epochs: [179]
+    epochs: [0]
 
   - regularizer:
       instance_name: Channels_groups_regularizer
diff --git a/tests/checkpoints/resnet20_cifar10_checkpoint.pth.tar b/tests/checkpoints/resnet20_cifar10_checkpoint.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..d15a84f847090d934bf91d19736ff577a2752476
Binary files /dev/null and b/tests/checkpoints/resnet20_cifar10_checkpoint.pth.tar differ
diff --git a/tests/test_infra.py b/tests/test_infra.py
index 2d0524f5bb77cb9d50195e304b1a1d4a1c9e4e1e..478ad3e1b9397d7f3daa1d1d6f22e26130159b24 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -28,7 +28,7 @@ except ImportError:
         sys.path.append(module_path)
     import distiller
 import distiller
-from distiller.apputils import save_checkpoint, load_checkpoint
+from distiller.apputils import save_checkpoint, load_checkpoint, load_lean_checkpoint
 from distiller.models import create_model
 import pretrainedmodels
 
@@ -66,30 +66,65 @@ def test_create_model_pretrainedmodels():
         model = create_model(False, 'imagenet', 'no_such_model!')
 
 
+def _is_similar_param_groups(opt_a, opt_b):
+    for k in opt_a['param_groups'][0]:
+        val_a = opt_a['param_groups'][0][k]
+        val_b = opt_b['param_groups'][0][k]
+        if (val_a != val_b) and (k != 'params'):
+            return False
+    return True
+
+
 def test_load():
     logger = logging.getLogger('simple_example')
     logger.setLevel(logging.INFO)
 
-    model = create_model(False, 'cifar10', 'resnet20_cifar')
-    model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
+    checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'
+    src_optimizer_state_dict = torch.load(checkpoint_filename)['optimizer_state_dict']
+
+    model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
+    model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
+        model, checkpoint_filename)
     assert compression_scheduler is not None
-    assert start_epoch == 180
+    assert optimizer is not None, 'Failed to load the optimizer'
+    if not _is_similar_param_groups(src_optimizer_state_dict, optimizer.state_dict()):
+        assert src_optimizer_state_dict == optimizer.state_dict() # this will always fail
+    assert start_epoch == 1
+
+
+def test_load_state_dict_implicit():
+    # prepare lean checkpoint
+    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
+
+    with tempfile.NamedTemporaryFile() as tmpfile:
+        torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
+        model = create_model(False, 'cifar10', 'resnet20_cifar')
+        with pytest.raises(KeyError):
+            model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
 
 
-def test_load_state_dict():
+def test_load_lean_checkpoint_1():
     # prepare lean checkpoint
     state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
 
     with tempfile.NamedTemporaryFile() as tmpfile:
         torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
         model = create_model(False, 'cifar10', 'resnet20_cifar')
-        model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
+        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
+            model, tmpfile.name, lean_checkpoint=True)
 
-    assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0
     assert compression_scheduler is None
+    assert optimizer is None
     assert start_epoch == 0
 
 
+def test_load_lean_checkpoint_2():
+    checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar'
+
+    model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
+    model = load_lean_checkpoint(model, checkpoint_filename)
+
+
 def test_load_dumb_checkpoint():
     # prepare lean checkpoint
     state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
@@ -98,24 +133,42 @@ def test_load_dumb_checkpoint():
         torch.save(state_dict_arrays, tmpfile.name)
         model = create_model(False, 'cifar10', 'resnet20_cifar')
         with pytest.raises(ValueError):
-            model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)
+            model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)
 
 
 def test_load_negative():
     with pytest.raises(FileNotFoundError):
         model = create_model(False, 'cifar10', 'resnet20_cifar')
-        model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
+        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model,
+            'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
 
 
 def test_load_gpu_model_on_cpu():
     # Issue #148
     CPU_DEVICE_ID = -1
+    CPU_DEVICE_NAME = 'cpu'
+    checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'
+
     model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
-    model, compression_scheduler, start_epoch = load_checkpoint(model,
-                                                                '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
+    model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
+        model, checkpoint_filename)
+
     assert compression_scheduler is not None
-    assert start_epoch == 180
-    assert distiller.model_device(model) == 'cpu'
+    assert optimizer is not None
+    assert distiller.utils.optimizer_device_name(optimizer) == CPU_DEVICE_NAME
+    assert start_epoch == 1
+    assert distiller.model_device(model) == CPU_DEVICE_NAME
+
+
+def test_load_gpu_model_on_cpu_lean_checkpoint():
+    CPU_DEVICE_ID = -1
+    CPU_DEVICE_NAME = 'cpu'
+    checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar'
+
+    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
+    model = load_lean_checkpoint(model, checkpoint_filename,
+                                 model_device=CPU_DEVICE_NAME)
+    assert distiller.model_device(model) == CPU_DEVICE_NAME
 
 
 def test_load_gpu_model_on_cpu_with_thinning():
@@ -137,13 +190,10 @@ def test_load_gpu_model_on_cpu_with_thinning():
     distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
     assert hasattr(gpu_model, 'thinning_recipes')
     scheduler = distiller.CompressionScheduler(gpu_model)
-    save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None)
+    save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None,
+        dir='checkpoints')
 
     CPU_DEVICE_ID = -1
     cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
-    load_checkpoint(cpu_model, "checkpoint.pth.tar")
+    load_lean_checkpoint(cpu_model, "checkpoints/checkpoint.pth.tar")
     assert distiller.model_device(cpu_model) == 'cpu'
-
-
-if __name__ == '__main__':
-    test_load_gpu_model_on_cpu()
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 8b9c4ddf45e088a87d5b3cfa0cb0e033347eb811..443e45240083ad88108ebd4940764244c6a9c6e9 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -23,7 +23,7 @@ import distiller
 import common
 import pytest
 from distiller.models import create_model
-from distiller.apputils import save_checkpoint, load_checkpoint
+from distiller.apputils import save_checkpoint, load_checkpoint, load_lean_checkpoint
 
 # Logging configuration
 logging.basicConfig(level=logging.INFO)
@@ -296,7 +296,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
     conv2 = common.find_module_by_name(model_2, pair[1])
     assert conv2 is not None
     with pytest.raises(KeyError):
-        model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+        model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
     compression_scheduler = distiller.CompressionScheduler(model)
     hasattr(model, 'thinning_recipes')
 
@@ -304,13 +304,13 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
 
     # (2)
     save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler)
-    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
     assert hasattr(model_2, 'thinning_recipes')
     logger.info("test_arbitrary_channel_pruning - Done")
 
     # (3)
     save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler)
-    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
+    model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
     assert hasattr(model_2, 'thinning_recipes')
     logger.info("test_arbitrary_channel_pruning - Done 2")