Skip to content
Snippets Groups Projects
Commit 4ad16ef0 authored by Guy Jacob's avatar Guy Jacob
Browse files

EarlyExitMgr: Handle case where model was DataParallel-ed after creation

parent 2291fdcc
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@
__all__ = ["EarlyExitMgr"]
import torch.nn as nn
from distiller.modules import BranchPoint
......@@ -46,9 +46,8 @@ class EarlyExitMgr(object):
"""
outputs = []
for exit_point in self.exit_points:
parent_name, node_name = _split_module_name(exit_point)
parent_module = _find_module(model, parent_name)
output = parent_module.__getattr__(node_name).output
branch_point = _get_branch_point_module(model, exit_point)
output = branch_point.output
assert output is not None
outputs.append(output)
return outputs
......@@ -61,9 +60,7 @@ class EarlyExitMgr(object):
"""
outputs = []
for exit_point in self.exit_points:
parent_name, node_name = _split_module_name(exit_point)
parent_module = _find_module(model, parent_name)
branch_point = parent_module.__getattr__(node_name)
branch_point = _get_branch_point_module(model, exit_point)
branch_point.output = None
return outputs
......@@ -81,3 +78,18 @@ def _split_module_name(mod_name):
parent = '.'.join(name_parts[:-1])
node = name_parts[-1]
return parent, node
def _get_branch_point_module(model, exit_point):
parent_name, node_name = _split_module_name(exit_point)
parent_module = _find_module(model, parent_name)
try:
branch_point = parent_module.__getattr__(node_name)
except AttributeError:
# This handles the case where the parent module was data-paralleled after model creation
if isinstance(parent_module, nn.DataParallel):
branch_point = parent_module.module.__getattr__(node_name)
assert isinstance(branch_point, BranchPoint)
else:
raise
return branch_point
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment