diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 7c051545d3e166938641980da1dac577d0ed5498..380a4c455a3571bd129b80cfd5b8231fff1aedaa 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -399,19 +399,18 @@ def main(): if compression_scheduler: compression_scheduler.on_epoch_end(epoch, optimizer) - # remember best top1 and save checkpoint - #sparsity = distiller.model_sparsity(model) - is_best = top1 > best_epochs[0].top1 - if is_best: + # Update the list of top scores achieved so far, and save the checkpoint + is_best = top1 > best_epochs[-1].top1 + if top1 > best_epochs[0].top1: best_epochs[0].epoch = epoch best_epochs[0].top1 = top1 - #best_epoch.sparsity = sparsity + # Keep best_epochs sorted such that best_epochs[0] is the lowest top1 in the best_epochs list best_epochs = sorted(best_epochs, key=lambda score: score.top1) for score in reversed(best_epochs): if score.top1 > 0: msglogger.info('==> Best Top1: %.3f on Epoch: %d', score.top1, score.epoch) apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, - best_epochs[0].top1, is_best, args.name, msglogger.logdir) + best_epochs[-1].top1, is_best, args.name, msglogger.logdir) # Finally run results on the test set test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)