From 4ad16ef00f9ea90c0d7834667bf86b12e795c12e Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Tue, 31 Mar 2020 15:52:18 +0300
Subject: [PATCH] EarlyExitMgr: Handle case where model was DataParallel-ed
 after creation

---
 distiller/early_exit.py | 26 +++++++++++++++++++-------
 1 file changed, 19 insertions(+), 7 deletions(-)

diff --git a/distiller/early_exit.py b/distiller/early_exit.py
index 721cf7f..07d7bf7 100644
--- a/distiller/early_exit.py
+++ b/distiller/early_exit.py
@@ -17,7 +17,7 @@
 
 __all__ = ["EarlyExitMgr"]
 
-
+import torch.nn as nn
 from distiller.modules import BranchPoint
 
 
@@ -46,9 +46,8 @@ class EarlyExitMgr(object):
         """
         outputs = []
         for exit_point in self.exit_points:
-            parent_name, node_name = _split_module_name(exit_point)
-            parent_module = _find_module(model, parent_name)
-            output = parent_module.__getattr__(node_name).output
+            branch_point = _get_branch_point_module(model, exit_point)
+            output = branch_point.output
             assert output is not None
             outputs.append(output)
         return outputs
@@ -61,9 +60,7 @@ class EarlyExitMgr(object):
         """
         outputs = []
         for exit_point in self.exit_points:
-            parent_name, node_name = _split_module_name(exit_point)
-            parent_module = _find_module(model, parent_name)
-            branch_point = parent_module.__getattr__(node_name)
+            branch_point = _get_branch_point_module(model, exit_point)
             branch_point.output = None
         return outputs
 
@@ -81,3 +78,18 @@ def _split_module_name(mod_name):
     parent = '.'.join(name_parts[:-1])
     node = name_parts[-1]
     return parent, node
+
+
+def _get_branch_point_module(model, exit_point):
+    parent_name, node_name = _split_module_name(exit_point)
+    parent_module = _find_module(model, parent_name)
+    try:
+        branch_point = parent_module.__getattr__(node_name)
+    except AttributeError:
+        # This handles the case where the parent module was data-paralleled after model creation
+        if isinstance(parent_module, nn.DataParallel):
+            branch_point = parent_module.module.__getattr__(node_name)
+            assert isinstance(branch_point, BranchPoint)
+        else:
+            raise
+    return branch_point
-- 
GitLab