Skip to content
Snippets Groups Projects
Unverified Commit 78e98a51 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Bug fix: Resuming from checkpoint ignored the masks stored in the checkpoint (#76)

When we resume from a checkpoint, we usually want to continue using the checkpoint’s
masks.  I say “usually” because I can see a situation where we want to prune a model
and checkpoint it, and then resume with the intention of fine-tuning w/o keeping the
masks.  This is what’s done in Song Han’s Dense-Sparse-Dense (DSD) training
(https://arxiv.org/abs/1607.04381).  But I didn’t want to add another argument to
```compress_classifier.py``` for the time being – so we ignore DSD.

There are two possible situations when we resume a checkpoint that has a serialized
```CompressionScheduler``` with pruning masks:
1. We are planning on using a new ```CompressionScheduler``` that is defined in a
schedule YAML file.  In this case, we want to copy the masks from the serialized
```CompressionScheduler``` to the new ```CompressionScheduler``` that we are
constructing from the YAML file.  This is one fix.
2. We are resuming a checkpoint, but without using a YAML schedule file.
In this case we want to use the ```CompressionScheduler``` that we loaded from the
checkpoint file.  All this ```CompressionScheduler``` does is keep applying the masks
as we train, so that we don’t lose them.  This is the second fix.

For DSD, we would need a new flag that would override using the ```CompressionScheduler```
that we load from the checkpoint.
parent ff6985ad
No related branches found
No related tags found
No related merge requests found
......@@ -49,10 +49,11 @@ msglogger = logging.getLogger()
app_cfg_logger = logging.getLogger("app_cfg")
def dict_config(model, optimizer, sched_dict):
def dict_config(model, optimizer, sched_dict, scheduler=None):
app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2))
schedule = distiller.CompressionScheduler(model)
if scheduler is None:
scheduler = distiller.CompressionScheduler(model)
pruners = __factory('pruners', model, sched_dict)
regularizers = __factory('regularizers', model, sched_dict)
......@@ -106,7 +107,7 @@ def dict_config(model, optimizer, sched_dict):
else:
raise ValueError("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]".format(policy_def))
add_policy_to_scheduler(policy, policy_def, schedule)
add_policy_to_scheduler(policy, policy_def, scheduler)
# Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers
lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer)
......@@ -116,7 +117,7 @@ def dict_config(model, optimizer, sched_dict):
instance_name)
lr_scheduler = lr_schedulers[instance_name]
policy = distiller.LRPolicy(lr_scheduler)
add_policy_to_scheduler(policy, policy_def, schedule)
add_policy_to_scheduler(policy, policy_def, scheduler)
except AssertionError:
# propagate the assertion information
......@@ -125,25 +126,25 @@ def dict_config(model, optimizer, sched_dict):
print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
print("Exception: %s %s" % (type(exception), exception))
raise
return schedule
return scheduler
def add_policy_to_scheduler(policy, policy_def, schedule):
def add_policy_to_scheduler(policy, policy_def, scheduler):
if 'epochs' in policy_def:
schedule.add_policy(policy, epochs=policy_def['epochs'])
scheduler.add_policy(policy, epochs=policy_def['epochs'])
else:
schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'],
scheduler.add_policy(policy, starting_epoch=policy_def['starting_epoch'],
ending_epoch=policy_def['ending_epoch'],
frequency=policy_def['frequency'])
def file_config(model, optimizer, filename):
def file_config(model, optimizer, filename, scheduler=None):
"""Read the schedule from file"""
with open(filename, 'r') as stream:
msglogger.info('Reading compression schedule from: %s', filename)
try:
sched_dict = yaml_ordered_load(stream)
return dict_config(model, optimizer, sched_dict)
return dict_config(model, optimizer, sched_dict, scheduler)
except yaml.YAMLError as exc:
print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename)
raise
......
......@@ -154,7 +154,7 @@ parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type
help='number of best scores to track and report (default: 1)')
parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
help='Load a model without DataParallel wrapping it')
quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
'("post-training quantization)')
quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
......@@ -343,10 +343,10 @@ def main():
if args.compress:
# The main use-case for this sample application is CNN compression. Compression
# requires a compression schedule configuration file in YAML.
compression_scheduler = distiller.file_config(model, optimizer, args.compress)
compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler)
# Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
model.cuda()
else:
elif compression_scheduler is None:
compression_scheduler = distiller.CompressionScheduler(model)
args.kd_policy = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment