diff --git a/docs-src/docs/earlyexit.md b/docs-src/docs/earlyexit.md
new file mode 100644
index 0000000000000000000000000000000000000000..f36d25ed5b32755c4d2d2f2ffca248411ceb9114
--- /dev/null
+++ b/docs-src/docs/earlyexit.md
@@ -0,0 +1,42 @@
+# Early Exit Inference
+While Deep Neural Networks benefit from a large number of layers, it's often the case that many datapoints in classification tasks can be classified accurately with much less work. There have been several studies recently regarding the idea of exiting before the normal endpoint of the neural network. Panda et al in [Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition](#panda) points out that a lot of data points can be classified easily and require less processing than some more difficult points and they view this in terms of power savings. Surat et al in [BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks](#branchynet) look at a selective approach to exit placement and criteria for exiting early.
+
+## Why Does Early Exit Work?
+Early Exit is a strategy with a straightforward and easy to understand concept Figure #fig(boundaries) shows a simple example in a 2-D feature space. While deep networks can representative more complex and expressive boundaries between classes (assuming we’re confident of avoiding over-fitting the data), it’s also clear that much of the data can be properly classified with even the simplest of classification boundaries.
+
+![Figure !fig(boundaries): Simple and more expressive classification boundaries](/docs-src/docs/imgs/decision_boundary.png)
+
+Data points far from the boundary can be considered "easy to classify" and achieve a high degree of confidence quicker than do data points close to the boundary. In fact, we can think of the area between the outer straight lines as being the region that is "difficult to classify" and require the full expressiveness of the neural network to accurately classify it.
+
+## Example code for Early Exit
+Both CIFAR10 and Imagenet code comes directly from publically available examples from Pytorch. The only edits are the exits that are inserted in a methodology similar to BranchyNet work.
+
+Deeper networks can benefit from multiple exits. Our examples illustrate both a single and a pair of early exits for CIFAR10 and Imagenet, respectively.
+
+Note that this code does not actually take exits. What it does is to compute statistics of loss and accuracy assuming exits were taken when criteria are met. Actually implementing exits can be tricky and architecture dependent and we plan to address these issues.
+
+### Heuristics
+The insertion of the exits are ad-hoc, but there are some heuristic principals guiding their placement and parameters. The earlier exits are placed, the more agressive the exit as it essentially prunes the rest of the network at a very early stage, thus saving a lot of work. However, a diminishing percentage of data will be directed through the exit if we are to preserve accuracy.
+
+There are other benefits to adding exits in that training the modified network now has backpropagation losses coming from the exits that affect the earlier layers more substantially than the last exit. This effect mitigates problems such as vanishing gradient.
+
+### Early Exit Hyperparameters
+There are two parameters that are required to enable early exit. Leave them undefined if you are not enabling Early Exit:
+
+1. **--earlyexit_thresholds** defines the
+thresholds for each of the early exits. The cross entropy measure must be **less than** the specified threshold to take a specific exit, otherwise the data continues along the regular path. For example, you could specify "--earlyexit_thresholds 0.9 1.2" and this would imply two early exits with corresponding thresholds of 0.9 and 1.2, respectively to take that exit.
+
+2. **--earlyexit_lossweights** provide the weights for the linear combination of losses during training to compute a signle, overall loss. We only specify weights for the early exits and assume that the sum of the weights (including final exit) are equal to 1.0. So an example of "--earlyexit_lossweights 0.2 0.3" implies two early exits weighted with values of 0.2 and 0.3, respectively and that the final exit has a value of 1.0-(0.2+0.3) = 0.5. Studies have shown that weighting the early exits more heavily will create more agressive early exits, but perhaps with a slight negative effect on accuracy.
+
+### CIFAR10
+In the case of CIFAR10, we have inserted a single exit after the first full layer grouping. The layers on the exit path itself includes a convolutional layer and a fully connected layer. If you move the exit, be sure to match the proper sizes for inputs and outputs to the exit layers.
+
+### Imagenet
+This supports training and inference of the imagenet dataset via several well known deep architectures. ResNet-50 is the architecture of interest in this study, however the exit is defined in the generic resnet code and could be used with other size resnets. There are two exits inserted in this example. Again, exit layers must have their sizes match properly.
+
+## References
+<div id="panda"></div> **Priyadarshini Panda, Abhronil Sengupta, Kaushik Roy**.
+    [*Conditional Deep Learning for Energy-Efficient and Enhanced Pattern Recognition*](https://arxiv.org/abs/1509.08971v6), arXiv:1509.08971v6, 2017.
+
+<div id="branchynet"></div> **Surat Teerapittayanon, Bradley McDanel, H. T. Kung**.
+    [*BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks*](http://arxiv.org/abs/1709.01686), arXiv:1709.01686, 2017.
diff --git a/docs-src/docs/imgs/decision_boundary.png b/docs-src/docs/imgs/decision_boundary.png
new file mode 100644
index 0000000000000000000000000000000000000000..a22c4c42c20cd31df791354bbc012655359d74d9
Binary files /dev/null and b/docs-src/docs/imgs/decision_boundary.png differ
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 55f9b19f9a44faddfbf9bda2f43cc83149029a1f..02e86e2aa61f9a194612dc0151a064f4ceea6e66 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -145,6 +145,8 @@ parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
 parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
 parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
 parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true', help='Display the confusion matrix')
+parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None, help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)')
+parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None, help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)')
 
 
 def check_pytorch_version():
@@ -213,6 +215,12 @@ def main():
     args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
     args.num_classes = 10 if args.dataset == 'cifar10' else 1000
 
+    if args.earlyexit_thresholds:
+        args.num_exits = len(args.earlyexit_thresholds) + 1
+        args.loss_exits = [0] * args.num_exits
+        args.losses_exits = []
+        args.exiterrors = []
+
     # Create the model
     model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus)
     compression_scheduler = None
@@ -221,6 +229,10 @@ def main():
     tflogger = TensorBoardLogger(msglogger.logdir)
     pylogger = PythonLogger(msglogger)
 
+    # capture thresholds for early-exit training
+    if args.earlyexit_thresholds:
+        msglogger.info('=> using early-exit threshold values of %s', args.earlyexit_thresholds)
+
     # We can optionally resume from a checkpoint
     if args.resume:
         model, compression_scheduler, start_epoch = apputils.load_checkpoint(
@@ -290,8 +302,7 @@ def main():
                  OrderedDict([('Loss', vloss),
                               ('Top1', top1),
                               ('Top5', top5)]))
-        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1,
-                                        log_freq=1, loggers=[tflogger])
+        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, loggers=[tflogger])
 
         if compression_scheduler:
             compression_scheduler.on_epoch_end(epoch, optimizer)
@@ -321,6 +332,11 @@ def train(train_loader, model, criterion, optimizer, epoch,
     classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
     batch_time = tnt.AverageValueMeter()
     data_time = tnt.AverageValueMeter()
+    # For Early Exit, we define statistics for each exit - so exiterrors is analogous to classerr for the non-Early Exit case
+    if args.earlyexit_lossweights:
+        args.exiterrors = []
+        for exitnum in range(args.num_exits):
+            args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)))
 
     total_samples = len(train_loader.sampler)
     batch_size = train_loader.batch_size
@@ -342,11 +358,16 @@ def train(train_loader, model, criterion, optimizer, epoch,
         # Execute the forward phase, compute the output and measure loss
         if compression_scheduler:
             compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer)
+
         output = model(input_var)
-        loss = criterion(output, target_var)
+        if not args.earlyexit_lossweights:
+            loss = criterion(output, target_var)
+            # Measure accuracy and record loss
+            classerr.add(output.data, target)
+        else:
+            # Measure accuracy and record loss
+            loss = earlyexit_loss(output, target_var, criterion, args)
 
-        # Measure accuracy and record loss
-        classerr.add(output.data, target)
         losses['objective_loss'].add(loss.item())
 
         if compression_scheduler:
@@ -369,14 +390,27 @@ def train(train_loader, model, criterion, optimizer, epoch,
         if steps_completed % args.print_freq == 0:
             # Log some statistics
             lr = optimizer.param_groups[0]['lr']
-            stats = ('Peformance/Training/',
-                     OrderedDict([
-                         ('Loss', losses['objective_loss'].mean),
-                         ('Reg Loss', losses['regularizer_loss'].mean),
-                         ('Top1', classerr.value(1)),
-                         ('Top5', classerr.value(5)),
-                         ('LR', lr),
-                         ('Time', batch_time.mean)]))
+            if not args.earlyexit_lossweights:
+                stats = ('Peformance/Training/',
+                         OrderedDict([
+                             ('Loss', losses['objective_loss'].mean),
+                             ('Reg Loss', losses['regularizer_loss'].mean),
+                             ('Top1', classerr.value(1)),
+                             ('Top5', classerr.value(5)),
+                             ('LR', lr),
+                             ('Time', batch_time.mean)]))
+            else:
+                # for Early Exit case, the Top1 and Top5 stats are computed for each exit.
+                stats_dict = OrderedDict()
+                stats_dict['Objective Loss'] = losses['objective_loss'].mean
+                for exitnum in range(args.num_exits):
+                    t1 = 'Top1_exit' + str(exitnum)
+                    t5 = 'Top5_exit' + str(exitnum)
+                    stats_dict[t1] = args.exiterrors[exitnum].value(1)
+                    stats_dict[t5] = args.exiterrors[exitnum].value(5)
+                stats_dict['LR'] = lr
+                stats_dict['Time'] = batch_time.mean
+                stats = ('Peformance/Training/', stats_dict)
 
             params = model.named_parameters() if args.log_params_histograms else None
             distiller.log_training_progress(stats,
@@ -406,6 +440,16 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
     """Execute the validation/test loop."""
     losses = {'objective_loss': tnt.AverageValueMeter()}
     classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
+
+    if args.earlyexit_thresholds:
+        # for Early Exit, we have a list of errors and losses for each of the exits.
+        args.exiterrors = []
+        args.losses_exits = []
+        for exitnum in range(args.num_exits):
+            args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)))
+            args.losses_exits.append(tnt.AverageValueMeter())
+        args.exit_taken = [0] * args.num_exits
+    
     batch_time = tnt.AverageValueMeter()
     total_samples = len(data_loader.sampler)
     batch_size = data_loader.batch_size
@@ -423,16 +467,21 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
             target = target.cuda(async=True)
             input_var = get_inference_var(inputs)
             target_var = get_inference_var(target)
-
-            # compute output
+            # compute output from model
             output = model(input_var)
-            loss = criterion(output, target_var)
 
-            # measure accuracy and record loss
-            losses['objective_loss'].add(loss.item())
-            classerr.add(output.data, target)
-            if args.display_confusion:
-                confusion.add(output.data, target)
+            if not args.earlyexit_thresholds:
+                # compute loss
+                loss = criterion(output, target_var)
+                # measure accuracy and record loss
+                losses['objective_loss'].add(loss.item())
+                classerr.add(output.data, target)
+                if args.display_confusion:
+                    confusion.add(output.data, target)
+            else:
+                # If using Early Exit, then compute outputs at all exits - output is now a list of all exits
+                # from exit0 through exitN (i.e. [exit0, exit1, ... exitN])
+                earlyexit_validate_loss(output, target_var, criterion, args)
 
             # measure elapsed time
             batch_time.add(time.time() - end)
@@ -440,19 +489,55 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
 
             steps_completed = (validation_step+1)
             if steps_completed % args.print_freq == 0:
-                stats = ('',
+                if not args.earlyexit_thresholds:
+                    stats = ('',
                          OrderedDict([('Loss', losses['objective_loss'].mean),
                                       ('Top1', classerr.value(1)),
                                       ('Top5', classerr.value(5))]))
+                else:
+                    stats_dict = OrderedDict()
+                    stats_dict['Test'] = validation_step
+                    for exitnum in range(args.num_exits):
+                        la_string = 'LossAvg' + str(exitnum)
+                        stats_dict[la_string] = args.losses_exits[exitnum].mean
+                        # Because of the nature of ClassErrorMeter, if an exit is never taken during the batch,
+                        # then accessing the value(k) will cause a divide by zero. So we'll build the OrderedDict
+                        # accordingly and we will not print for an exit error when that exit is never taken.
+                        if args.exit_taken[exitnum]:
+                            t1 = 'Top1_exit' + str(exitnum)
+                            t5 = 'Top5_exit' + str(exitnum)
+                            stats_dict[t1] = args.exiterrors[exitnum].value(1)
+                            stats_dict[t5] = args.exiterrors[exitnum].value(5)
+                    stats = ('Performance/Validation/', stats_dict)
+
                 distiller.log_training_progress(stats, None, epoch, steps_completed,
                                                 total_steps, args.print_freq, loggers)
-
-    msglogger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n',
-                   classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
-
-    if args.display_confusion:
-        msglogger.info('==> Confusion:\n%s', str(confusion.value()))
-    return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
+    if not args.earlyexit_thresholds:
+        msglogger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n',
+                       classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)
+        
+        if args.display_confusion:
+            msglogger.info('==> Confusion:\n%s', str(confusion.value()))
+        return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
+    else:
+        #print some interesting summary stats for number of data points that could exit early
+        top1k_stats = [0] * args.num_exits
+        top5k_stats = [0] * args.num_exits
+        losses_exits_stats = [0] * args.num_exits
+        sum_exit_stats = 0
+        for exitnum in range(args.num_exits):
+            if args.exit_taken[exitnum]:
+                sum_exit_stats += args.exit_taken[exitnum]
+                msglogger.info("Exit %d: %d", exitnum, args.exit_taken[exitnum])
+                top1k_stats[exitnum] += args.exiterrors[exitnum].value(1)
+                top5k_stats[exitnum] += args.exiterrors[exitnum].value(5)
+                losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean
+        for exitnum in range(args.num_exits):
+            if args.exit_taken[exitnum]:
+                msglogger.info("Percent Early Exit %d: %.3f", exitnum, (args.exit_taken[exitnum]*100.0) / sum_exit_stats)
+
+
+        return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1]
 
 
 class PytorchNoGrad(object):
@@ -478,6 +563,41 @@ def get_inference_var(tensor):
         return torch.autograd.Variable(tensor)
     return torch.autograd.Variable(tensor, volatile=True)
 
+def earlyexit_loss(output, target_var, criterion, args):
+    loss = 0
+    sum_lossweights = 0
+    for exitnum in range(args.num_exits-1):
+        loss += (args.earlyexit_lossweights[exitnum] * criterion(output[exitnum], target_var))
+        sum_lossweights += args.earlyexit_lossweights[exitnum]
+        args.exiterrors[exitnum].add(output[exitnum].data, target_var)
+    # handle final exit
+    loss += (1.0 - sum_lossweights) * criterion(output[args.num_exits-1], target_var)
+    args.exiterrors[args.num_exits-1].add(output[args.num_exits-1].data, target_var)
+    return loss
+
+def earlyexit_validate_loss(output, target_var, criterion, args):
+    for exitnum in range(args.num_exits):
+        args.loss_exits[exitnum] = criterion(output[exitnum], target_var)
+        args.losses_exits[exitnum].add(args.loss_exits[exitnum].item())
+
+    # We need to go through this batch itself - this is now a vector of losses through the batch.
+    # Collecting stats on which exit early can be done across the batch at this time.
+    # Note that we can't use batch_size as last batch might be smaller
+    this_batch_size = target_var.size()[0]
+    for batchnum in range(this_batch_size):
+        # take the exit using CrossEntropyLoss as confidence measure (lower is more confident)
+        for exitnum in range(args.num_exits-1):
+            if args.loss_exits[exitnum].item() < args.earlyexit_thresholds[exitnum]:
+                # take the results from early exit since lower than threshold
+                args.exiterrors[exitnum].add(torch.tensor(np.array(output[exitnum].data[batchnum], ndmin=2)),
+                        torch.full([1], target_var[batchnum], dtype=torch.long))
+                args.exit_taken[exitnum] += 1
+            else:
+                # skip the early exits and include results from end of net
+                args.exiterrors[args.num_exits-1].add(torch.tensor(np.array(output[args.num_exits-1].data[batchnum], ndmin=2)),
+                        torch.full([1], target_var[batchnum], dtype=torch.long))
+                args.exit_taken[args.num_exits-1] += 1
+
 
 def evaluate_model(model, criterion, test_loader, loggers, args):
     # This sample application can be invoked to evaluate the accuracy of your model on
diff --git a/models/cifar10/__init__.py b/models/cifar10/__init__.py
index e4f636fe642f4c9f7179454a1ce26d1f4ce454c6..32fd20bb7ca286f5714d141f2f801eb1829cba0f 100755
--- a/models/cifar10/__init__.py
+++ b/models/cifar10/__init__.py
@@ -20,3 +20,4 @@ from .simplenet_cifar import *
 from .resnet_cifar import *
 from .preresnet_cifar import *
 from .vgg_cifar import *
