diff --git a/distiller/early_exit.py b/distiller/early_exit.py index 721cf7f93c7756e2d3038f6d82b60307478fe383..07d7bf7661ffb058a0bf6e2aecbfab8723f84f6d 100644 --- a/distiller/early_exit.py +++ b/distiller/early_exit.py @@ -17,7 +17,7 @@ __all__ = ["EarlyExitMgr"] - +import torch.nn as nn from distiller.modules import BranchPoint @@ -46,9 +46,8 @@ class EarlyExitMgr(object): """ outputs = [] for exit_point in self.exit_points: - parent_name, node_name = _split_module_name(exit_point) - parent_module = _find_module(model, parent_name) - output = parent_module.__getattr__(node_name).output + branch_point = _get_branch_point_module(model, exit_point) + output = branch_point.output assert output is not None outputs.append(output) return outputs @@ -61,9 +60,7 @@ class EarlyExitMgr(object): """ outputs = [] for exit_point in self.exit_points: - parent_name, node_name = _split_module_name(exit_point) - parent_module = _find_module(model, parent_name) - branch_point = parent_module.__getattr__(node_name) + branch_point = _get_branch_point_module(model, exit_point) branch_point.output = None return outputs @@ -81,3 +78,18 @@ def _split_module_name(mod_name): parent = '.'.join(name_parts[:-1]) node = name_parts[-1] return parent, node + + +def _get_branch_point_module(model, exit_point): + parent_name, node_name = _split_module_name(exit_point) + parent_module = _find_module(model, parent_name) + try: + branch_point = parent_module.__getattr__(node_name) + except AttributeError: + # This handles the case where the parent module was data-paralleled after model creation + if isinstance(parent_module, nn.DataParallel): + branch_point = parent_module.module.__getattr__(node_name) + assert isinstance(branch_point, BranchPoint) + else: + raise + return branch_point