From 9cb0dd684368d543105222240a4889b5e80da98b Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 6 Mar 2019 18:12:07 +0200
Subject: [PATCH] compress_classifier.py: sort best scores by count of NNZ
 weights

A recent commit changed the sorting of the best performing training
epochs to be based on the sparsity level of the model, then its
Top1 and Top5 scores.
When we create thinned models, the sparsity remains low (even zero),
while the physical size of the network is smaller.
This commit changes the sorting criteria to be based on the count
of non-zero (NNZ) parameters.  This captures both sparsity and
parameter size objectives:
- When sparsity is high, the number of NNZ params is low
(params_nnz_cnt = sparsity * params_cnt).
- When we remove structures (thinnning), the sparsity may remain
constant, but the count of params (params_cnt) is lower, and therefore,
once again params_nnz_cnt is lower.

Therefore, params_nnz_cnt is a good proxy to capture a sparsity
objective and/or a thinning objective.
---
 .../compress_classifier.py                    | 26 ++++++++++++-------
 1 file changed, 16 insertions(+), 10 deletions(-)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index a4730c3..e4a5905 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -280,16 +280,7 @@ def main():
             compression_scheduler.on_epoch_end(epoch, optimizer)
 
         # Update the list of top scores achieved so far, and save the checkpoint
-        sparsity = distiller.model_sparsity(model)
-        perf_scores_history.append(distiller.MutableNamedTuple({'sparsity': sparsity, 'top1': top1,
-                                                                'top5': top5, 'epoch': epoch}))
-        # Keep perf_scores_history sorted from best to worst
-        # Sort by sparsity as main sort key, then sort by top1, top5 and epoch
-        perf_scores_history.sort(key=operator.attrgetter('sparsity', 'top1', 'top5', 'epoch'), reverse=True)
-        for score in perf_scores_history[:args.num_best_scores]:
-            msglogger.info('==> Best [Top1: %.3f   Top5: %.3f   Sparsity: %.2f on epoch: %d]',
-                           score.top1, score.top5, score.sparsity, score.epoch)
-
+        update_training_scores_history(perf_scores_history, model, top1, top5, epoch, args.num_best_scores)
         is_best = epoch == perf_scores_history[0].epoch
         apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler,
                                  perf_scores_history[0].top1, is_best, args.name, msglogger.logdir)
@@ -514,6 +505,21 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
         return total_top1, total_top5, losses_exits_stats[args.num_exits-1]
 
 
+def update_training_scores_history(perf_scores_history, model, top1, top5, epoch, num_best_scores):
+    """ Update the list of top training scores achieved so far, and log the best scores so far"""
+
+    model_sparsity, _, params_nnz_cnt = distiller.model_params_stats(model)
+    perf_scores_history.append(distiller.MutableNamedTuple({'params_nnz_cnt': -params_nnz_cnt,
+                                                            'sparsity': model_sparsity,
+                                                            'top1': top1, 'top5': top5, 'epoch': epoch}))
+    # Keep perf_scores_history sorted from best to worst
+    # Sort by sparsity as main sort key, then sort by top1, top5 and epoch
+    perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'top1', 'top5', 'epoch'), reverse=True)
+    for score in perf_scores_history[:num_best_scores]:
+        msglogger.info('==> Best [Top1: %.3f   Top5: %.3f   Sparsity:%.2f   Params: %d on epoch: %d]',
+                       score.top1, score.top5, score.sparsity, -score.params_nnz_cnt, score.epoch)
+
+
 def earlyexit_loss(output, target, criterion, args):
     loss = 0
     sum_lossweights = 0
-- 
GitLab