diff --git a/distiller/early_exit.py b/distiller/early_exit.py
index 721cf7f93c7756e2d3038f6d82b60307478fe383..07d7bf7661ffb058a0bf6e2aecbfab8723f84f6d 100644
--- a/distiller/early_exit.py
+++ b/distiller/early_exit.py
@@ -17,7 +17,7 @@
 
 __all__ = ["EarlyExitMgr"]
 
-
+import torch.nn as nn
 from distiller.modules import BranchPoint
 
 
@@ -46,9 +46,8 @@ class EarlyExitMgr(object):
         """
         outputs = []
         for exit_point in self.exit_points:
-            parent_name, node_name = _split_module_name(exit_point)
-            parent_module = _find_module(model, parent_name)
-            output = parent_module.__getattr__(node_name).output
+            branch_point = _get_branch_point_module(model, exit_point)
+            output = branch_point.output
             assert output is not None
             outputs.append(output)
         return outputs
@@ -61,9 +60,7 @@ class EarlyExitMgr(object):
         """
         outputs = []
         for exit_point in self.exit_points:
-            parent_name, node_name = _split_module_name(exit_point)
-            parent_module = _find_module(model, parent_name)
-            branch_point = parent_module.__getattr__(node_name)
+            branch_point = _get_branch_point_module(model, exit_point)
             branch_point.output = None
         return outputs
 
@@ -81,3 +78,18 @@ def _split_module_name(mod_name):
     parent = '.'.join(name_parts[:-1])
     node = name_parts[-1]
     return parent, node
+
+
+def _get_branch_point_module(model, exit_point):
+    parent_name, node_name = _split_module_name(exit_point)
+    parent_module = _find_module(model, parent_name)
+    try:
+        branch_point = parent_module.__getattr__(node_name)
+    except AttributeError:
+        # This handles the case where the parent module was data-paralleled after model creation
+        if isinstance(parent_module, nn.DataParallel):
+            branch_point = parent_module.module.__getattr__(node_name)
+            assert isinstance(branch_point, BranchPoint)
+        else:
+            raise
+    return branch_point