From 4ad16ef00f9ea90c0d7834667bf86b12e795c12e Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Tue, 31 Mar 2020 15:52:18 +0300 Subject: [PATCH] EarlyExitMgr: Handle case where model was DataParallel-ed after creation --- distiller/early_exit.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/distiller/early_exit.py b/distiller/early_exit.py index 721cf7f..07d7bf7 100644 --- a/distiller/early_exit.py +++ b/distiller/early_exit.py @@ -17,7 +17,7 @@ __all__ = ["EarlyExitMgr"] - +import torch.nn as nn from distiller.modules import BranchPoint @@ -46,9 +46,8 @@ class EarlyExitMgr(object): """ outputs = [] for exit_point in self.exit_points: - parent_name, node_name = _split_module_name(exit_point) - parent_module = _find_module(model, parent_name) - output = parent_module.__getattr__(node_name).output + branch_point = _get_branch_point_module(model, exit_point) + output = branch_point.output assert output is not None outputs.append(output) return outputs @@ -61,9 +60,7 @@ class EarlyExitMgr(object): """ outputs = [] for exit_point in self.exit_points: - parent_name, node_name = _split_module_name(exit_point) - parent_module = _find_module(model, parent_name) - branch_point = parent_module.__getattr__(node_name) + branch_point = _get_branch_point_module(model, exit_point) branch_point.output = None return outputs @@ -81,3 +78,18 @@ def _split_module_name(mod_name): parent = '.'.join(name_parts[:-1]) node = name_parts[-1] return parent, node + + +def _get_branch_point_module(model, exit_point): + parent_name, node_name = _split_module_name(exit_point) + parent_module = _find_module(model, parent_name) + try: + branch_point = parent_module.__getattr__(node_name) + except AttributeError: + # This handles the case where the parent module was data-paralleled after model creation + if isinstance(parent_module, nn.DataParallel): + branch_point = parent_module.module.__getattr__(node_name) + assert isinstance(branch_point, BranchPoint) + else: + raise + return branch_point -- GitLab