From c25d9ee20cff65ed47a1dc7d6594cbb7e1b6cd34 Mon Sep 17 00:00:00 2001
From: Yi-Syuan Chen <chenys1995@gmail.com>
Date: Tue, 11 Dec 2018 20:20:32 +0800
Subject: [PATCH] Save scheduler in quantize_eval checkpoint (#99)

---
 examples/classifier_compression/compress_classifier.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 8ff0eb3..486c84a 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -364,7 +364,7 @@ def main():
         return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities)
 
     if args.evaluate:
-        return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args)
+        return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args, compression_scheduler)
 
     if args.compress:
         # The main use-case for this sample application is CNN compression. Compression
@@ -722,7 +722,7 @@ def earlyexit_validate_stats(args):
     msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5)
     return(total_top1, total_top5, losses_exits_stats)
 
-def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):
+def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args, scheduler=None):
     # This sample application can be invoked to evaluate the accuracy of your model on
     # the test dataset.
     # You can optionally quantize the model to 8-bit integer before evaluation.
@@ -744,7 +744,7 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector
 
     if args.quantize_eval:
         checkpoint_name = 'quantized'
-        apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
+        apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, scheduler=scheduler,
                                  name='_'.join([args.name, checkpoint_name]) if args.name else checkpoint_name,
                                  dir=msglogger.logdir)
 
-- 
GitLab