Skip to content
Snippets Groups Projects
Unverified Commit 6242afed authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Bug fix: value of best_top1 stored in the checkpoint may be wrong (#77)

* Bug fix: value of best_top1 stored in the checkpoint may be wrong

If you invoke compress_clasifier.py with --num-best-scores=n
with n>1, then the value of best_top1 stored in checkpoints is wrong.
parent 78e98a51
No related branches found
No related tags found
No related merge requests found
...@@ -399,19 +399,18 @@ def main(): ...@@ -399,19 +399,18 @@ def main():
if compression_scheduler: if compression_scheduler:
compression_scheduler.on_epoch_end(epoch, optimizer) compression_scheduler.on_epoch_end(epoch, optimizer)
# remember best top1 and save checkpoint # Update the list of top scores achieved so far, and save the checkpoint
#sparsity = distiller.model_sparsity(model) is_best = top1 > best_epochs[-1].top1
is_best = top1 > best_epochs[0].top1 if top1 > best_epochs[0].top1:
if is_best:
best_epochs[0].epoch = epoch best_epochs[0].epoch = epoch
best_epochs[0].top1 = top1 best_epochs[0].top1 = top1
#best_epoch.sparsity = sparsity # Keep best_epochs sorted such that best_epochs[0] is the lowest top1 in the best_epochs list
best_epochs = sorted(best_epochs, key=lambda score: score.top1) best_epochs = sorted(best_epochs, key=lambda score: score.top1)
for score in reversed(best_epochs): for score in reversed(best_epochs):
if score.top1 > 0: if score.top1 > 0:
msglogger.info('==> Best Top1: %.3f on Epoch: %d', score.top1, score.epoch) msglogger.info('==> Best Top1: %.3f on Epoch: %d', score.top1, score.epoch)
apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler,
best_epochs[0].top1, is_best, args.name, msglogger.logdir) best_epochs[-1].top1, is_best, args.name, msglogger.logdir)
# Finally run results on the test set # Finally run results on the test set
test(test_loader, model, criterion, [pylogger], activations_collectors, args=args) test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment