From 470209b94bb79fb5097103cfb32fe9791a3d1961 Mon Sep 17 00:00:00 2001 From: Haim Barad <35232758+haim-barad@users.noreply.github.com> Date: Thu, 8 Nov 2018 15:35:55 +0200 Subject: [PATCH] Early Exit docs (#75) * Updated stats computation - fixes issues with validation stats * Clarification of output (docs) * Update * Moved validation stats to separate function --- docs-src/docs/earlyexit.md | 3 ++ .../compress_classifier.py | 49 ++++++++++--------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/docs-src/docs/earlyexit.md b/docs-src/docs/earlyexit.md index 3ba88c5..e28d21f 100644 --- a/docs-src/docs/earlyexit.md +++ b/docs-src/docs/earlyexit.md @@ -28,6 +28,9 @@ thresholds for each of the early exits. The cross entropy measure must be **less 1. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy. +### Output Stats +The example code outputs various statistics regarding the loss and accuracy at each of the exits. During training, the Top1 and Top5 stats represent the accuracy should all of the data be forced out that exit (in order to compute the loss at that exit). During inference (i.e. validation and test stages), the Top1 and Top5 stats represent the accuracy for those data points that could exit because the calculated entropy at that exit was lower than the specified threshold for that exit. + ### CIFAR10 In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers. diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 2887f7e..ade512d 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -622,29 +622,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): msglogger.info('==> Confusion:\n%s\n', str(confusion.value())) return classerr.value(1), classerr.value(5), losses['objective_loss'].mean else: - # Print some interesting summary stats for number of data points that could exit early - top1k_stats = [0] * args.num_exits - top5k_stats = [0] * args.num_exits - losses_exits_stats = [0] * args.num_exits - sum_exit_stats = 0 - for exitnum in range(args.num_exits): - if args.exit_taken[exitnum]: - sum_exit_stats += args.exit_taken[exitnum] - msglogger.info("Exit %d: %d", exitnum, args.exit_taken[exitnum]) - top1k_stats[exitnum] += args.exiterrors[exitnum].value(1) - top5k_stats[exitnum] += args.exiterrors[exitnum].value(5) - losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean - for exitnum in range(args.num_exits): - if args.exit_taken[exitnum]: - msglogger.info("Percent Early Exit %d: %.3f", exitnum, - (args.exit_taken[exitnum]*100.0) / sum_exit_stats) - total_top1 = 0 - total_top5 = 0 - for exitnum in range(args.num_exits): - total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats)) - total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats)) - msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum]) - msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5) + total_top1, total_top5, losses_exits_stats = earlyexit_validate_stats(args) return total_top1, total_top5, losses_exits_stats[args.num_exits-1] @@ -692,6 +670,31 @@ def earlyexit_validate_loss(output, target, criterion, args): torch.full([1], target[batch_index], dtype=torch.long)) args.exit_taken[exitnum] += 1 +def earlyexit_validate_stats(args): + # Print some interesting summary stats for number of data points that could exit early + top1k_stats = [0] * args.num_exits + top5k_stats = [0] * args.num_exits + losses_exits_stats = [0] * args.num_exits + sum_exit_stats = 0 + for exitnum in range(args.num_exits): + if args.exit_taken[exitnum]: + sum_exit_stats += args.exit_taken[exitnum] + msglogger.info("Exit %d: %d", exitnum, args.exit_taken[exitnum]) + top1k_stats[exitnum] += args.exiterrors[exitnum].value(1) + top5k_stats[exitnum] += args.exiterrors[exitnum].value(5) + losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean + for exitnum in range(args.num_exits): + if args.exit_taken[exitnum]: + msglogger.info("Percent Early Exit %d: %.3f", exitnum, + (args.exit_taken[exitnum]*100.0) / sum_exit_stats) + total_top1 = 0 + total_top5 = 0 + for exitnum in range(args.num_exits): + total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats)) + total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats)) + msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum]) + msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5) + return(total_top1, total_top5, losses_exits_stats) def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args): # This sample application can be invoked to evaluate the accuracy of your model on -- GitLab