diff --git a/docs-src/docs/earlyexit.md b/docs-src/docs/earlyexit.md index f36d25ed5b32755c4d2d2f2ffca248411ceb9114..3ba88c5496a1f2c19834f6952247b3376c25e317 100644 --- a/docs-src/docs/earlyexit.md +++ b/docs-src/docs/earlyexit.md @@ -24,9 +24,9 @@ There are other benefits to adding exits in that training the modified network n There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit: 1. **--earlyexit_thresholds** defines the -thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this would imply two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take that exit. +thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits. -2. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy. +1. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy. ### CIFAR10 In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers. diff --git a/docs-src/docs/imgs/decision_boundary.png b/docs-src/docs/imgs/decision_boundary.png index a22c4c42c20cd31df791354bbc012655359d74d9..54c6da7e295985ff1096a3e27e21946f9efd606b 100644 Binary files a/docs-src/docs/imgs/decision_boundary.png and b/docs-src/docs/imgs/decision_boundary.png differ diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index d0663c6b22bb891f391e6f6db44780295ece426d..6c5ab5903b45d1e677e590daa7c39fa7045d7279 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -580,8 +580,6 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): if args.display_confusion: confusion.add(output.data, target) else: - # If using Early Exit, then compute outputs at all exits - output is now a list of all exits - # from exit0 through exitN (i.e. [exit0, exit1, ... exitN]) earlyexit_validate_loss(output, target, criterion, args) # measure elapsed time @@ -637,8 +635,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): if args.exit_taken[exitnum]: msglogger.info("Percent Early Exit %d: %.3f", exitnum, (args.exit_taken[exitnum]*100.0) / sum_exit_stats) - - return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1] + total_top1 = 0 + total_top5 = 0 + for exitnum in range(args.num_exits): + total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats)) + total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats)) + msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum]) + msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5) + return total_top1, total_top5, losses_exits_stats[args.num_exits-1] def earlyexit_loss(output, target, criterion, args): @@ -655,28 +659,35 @@ def earlyexit_loss(output, target, criterion, args): def earlyexit_validate_loss(output, target, criterion, args): + # We need to go through each sample in the batch itself - in other words, we are + # not doing batch processing for exit criteria - we do this as though it were batchsize of 1 + # but with a grouping of samples equal to the batch size. + # Note that final group might not be a full batch - so determine actual size. + this_batch_size = target.size()[0] + earlyexit_validate_criterion = nn.CrossEntropyLoss(reduction='none').cuda() for exitnum in range(args.num_exits): - args.loss_exits[exitnum] = criterion(output[exitnum], target) - args.losses_exits[exitnum].add(args.loss_exits[exitnum].item()) + # calculate losses at each sample separately in the minibatch. + args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target) + # for batch_size > 1, we need to reduce this down to an average over the batch + args.losses_exits[exitnum].add(torch.mean(args.loss_exits[exitnum])) - # We need to go through this batch itself - this is now a vector of losses through the batch. - # Collecting stats on which exit early can be done across the batch at this time. - # Note that we can't use batch_size as last batch might be smaller - this_batch_size = target.size()[0] - for batchnum in range(this_batch_size): + for batch_index in range(this_batch_size): + earlyexit_taken = False # take the exit using CrossEntropyLoss as confidence measure (lower is more confident) - for exitnum in range(args.num_exits-1): - if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]: + for exitnum in range(args.num_exits - 1): + if args.loss_exits[exitnum][batch_index] < args.earlyexit_thresholds[exitnum]: # take the results from early exit since lower than threshold - args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)), - torch.full([1], target[batchnum], dtype=torch.long)) + args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batch_index], ndmin=2)), + torch.full([1], target[batch_index], dtype=torch.long)) args.exit_taken[exitnum] += 1 - else: - # skip the early exits and include results from end of net - args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], - ndmin=2)), - torch.full([1], target[batchnum], dtype=torch.long)) - args.exit_taken[args.num_exits-1] += 1 + earlyexit_taken = True + break # since exit was taken, do not affect the stats of subsequent exits + # this sample does not exit early and therefore continues until final exit + if not earlyexit_taken: + exitnum = args.num_exits - 1 + args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batch_index], ndmin=2)), + torch.full([1], target[batch_index], dtype=torch.long)) + args.exit_taken[exitnum] += 1 def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):