+from .resnet_cifar_earlyexit import *
\ No newline at end of file
diff --git a/models/cifar10/resnet_cifar_earlyexit.py b/models/cifar10/resnet_cifar_earlyexit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4f75d7288d93ff1efcc65cc7bde0332718923e7
--- /dev/null
+++ b/models/cifar10/resnet_cifar_earlyexit.py
@@ -0,0 +1,113 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Resnet for CIFAR10
+
+Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition".
+This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
+changes for the 10-class Cifar-10 dataset.
+This ResNet also has layer gates, to be able to dynamically remove layers.
+
+@inproceedings{DBLP:conf/cvpr/HeZRS16,
+  author    = {Kaiming He and
+               Xiangyu Zhang and
+               Shaoqing Ren and
+               Jian Sun},
+  title     = {Deep Residual Learning for Image Recognition},
+  booktitle = {{CVPR}},
+  pages     = {770--778},
+  publisher = {{IEEE} Computer Society},
+  year      = {2016}
+}
+
+"""
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+import torchvision.models as models
+from .resnet_cifar import BasicBlock
+from .resnet_cifar import ResNetCifar
+
+
+__all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
+    'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit']
+
+NUM_CLASSES = 10
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class ResNetCifarEarlyExit(ResNetCifar):
+
+    def __init__(self, block, layers, num_classes=NUM_CLASSES):
+        super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes)
+
+        # Define early exit layers
+        self.linear_exit0 = nn.Linear(1600, num_classes)
+
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+
+        x = self.layer1(x)
+
+        # Add early exit layers
+        exit0 = nn.functional.avg_pool2d(x, 3)
+        exit0 = exit0.view(exit0.size(0), -1)
+        exit0 = self.linear_exit0(exit0)
+
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        # return a list of probabilities
+        output = []
+        output.append(exit0)
+        output.append(x)
+        return output
+
+
+def resnet20_cifar_earlyexit(**kwargs):
+    model = ResNetCifarEarlyExit(BasicBlock, [3, 3, 3], **kwargs)
+    return model
+
+def resnet32_cifar_earlyexit(**kwargs):
+    model = ResNetCifarEarlyExit(BasicBlock, [5, 5, 5], **kwargs)
+    return model
+
+def resnet44_cifar_earlyexit(**kwargs):
+    model = ResNetCifarEarlyExit(BasicBlock, [7, 7, 7], **kwargs)
+    return model
+
+def resnet56_cifar_earlyexit(**kwargs):
+    model = ResNetCifarEarlyExit(BasicBlock, [9, 9, 9], **kwargs)
+    return model
+
+def resnet110_cifar_earlyexit(**kwargs):
+    model = ResNetCifarEarlyExit(BasicBlock, [18, 18, 18], **kwargs)
+    return model
+
+def resnet1202_cifar_earlyexit(**kwargs):
+    model = ResNetCifarEarlyExit(BasicBlock, [200, 200, 200], **kwargs)
+    return model
\ No newline at end of file
diff --git a/models/imagenet/__init__.py b/models/imagenet/__init__.py
index 300ebd50ff354f6555d51a224c7f6f4c91491b36..5ed5d8ca4eeb7b11a7b644db0ea4902c543f85c4 100755
--- a/models/imagenet/__init__.py
+++ b/models/imagenet/__init__.py
@@ -19,3 +19,4 @@
 from .mobilenet import *
 from .preresnet_imagenet import *
 from .alexnet_batchnorm import *
+from .resnet_earlyexit import *
diff --git a/models/imagenet/resnet_earlyexit.py b/models/imagenet/resnet_earlyexit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4b87423df0bff892afa9fffcaf6d0796a203367
--- /dev/null
+++ b/models/imagenet/resnet_earlyexit.py
@@ -0,0 +1,101 @@
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+import torchvision.models as models
+from torchvision.models.resnet import Bottleneck
+from torchvision.models.resnet import BasicBlock
+
+
+__all__ = ['resnet18_earlyexit', 'resnet34_earlyexit', 'resnet50_earlyexit', 'resnet101_earlyexit', 'resnet152_earlyexit']
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class ResNetEarlyExit(models.ResNet):
+
+    def __init__(self, block, layers, num_classes=1000):
+        super(ResNetEarlyExit, self).__init__(block, layers, num_classes)
+
+        # Define early exit layers
+        self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True)
+        self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True)
+        self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True)
+        self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes)
+        self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+
+        # Add early exit layers
+        exit0 = self.avgpool(x)
+        exit0 = self.conv1_exit0(exit0)
+        exit0 = self.conv2_exit0(exit0)
+        exit0 = self.avgpool(exit0)
+        exit0 = exit0.view(exit0.size(0), -1)
+        exit0 = self.fc_exit0(exit0)
+
+        x = self.layer2(x)
+
+        # Add early exit layers
+        exit1 = self.conv1_exit1(x)
+        exit1 = self.avgpool(exit1)
+        exit1 = exit1.view(exit1.size(0), -1)
+        exit1 = self.fc_exit1(exit1)
+
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+
+        # return a list of probabilities
+        output = []
+        output.append(exit0)
+        output.append(exit1)
+        output.append(x)
+        return output
+
+
+def resnet18_earlyexit(**kwargs):
+    """Constructs a ResNet-18 model.
+    """
+    model = ResNetEarlyExit(BasicBlock, [2, 2, 2, 2], **kwargs)
+    return model
+
+
+def resnet34_earlyexit(**kwargs):
+    """Constructs a ResNet-34 model.
+    """
+    model = ResNetEarlyExit(BasicBlock, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+def resnet50_earlyexit(**kwargs):
+    """Constructs a ResNet-50 model.
+    """
+    model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs)
+    return model
+
+
+def resnet101_earlyexit(**kwargs):
+    """Constructs a ResNet-101 model.
+    """
+    model = ResNetEarlyExit(Bottleneck, [3, 4, 23, 3], **kwargs)
+    return model
+
+
+def resnet152_earlyexit(**kwargs):
+    """Constructs a ResNet-152 model.
+    """
+    model = ResNetEarlyExit(Bottleneck, [3, 8, 36, 3], **kwargs)
+    return model