Skip to content
Snippets Groups Projects
Commit 470209b9 authored by Haim Barad's avatar Haim Barad Committed by Neta Zmora
Browse files

Early Exit docs (#75)

* Updated stats computation - fixes issues with validation stats

* Clarification of output (docs)

* Update

* Moved validation stats to separate function
parent a2dce3a6
No related branches found
No related tags found
No related merge requests found
......@@ -28,6 +28,9 @@ thresholds for each of the early exits. The cross entropy measure must be **less
1. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.
### Output Stats
The example code outputs various statistics regarding the loss and accuracy at each of the exits. During training, the Top1 and Top5 stats represent the accuracy should all of the data be forced out that exit (in order to compute the loss at that exit). During inference (i.e. validation and test stages), the Top1 and Top5 stats represent the accuracy for those data points that could exit because the calculated entropy at that exit was lower than the specified threshold for that exit.
### CIFAR10
In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.
......
......@@ -622,29 +622,7 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
msglogger.info('==> Confusion:\n%s\n', str(confusion.value()))
return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
else:
# Print some interesting summary stats for number of data points that could exit early
top1k_stats = [0] * args.num_exits
top5k_stats = [0] * args.num_exits
losses_exits_stats = [0] * args.num_exits
sum_exit_stats = 0
for exitnum in range(args.num_exits):
if args.exit_taken[exitnum]:
sum_exit_stats += args.exit_taken[exitnum]
msglogger.info("Exit %d: %d", exitnum, args.exit_taken[exitnum])
top1k_stats[exitnum] += args.exiterrors[exitnum].value(1)
top5k_stats[exitnum] += args.exiterrors[exitnum].value(5)
losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean
for exitnum in range(args.num_exits):
if args.exit_taken[exitnum]:
msglogger.info("Percent Early Exit %d: %.3f", exitnum,
(args.exit_taken[exitnum]*100.0) / sum_exit_stats)
total_top1 = 0
total_top5 = 0
for exitnum in range(args.num_exits):
total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum])
msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5)
total_top1, total_top5, losses_exits_stats = earlyexit_validate_stats(args)
return total_top1, total_top5, losses_exits_stats[args.num_exits-1]
......@@ -692,6 +670,31 @@ def earlyexit_validate_loss(output, target, criterion, args):
torch.full([1], target[batch_index], dtype=torch.long))
args.exit_taken[exitnum] += 1
def earlyexit_validate_stats(args):
# Print some interesting summary stats for number of data points that could exit early
top1k_stats = [0] * args.num_exits
top5k_stats = [0] * args.num_exits
losses_exits_stats = [0] * args.num_exits
sum_exit_stats = 0
for exitnum in range(args.num_exits):
if args.exit_taken[exitnum]:
sum_exit_stats += args.exit_taken[exitnum]
msglogger.info("Exit %d: %d", exitnum, args.exit_taken[exitnum])
top1k_stats[exitnum] += args.exiterrors[exitnum].value(1)
top5k_stats[exitnum] += args.exiterrors[exitnum].value(5)
losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean
for exitnum in range(args.num_exits):
if args.exit_taken[exitnum]:
msglogger.info("Percent Early Exit %d: %.3f", exitnum,
(args.exit_taken[exitnum]*100.0) / sum_exit_stats)
total_top1 = 0
total_top5 = 0
for exitnum in range(args.num_exits):
total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum])
msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5)
return(total_top1, total_top5, losses_exits_stats)
def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):
# This sample application can be invoked to evaluate the accuracy of your model on
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment