diff --git a/docs-src/docs/earlyexit.md b/docs-src/docs/earlyexit.md
index f36d25ed5b32755c4d2d2f2ffca248411ceb9114..3ba88c5496a1f2c19834f6952247b3376c25e317 100644
--- a/docs-src/docs/earlyexit.md
+++ b/docs-src/docs/earlyexit.md
@@ -24,9 +24,9 @@ There are other benefits to adding exits in that training the modified network n
 There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit:
 
 1. **--earlyexit_thresholds** defines the
-thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this would imply two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take that exit.
+thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this implies two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take those exits.
 
-2. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.
+1. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.
 
 ### CIFAR10
 In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.
diff --git a/docs-src/docs/imgs/decision_boundary.png b/docs-src/docs/imgs/decision_boundary.png
index a22c4c42c20cd31df791354bbc012655359d74d9..54c6da7e295985ff1096a3e27e21946f9efd606b 100644
Binary files a/docs-src/docs/imgs/decision_boundary.png and b/docs-src/docs/imgs/decision_boundary.png differ
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index d0663c6b22bb891f391e6f6db44780295ece426d..6c5ab5903b45d1e677e590daa7c39fa7045d7279 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -580,8 +580,6 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
                 if args.display_confusion:
                     confusion.add(output.data, target)
             else:
-                # If using Early Exit, then compute outputs at all exits - output is now a list of all exits
-                # from exit0 through exitN (i.e. [exit0, exit1, ... exitN])
                 earlyexit_validate_loss(output, target, criterion, args)
 
             # measure elapsed time
@@ -637,8 +635,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
             if args.exit_taken[exitnum]:
                 msglogger.info("Percent Early Exit %d: %.3f", exitnum,
                                (args.exit_taken[exitnum]*100.0) / sum_exit_stats)
-
-        return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1]
+        total_top1 = 0
+        total_top5 = 0
+        for exitnum in range(args.num_exits):
+            total_top1 += (top1k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
+            total_top5 += (top5k_stats[exitnum] * (args.exit_taken[exitnum] / sum_exit_stats))
+            msglogger.info("Accuracy Stats for exit %d: top1 = %.3f, top5 = %.3f", exitnum, top1k_stats[exitnum], top5k_stats[exitnum])
+        msglogger.info("Totals for entire network with early exits: top1 = %.3f, top5 = %.3f", total_top1, total_top5)
+        return total_top1, total_top5, losses_exits_stats[args.num_exits-1]
 
 
 def earlyexit_loss(output, target, criterion, args):
@@ -655,28 +659,35 @@ def earlyexit_loss(output, target, criterion, args):
 
 
 def earlyexit_validate_loss(output, target, criterion, args):
+    # We need to go through each sample in the batch itself - in other words, we are
+    # not doing batch processing for exit criteria - we do this as though it were batchsize of 1
+    # but with a grouping of samples equal to the batch size.
+    # Note that final group might not be a full batch - so determine actual size.
+    this_batch_size = target.size()[0]
+    earlyexit_validate_criterion = nn.CrossEntropyLoss(reduction='none').cuda()
     for exitnum in range(args.num_exits):
-        args.loss_exits[exitnum] = criterion(output[exitnum], target)
-        args.losses_exits[exitnum].add(args.loss_exits[exitnum].item())
+        # calculate losses at each sample separately in the minibatch. 
+        args.loss_exits[exitnum] = earlyexit_validate_criterion(output[exitnum], target)
+        # for batch_size > 1, we need to reduce this down to an average over the batch
+        args.losses_exits[exitnum].add(torch.mean(args.loss_exits[exitnum]))
 
-    # We need to go through this batch itself - this is now a vector of losses through the batch.
-    # Collecting stats on which exit early can be done across the batch at this time.
-    # Note that we can't use batch_size as last batch might be smaller
-    this_batch_size = target.size()[0]
-    for batchnum in range(this_batch_size):
+    for batch_index in range(this_batch_size):
+        earlyexit_taken = False
         # take the exit using CrossEntropyLoss as confidence measure (lower is more confident)
-        for exitnum in range(args.num_exits-1):
-            if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]:
+        for exitnum in range(args.num_exits - 1):
+            if args.loss_exits[exitnum][batch_index] < args.earlyexit_thresholds[exitnum]:
                 # take the results from early exit since lower than threshold
-                args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)),
-                        torch.full([1], target[batchnum], dtype=torch.long))
+                args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batch_index], ndmin=2)),
+                        torch.full([1], target[batch_index], dtype=torch.long))
                 args.exit_taken[exitnum] += 1
-            else:
-                # skip the early exits and include results from end of net
-                args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum],
-                                                                            ndmin=2)),
-                        torch.full([1], target[batchnum], dtype=torch.long))
-                args.exit_taken[args.num_exits-1] += 1
+                earlyexit_taken = True
+                break                    # since exit was taken, do not affect the stats of subsequent exits
+        # this sample does not exit early and therefore continues until final exit
+        if not earlyexit_taken:
+            exitnum = args.num_exits - 1
+            args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batch_index], ndmin=2)),
+                    torch.full([1], target[batch_index], dtype=torch.long))
+            args.exit_taken[exitnum] += 1
 
 
 def evaluate_model(model, criterion, test_loader, loggers, activations_collectors, args):