diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py
index 6f0fd7ab11ca35492d2432d309dd9083f0b7c2c2..c5784343241006b5bd45788600cbc927421019c6 100755
--- a/distiller/models/imagenet/resnet.py
+++ b/distiller/models/imagenet/resnet.py
@@ -31,7 +31,8 @@ from distiller.modules import EltwiseAdd
 
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
-           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'DistillerBottleneck']
 
 
 class DistillerBasicBlock(BasicBlock):
diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py
index 03fd6c9b165a3f2ed988099374914feae13b7bce..def335c7d3570edab6b3a955361e67213f4e0589 100644
--- a/distiller/models/imagenet/resnet_earlyexit.py
+++ b/distiller/models/imagenet/resnet_earlyexit.py
@@ -1,9 +1,9 @@
 import torch.nn as nn
 import torchvision.models as models
-from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
 from .resnet import DistillerBottleneck
 import distiller
 
+
 __all__ = ['resnet50_earlyexit']
 
 
@@ -15,18 +15,15 @@ def conv3x3(in_planes, out_planes, stride=1):
 
 def get_exits_def(num_classes):
     expansion = 1 # models.ResNet.BasicBlock.expansion
-    exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
-                                                  nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True),
-                                                  nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True),
-                                                  #nn.AdaptiveAvgPool2d((1, 1)),
+    exits_def = [('layer1.2.relu3', nn.Sequential(nn.Conv2d(256, 10, kernel_size=7, stride=2, padding=3, bias=True),
+                                                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                                                   nn.Flatten(),
-                                                  nn.Linear(12 * expansion, num_classes))),
+                                                  nn.Linear(1960, num_classes))),
                                                   #distiller.modules.Print())),
                  ('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True),
-                                                  nn.AdaptiveAvgPool2d((1, 1)),
+                                                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                                                   nn.Flatten(),
-                                                  #distiller.modules.Print()))]
-                                                  nn.Linear(12 * expansion, num_classes)))]
+                                                  nn.Linear(588, num_classes)))]
     return exits_def