From 6242afedac390cafe42cbce35a1930525d3acb81 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Tue, 20 Nov 2018 17:58:15 +0200
Subject: [PATCH] Bug fix: value of best_top1 stored in the checkpoint may be
 wrong (#77)

* Bug fix: value of best_top1 stored in the checkpoint may be wrong

If you invoke compress_clasifier.py with --num-best-scores=n
with n>1, then the value of best_top1 stored in checkpoints is wrong.
---
 .../classifier_compression/compress_classifier.py     | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 7c05154..380a4c4 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -399,19 +399,18 @@ def main():
         if compression_scheduler:
             compression_scheduler.on_epoch_end(epoch, optimizer)
 
-        # remember best top1 and save checkpoint
-        #sparsity = distiller.model_sparsity(model)
-        is_best = top1 > best_epochs[0].top1
-        if is_best:
+        # Update the list of top scores achieved so far, and save the checkpoint
+        is_best = top1 > best_epochs[-1].top1
+        if top1 > best_epochs[0].top1:
             best_epochs[0].epoch = epoch
             best_epochs[0].top1 = top1
-            #best_epoch.sparsity = sparsity
+            # Keep best_epochs sorted such that best_epochs[0] is the lowest top1 in the best_epochs list
             best_epochs = sorted(best_epochs, key=lambda score: score.top1)
         for score in reversed(best_epochs):
             if score.top1 > 0:
                 msglogger.info('==> Best Top1: %.3f on Epoch: %d', score.top1, score.epoch)
         apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler,
-                                 best_epochs[0].top1, is_best, args.name, msglogger.logdir)
+                                 best_epochs[-1].top1, is_best, args.name, msglogger.logdir)
 
     # Finally run results on the test set
     test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)
-- 
GitLab