diff --git a/distiller/models/cifar10/resnet_cifar_earlyexit.py b/distiller/models/cifar10/resnet_cifar_earlyexit.py
index c4f75d7288d93ff1efcc65cc7bde0332718923e7..26b67c5bbb91002602c8604ec858d6f95124d29d 100644
--- a/distiller/models/cifar10/resnet_cifar_earlyexit.py
+++ b/distiller/models/cifar10/resnet_cifar_earlyexit.py
@@ -43,7 +43,7 @@ from .resnet_cifar import ResNetCifar
 
 
 __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
-    'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit']
+           'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit']
 
 NUM_CLASSES = 10
 
@@ -53,38 +53,47 @@ def conv3x3(in_planes, out_planes, stride=1):
                      padding=1, bias=False)
 
 
-class ResNetCifarEarlyExit(ResNetCifar):
+class ExitBranch(nn.Module):
+    def __init__(self, num_classes):
+        super().__init__()
+        self.avg_pool = nn.AvgPool2d(3)
+        self.linear = nn.Linear(1600, num_classes)
 
-    def __init__(self, block, layers, num_classes=NUM_CLASSES):
-        super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes)
+    def forward(self, x):
+        x = self.avg_pool(x)
+        x = x.view(x.size(0), -1)
+        x = self.linear(x)
+        return x
 
-        # Define early exit layers
-        self.linear_exit0 = nn.Linear(1600, num_classes)
 
+class BranchPoint(nn.Module):
+    def __init__(self, original_m, exit_m):
+        super().__init__()
+        self.original_m = original_m
+        self.exit_m = exit_m
+        self.output = None
 
     def forward(self, x):
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
+        x1 = self.original_m.forward(x)
+        x2 = self.exit_m.forward(x1)
+        self.output = x2
+        return x1
 
-        x = self.layer1(x)
 
-        # Add early exit layers
-        exit0 = nn.functional.avg_pool2d(x, 3)
-        exit0 = exit0.view(exit0.size(0), -1)
-        exit0 = self.linear_exit0(exit0)
+class ResNetCifarEarlyExit(ResNetCifar):
+    def __init__(self, block, layers, num_classes=NUM_CLASSES):
+        super().__init__(block, layers, num_classes)
 
-        x = self.layer2(x)
-        x = self.layer3(x)
+        # Define early exit branches and install them
+        self.exit_branch = ExitBranch(num_classes)
+        self.layer1 = BranchPoint(self.layer1, self.exit_branch)
 
-        x = self.avgpool(x)
-        x = x.view(x.size(0), -1)
-        x = self.fc(x)
+    def forward(self, x):
+        x = super().forward(x)
 
         # return a list of probabilities
-        output = []
-        output.append(exit0)
-        output.append(x)
+        exit0 = self.layer1.output
+        output = (exit0, x)
         return output