From fc62caab0791f550281f27da41935570f1aefa08 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 7 Nov 2019 19:44:43 +0200
Subject: [PATCH] Fix Early-exit code

Fix the EE code so that it works with the current 'master' branch,
and add a test for high-level EE regression
---
 distiller/apputils/image_classifier.py        | 29 ++++++++++++-------
 .../compress_classifier.py                    |  2 +-
 tests/full_flow_tests.py                      |  2 ++
 3 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 451a0ec..6b42383 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -502,7 +502,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
             errs['Top1'] = classerr.value(1)
             errs['Top5'] = classerr.value(5)
         else:
-            # for Early Exit case, the Top1 and Top5 stats are computed for each exit.
+            # For Early Exit case, the Top1 and Top5 stats are computed for each exit.
             for exitnum in range(args.num_exits):
                 errs['Top1_exit' + str(exitnum)] = args.exiterrors[exitnum].value(1)
                 errs['Top5_exit' + str(exitnum)] = args.exiterrors[exitnum].value(5)
@@ -533,8 +533,8 @@ def train(train_loader, model, criterion, optimizer, epoch,
     batch_time = tnt.AverageValueMeter()
     data_time = tnt.AverageValueMeter()
 
-    # For Early Exit, we define statistics for each exit
-    # So exiterrors is analogous to classerr for the non-Early Exit case
+    # For Early Exit, we define statistics for each exit, so
+    # exiterrors is analogous to classerr for the non-Early Exit case
     if early_exit_mode(args):
         args.exiterrors = []
         for exitnum in range(args.num_exits):
@@ -570,6 +570,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
             acc_stats.append([classerr.value(1), classerr.value(5)])
         else:
             # Measure accuracy and record loss
+            classerr.add(output[args.num_exits-1].data, target) # add the last exit (original exit)
             loss = earlyexit_loss(output, target, criterion, args)
         # Record loss
         losses[OBJECTIVE_LOSS_KEY].add(loss.item())
@@ -739,24 +740,30 @@ def update_training_scores_history(perf_scores_history, model, top1, top5, epoch
 
 
 def earlyexit_loss(output, target, criterion, args):
-    loss = 0
-    sum_lossweights = 0
+    """Compute the weighted sum of the exits losses
+
+    Note that the last exit is the original exit of the model (i.e. the
+    exit that traverses the entire network.
+    """
+    weighted_loss = 0
+    sum_lossweights = sum(args.earlyexit_lossweights)
+    assert sum_lossweights < 1
     for exitnum in range(args.num_exits-1):
-        loss += (args.earlyexit_lossweights[exitnum] * criterion(output[exitnum], target))
-        sum_lossweights += args.earlyexit_lossweights[exitnum]
+        exit_loss = criterion(output[exitnum], target)
+        weighted_loss += args.earlyexit_lossweights[exitnum] * exit_loss
         args.exiterrors[exitnum].add(output[exitnum].data, target)
     # handle final exit
-    loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target)
+    weighted_loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target)
     args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target)
-    return loss
+    return weighted_loss
 
 
 def earlyexit_validate_loss(output, target, criterion, args):
     # We need to go through each sample in the batch itself - in other words, we are
-    # not doing batch processing for exit criteria - we do this as though it were batchsize of 1
+    # not doing batch processing for exit criteria - we do this as though it were batch size of 1,
     # but with a grouping of samples equal to the batch size.
     # Note that final group might not be a full batch - so determine actual size.
-    this_batch_size = target.size()[0]
+    this_batch_size = target.size(0)
     earlyexit_validate_criterion = nn.CrossEntropyLoss(reduce=False).to(args.device)
 
     for exitnum in range(args.num_exits):
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 5108dbf..09dab16 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -158,7 +158,7 @@ def early_exit_init(args):
 class ClassifierCompressorSampleApp(classifier.ClassifierCompressor):
     def __init__(self, args, script_dir):
         super().__init__(args, script_dir)
-        early_exit_init(args)
+        early_exit_init(self.args)
         # Save the randomly-initialized model before training (useful for lottery-ticket method)
         if args.save_untrained_model:
             ckpt_name = '_'.join((self.args.name or "", "untrained"))
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index d8617b7..052f068 100755
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -120,6 +120,8 @@ def collateral_checker(log, run_dir, *collateral_list):
 TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args'])
 
 test_configs = [
+    TestConfig('--arch resnet20_cifar_earlyexit --lr=0.3 --epochs=180 --earlyexit_thresholds 0.9 '
+               '--earlyexit_lossweights 0.3 --epochs 2', DS_CIFAR, accuracy_checker, [17.04, 64.42]),
     TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]),
     TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'),
-- 
GitLab