From c25d9ee20cff65ed47a1dc7d6594cbb7e1b6cd34 Mon Sep 17 00:00:00 2001 From: Yi-Syuan Chen <chenys1995@gmail.com> Date: Tue, 11 Dec 2018 20:20:32 +0800 Subject: [PATCH] Save scheduler in quantize_eval checkpoint (#99) --- examples/classifier_compression/compress_classifier.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 8ff0eb3..486c84a 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -364,7 +364,7 @@ def main(): return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities) if args.evaluate: - return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args) + return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args, compression_scheduler) if args.compress: # The main use-case for this sample application is CNN compression. Compression @@ -722,7 +722,7 @@ def earlyexit_validate_stats(args): msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5) return(total_top1, total_top5, losses_exits_stats) -def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args): +def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args, scheduler=None): # This sample application can be invoked to evaluate the accuracy of your model on # the test dataset. # You can optionally quantize the model to 8-bit integer before evaluation. @@ -744,7 +744,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector if args.quantize_eval: checkpoint_name = 'quantized' - apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, + apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, scheduler=scheduler, name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name, dir=msglogger.logdir) -- GitLab