diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index a4730c304af9d8c82649bf9a7bcd598513b11c90..e4a5905322938283357163fcd63378a3fb38ea1b 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -280,16 +280,7 @@ def main(): compression_scheduler.on_epoch_end(epoch, optimizer) # Update the list of top scores achieved so far, and save the checkpoint - sparsity = distiller.model_sparsity(model) - perf_scores_history.append(distiller.MutableNamedTuple({'sparsity': sparsity, 'top1': top1, - 'top5': top5, 'epoch': epoch})) - # Keep perf_scores_history sorted from best to worst - # Sort by sparsity as main sort key, then sort by top1, top5 and epoch - perf_scores_history.sort(key=operator.attrgetter('sparsity', 'top1', 'top5', 'epoch'), reverse=True) - for score in perf_scores_history[:args.num_best_scores]: - msglogger.info('==> Best [Top1: %.3f Top5: %.3f Sparsity: %.2f on epoch: %d]', - score.top1, score.top5, score.sparsity, score.epoch) - + update_training_scores_history(perf_scores_history, model, top1, top5, epoch, args.num_best_scores) is_best = epoch == perf_scores_history[0].epoch apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, perf_scores_history[0].top1, is_best, args.name, msglogger.logdir) @@ -514,6 +505,21 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): return total_top1, total_top5, losses_exits_stats[args.num_exits-1] +def update_training_scores_history(perf_scores_history, model, top1, top5, epoch, num_best_scores): + """ Update the list of top training scores achieved so far, and log the best scores so far""" + + model_sparsity, _, params_nnz_cnt = distiller.model_params_stats(model) + perf_scores_history.append(distiller.MutableNamedTuple({'params_nnz_cnt': -params_nnz_cnt, + 'sparsity': model_sparsity, + 'top1': top1, 'top5': top5, 'epoch': epoch})) + # Keep perf_scores_history sorted from best to worst + # Sort by sparsity as main sort key, then sort by top1, top5 and epoch + perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'top1', 'top5', 'epoch'), reverse=True) + for score in perf_scores_history[:num_best_scores]: + msglogger.info('==> Best [Top1: %.3f Top5: %.3f Sparsity:%.2f Params: %d on epoch: %d]', + score.top1, score.top5, score.sparsity, -score.params_nnz_cnt, score.epoch) + + def earlyexit_loss(output, target, criterion, args): loss = 0 sum_lossweights = 0