diff --git a/distiller/config.py b/distiller/config.py index ad54cba42c2fca51a655894d2734f5bd56dc7639..07a8f2f1c9e7346a0802761880739ae30727faa2 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -49,10 +49,11 @@ msglogger = logging.getLogger() app_cfg_logger = logging.getLogger("app_cfg") -def dict_config(model, optimizer, sched_dict): +def dict_config(model, optimizer, sched_dict, scheduler=None): app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2)) - schedule = distiller.CompressionScheduler(model) + if scheduler is None: + scheduler = distiller.CompressionScheduler(model) pruners = __factory('pruners', model, sched_dict) regularizers = __factory('regularizers', model, sched_dict) @@ -106,7 +107,7 @@ def dict_config(model, optimizer, sched_dict): else: raise ValueError("\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]".format(policy_def)) - add_policy_to_scheduler(policy, policy_def, schedule) + add_policy_to_scheduler(policy, policy_def, scheduler) # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) @@ -116,7 +117,7 @@ def dict_config(model, optimizer, sched_dict): instance_name) lr_scheduler = lr_schedulers[instance_name] policy = distiller.LRPolicy(lr_scheduler) - add_policy_to_scheduler(policy, policy_def, schedule) + add_policy_to_scheduler(policy, policy_def, scheduler) except AssertionError: # propagate the assertion information @@ -125,25 +126,25 @@ def dict_config(model, optimizer, sched_dict): print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1)) print("Exception: %s %s" % (type(exception), exception)) raise - return schedule + return scheduler -def add_policy_to_scheduler(policy, policy_def, schedule): +def add_policy_to_scheduler(policy, policy_def, scheduler): if 'epochs' in policy_def: - schedule.add_policy(policy, epochs=policy_def['epochs']) + scheduler.add_policy(policy, epochs=policy_def['epochs']) else: - schedule.add_policy(policy, starting_epoch=policy_def['starting_epoch'], + scheduler.add_policy(policy, starting_epoch=policy_def['starting_epoch'], ending_epoch=policy_def['ending_epoch'], frequency=policy_def['frequency']) -def file_config(model, optimizer, filename): +def file_config(model, optimizer, filename, scheduler=None): """Read the schedule from file""" with open(filename, 'r') as stream: msglogger.info('Reading compression schedule from: %s', filename) try: sched_dict = yaml_ordered_load(stream) - return dict_config(model, optimizer, sched_dict) + return dict_config(model, optimizer, sched_dict, scheduler) except yaml.YAMLError as exc: print("\nFATAL parsing error while parsing the schedule configuration file %s" % filename) raise diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index ade512ddcce63c22ca2ce62796632b092073fe98..7c051545d3e166938641980da1dac577d0ed5498 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -154,7 +154,7 @@ parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type help='number of best scores to track and report (default: 1)') parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False, help='Load a model without DataParallel wrapping it') - + quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time' '("post-training quantization)') quant_group.add_argument('--quantize-eval', '--qe', action='store_true', @@ -343,10 +343,10 @@ def main(): if args.compress: # The main use-case for this sample application is CNN compression. Compression # requires a compression schedule configuration file in YAML. - compression_scheduler = distiller.file_config(model, optimizer, args.compress) + compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler) # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) model.cuda() - else: + elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) args.kd_policy = None