Skip to content
Snippets Groups Projects
Commit 3876a912 authored by Haim Barad's avatar Haim Barad Committed by Neta Zmora
Browse files

Fixed validation stats and added new summary stats (#71)

* Fixed validation stats and added new summary stats

* Trimmed some comments.

* Improved figure for documentation

* Minor updates
parent 60a4f44a
No related branches found
No related tags found
No related merge requests found
...@@ -24,9 +24,9 @@ There are other benefits to adding exits in that training the modified network n ...@@ -24,9 +24,9 @@ There are other benefits to adding exits in that training the modified network n
There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit: There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit:
1. **--earlyexit_thresholds** defines the 1. **--earlyexit_thresholds** defines the
thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this would imply two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take that exit. thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits.
2. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy. 1. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.
### CIFAR10 ### CIFAR10
In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers. In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.
......
docs-src/docs/imgs/decision_boundary.png

274 KiB | W: | H:

docs-src/docs/imgs/decision_boundary.png

191 KiB | W: | H:

docs-src/docs/imgs/decision_boundary.png
docs-src/docs/imgs/decision_boundary.png
docs-src/docs/imgs/decision_boundary.png
docs-src/docs/imgs/decision_boundary.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -580,8 +580,6 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -580,8 +580,6 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
if args.display_confusion: if args.display_confusion:
confusion.add(output.data, target) confusion.add(output.data, target)
else: else:
# If using Early Exit, then compute outputs at all exits - output is now a list of all exits
# from exit0 through exitN (i.e. [exit0, exit1, ... exitN])
earlyexit_validate_loss(output, target, criterion, args) earlyexit_validate_loss(output, target, criterion, args)
# measure elapsed time # measure elapsed time
...@@ -637,8 +635,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1): ...@@ -637,8 +635,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
if args.exit_taken[exitnum]: if args.exit_taken[exitnum]:
msglogger.info("Percent Early Exit %d: %.3f", exitnum, msglogger.info("Percent Early Exit %d: %.3f", exitnum,
(args.exit_taken[exitnum]*100.0) / sum_exit_stats) (args.exit_taken[exitnum]*100.0) / sum_exit_stats)
total_top1 = 0
return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1] total_top5 = 0
for exitnum in range(args.num_exits):
total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum])
msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5)
return total_top1, total_top5, losses_exits_stats[args.num_exits-1]
def earlyexit_loss(output, target, criterion, args): def earlyexit_loss(output, target, criterion, args):
...@@ -655,28 +659,35 @@ def earlyexit_loss(output, target, criterion, args): ...@@ -655,28 +659,35 @@ def earlyexit_loss(output, target, criterion, args):
def earlyexit_validate_loss(output, target, criterion, args): def earlyexit_validate_loss(output, target, criterion, args):
# We need to go through each sample in the batch itself - in other words, we are
# not doing batch processing for exit criteria - we do this as though it were batchsize of 1
# but with a grouping of samples equal to the batch size.
# Note that final group might not be a full batch - so determine actual size.
this_batch_size = target.size()[0]
earlyexit_validate_criterion = nn.CrossEntropyLoss(reduction='none').cuda()
for exitnum in range(args.num_exits): for exitnum in range(args.num_exits):
args.loss_exits[exitnum] = criterion(output[exitnum], target) # calculate losses at each sample separately in the minibatch.
args.losses_exits[exitnum].add(args.loss_exits[exitnum].item()) args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target)
# for batch_size > 1, we need to reduce this down to an average over the batch
args.losses_exits[exitnum].add(torch.mean(args.loss_exits[exitnum]))
# We need to go through this batch itself - this is now a vector of losses through the batch. for batch_index in range(this_batch_size):
# Collecting stats on which exit early can be done across the batch at this time. earlyexit_taken = False
# Note that we can't use batch_size as last batch might be smaller
this_batch_size = target.size()[0]
for batchnum in range(this_batch_size):
# take the exit using CrossEntropyLoss as confidence measure (lower is more confident) # take the exit using CrossEntropyLoss as confidence measure (lower is more confident)
for exitnum in range(args.num_exits-1): for exitnum in range(args.num_exits - 1):
if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]: if args.loss_exits[exitnum][batch_index] < args.earlyexit_thresholds[exitnum]:
# take the results from early exit since lower than threshold # take the results from early exit since lower than threshold
args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)), args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batch_index], ndmin=2)),
torch.full([1], target[batchnum], dtype=torch.long)) torch.full([1], target[batch_index], dtype=torch.long))
args.exit_taken[exitnum] += 1 args.exit_taken[exitnum] += 1
else: earlyexit_taken = True
# skip the early exits and include results from end of net break # since exit was taken, do not affect the stats of subsequent exits
args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], # this sample does not exit early and therefore continues until final exit
ndmin=2)), if not earlyexit_taken:
torch.full([1], target[batchnum], dtype=torch.long)) exitnum = args.num_exits - 1
args.exit_taken[args.num_exits-1] += 1 args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batch_index], ndmin=2)),
torch.full([1], target[batch_index], dtype=torch.long))
args.exit_taken[exitnum] += 1
def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args): def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment