From 78e98a51803e7119aa97eea52eca471255640bcd Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Tue, 20 Nov 2018 15:50:06 +0200
Subject: [PATCH] Bug fix: Resuming from checkpoint ignored the masks stored in
 the checkpoint (#76)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

When we resume from a checkpoint, we usually want to continue using the checkpoint’s
masks.  I say “usually” because I can see a situation where we want to prune a model
and checkpoint it, and then resume with the intention of fine-tuning w/o keeping the
masks.  This is what’s done in Song Han’s Dense-Sparse-Dense (DSD) training
(https://arxiv.org/abs/1607.04381).  But I didn’t want to add another argument to
```compress_classifier.py``` for the time being – so we ignore DSD.

There are two possible situations when we resume a checkpoint that has a serialized
```CompressionScheduler``` with pruning masks:
1. We are planning on using a new ```CompressionScheduler``` that is defined in a
schedule YAML file.  In this case, we want to copy the masks from the serialized
```CompressionScheduler``` to the new ```CompressionScheduler``` that we are
constructing from the YAML file.  This is one fix.
2. We are resuming a checkpoint, but without using a YAML schedule file.
In this case we want to use the ```CompressionScheduler``` that we loaded from the
checkpoint file.  All this ```CompressionScheduler``` does is keep applying the masks
as we train, so that we don’t lose them.  This is the second fix.

For DSD, we would need a new flag that would override using the ```CompressionScheduler```
that we load from the checkpoint.
---
 distiller/config.py                           | 21 ++++++++++---------
 .../compress_classifier.py                    |  6 +++---
 2 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/distiller/config.py b/distiller/config.py
index ad54cba..07a8f2f 100755
--- a/distiller/config.py
+++ b/distiller/config.py
@@ -49,10 +49,11 @@ msglogger = logging.getLogger()
 app_cfg_logger = logging.getLogger("app_cfg")
 
 
-def dict_config(model, optimizer, sched_dict):
+def dict_config(model, optimizer, sched_dict, scheduler=None):
     app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2))
 
-    schedule = distiller.CompressionScheduler(model)
+    if scheduler is None:
+        scheduler = distiller.CompressionScheduler(model)
 
     pruners = __factory('pruners', model, sched_dict)
     regularizers = __factory('regularizers', model, sched_dict)
@@ -106,7 +107,7 @@ def dict_config(model, optimizer, sched_dict):
             else:
                 raise ValueError("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]".format(policy_def))
 
-            add_policy_to_scheduler(policy, policy_def, schedule)
+            add_policy_to_scheduler(policy, policy_def, scheduler)
 
         # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers
         lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer)
@@ -116,7 +117,7 @@ def dict_config(model, optimizer, sched_dict):
                 instance_name)
             lr_scheduler = lr_schedulers[instance_name]
             policy = distiller.LRPolicy(lr_scheduler)
-            add_policy_to_scheduler(policy, policy_def, schedule)
+            add_policy_to_scheduler(policy, policy_def, scheduler)
 
     except AssertionError:
         # propagate the assertion information
@@ -125,25 +126,25 @@ def dict_config(model, optimizer, sched_dict):
         print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
         print("Exception: %s %s" % (type(exception), exception))
         raise
-    return schedule
+    return scheduler
 
 
-def add_policy_to_scheduler(policy, policy_def, schedule):
+def add_policy_to_scheduler(policy, policy_def, scheduler):
     if 'epochs' in policy_def:
-        schedule.add_policy(policy, epochs=policy_def['epochs'])
+        scheduler.add_policy(policy, epochs=policy_def['epochs'])
     else:
-        schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'],
+        scheduler.add_policy(policy, starting_epoch=policy_def['starting_epoch'],
                             ending_epoch=policy_def['ending_epoch'],
                             frequency=policy_def['frequency'])
 
 
-def file_config(model, optimizer, filename):
+def file_config(model, optimizer, filename, scheduler=None):
     """Read the schedule from file"""
     with open(filename, 'r') as stream:
         msglogger.info('Reading compression schedule from: %s', filename)
         try:
             sched_dict = yaml_ordered_load(stream)
-            return dict_config(model, optimizer, sched_dict)
+            return dict_config(model, optimizer, sched_dict, scheduler)
         except yaml.YAMLError as exc:
             print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename)
             raise
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index ade512d..7c05154 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -154,7 +154,7 @@ parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type
                     help='number of best scores to track and report (default: 1)')
 parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
                     help='Load a model without DataParallel wrapping it')
-                    
+
 quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
                                         '("post-training quantization)')
 quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
@@ -343,10 +343,10 @@ def main():
     if args.compress:
         # The main use-case for this sample application is CNN compression. Compression
         # requires a compression schedule configuration file in YAML.
-        compression_scheduler = distiller.file_config(model, optimizer, args.compress)
+        compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler)
         # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
         model.cuda()
-    else:
+    elif compression_scheduler is None:
         compression_scheduler = distiller.CompressionScheduler(model)
 
     args.kd_policy = None
-- 
GitLab