From 78e98a51803e7119aa97eea52eca471255640bcd Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Tue, 20 Nov 2018 15:50:06 +0200 Subject: [PATCH] Bug fix: Resuming from checkpoint ignored the masks stored in the checkpoint (#76) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When we resume from a checkpoint, we usually want to continue using the checkpoint’s masks. I say “usually†because I can see a situation where we want to prune a model and checkpoint it, and then resume with the intention of fine-tuning w/o keeping the masks. This is what’s done in Song Han’s Dense-Sparse-Dense (DSD) training (https://arxiv.org/abs/1607.04381). But I didn’t want to add another argument to ```compress_classifier.py``` for the time being – so we ignore DSD. There are two possible situations when we resume a checkpoint that has a serialized ```CompressionScheduler``` with pruning masks: 1. We are planning on using a new ```CompressionScheduler``` that is defined in a schedule YAML file. In this case, we want to copy the masks from the serialized ```CompressionScheduler``` to the new ```CompressionScheduler``` that we are constructing from the YAML file. This is one fix. 2. We are resuming a checkpoint, but without using a YAML schedule file. In this case we want to use the ```CompressionScheduler``` that we loaded from the checkpoint file. All this ```CompressionScheduler``` does is keep applying the masks as we train, so that we don’t lose them. This is the second fix. For DSD, we would need a new flag that would override using the ```CompressionScheduler``` that we load from the checkpoint. --- distiller/config.py | 21 ++++++++++--------- .../compress_classifier.py | 6 +++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/distiller/config.py b/distiller/config.py index ad54cba..07a8f2f 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -49,10 +49,11 @@ msglogger = logging.getLogger() app_cfg_logger = logging.getLogger("app_cfg") -def dict_config(model, optimizer, sched_dict): +def dict_config(model, optimizer, sched_dict, scheduler=None): app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2)) - schedule = distiller.CompressionScheduler(model) + if scheduler is None: + scheduler = distiller.CompressionScheduler(model) pruners = __factory('pruners', model, sched_dict) regularizers = __factory('regularizers', model, sched_dict) @@ -106,7 +107,7 @@ def dict_config(model, optimizer, sched_dict): else: raise ValueError("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]".format(policy_def)) - add_policy_to_scheduler(policy, policy_def, schedule) + add_policy_to_scheduler(policy, policy_def, scheduler) # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) @@ -116,7 +117,7 @@ def dict_config(model, optimizer, sched_dict): instance_name) lr_scheduler = lr_schedulers[instance_name] policy = distiller.LRPolicy(lr_scheduler) - add_policy_to_scheduler(policy, policy_def, schedule) + add_policy_to_scheduler(policy, policy_def, scheduler) except AssertionError: # propagate the assertion information @@ -125,25 +126,25 @@ def dict_config(model, optimizer, sched_dict): print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1)) print("Exception: %s %s" % (type(exception), exception)) raise - return schedule + return scheduler -def add_policy_to_scheduler(policy, policy_def, schedule): +def add_policy_to_scheduler(policy, policy_def, scheduler): if 'epochs' in policy_def: - schedule.add_policy(policy, epochs=policy_def['epochs']) + scheduler.add_policy(policy, epochs=policy_def['epochs']) else: - schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'], + scheduler.add_policy(policy, starting_epoch=policy_def['starting_epoch'], ending_epoch=policy_def['ending_epoch'], frequency=policy_def['frequency']) -def file_config(model, optimizer, filename): +def file_config(model, optimizer, filename, scheduler=None): """Read the schedule from file""" with open(filename, 'r') as stream: msglogger.info('Reading compression schedule from: %s', filename) try: sched_dict = yaml_ordered_load(stream) - return dict_config(model, optimizer, sched_dict) + return dict_config(model, optimizer, sched_dict, scheduler) except yaml.YAMLError as exc: print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename) raise diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index ade512d..7c05154 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -154,7 +154,7 @@ parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type help='number of best scores to track and report (default: 1)') parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, help='Load a model without DataParallel wrapping it') - + quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' '("post-training quantization)') quant_group.add_argument('--quantize-eval', '--qe', action='store_true', @@ -343,10 +343,10 @@ def main(): if args.compress: # The main use-case for this sample application is CNN compression. Compression # requires a compression schedule configuration file in YAML. - compression_scheduler = distiller.file_config(model, optimizer, args.compress) + compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler) # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) model.cuda() - else: + elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) args.kd_policy = None -- GitLab