Skip to content
Snippets Groups Projects
Unverified Commit db531db8 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

DistillerModuleList conversion: Handle models w. duplicate modules (#338)

* By duplicate modules we mean:
  self.relu1 = nn.Relu()
  self.relu2 = self.relu1
* The issue:
  The second module ('relu2') will not be returned by
  torch.nn.Module.named_modules/children()
* When converting to DistillerModuleList, in order to maintain the
  original order of modules and in order to have a correct mapping
  of names before/after the conversion - we need to take the duplicates
  into account
* Implemented an internal version of named_modules/children that includes
  duplicates
* Added test case for this + refactored the module list conversion tests
parent 69b1452a
No related branches found
No related tags found
No related merge requests found
......@@ -706,6 +706,24 @@ class _DistillerModuleList(object):
return main_str
def _named_children_with_duplicates(module):
"""Version of torch.nn.Module.named_children() that includes duplicate modules"""
for name, module in module._modules.items():
if module is not None:
yield name, module
def _named_modules_with_duplicates(module, prefix=''):
"""Version of torch.nn.Module.named_modules() that includes duplicate modules"""
yield prefix, module
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in _named_modules_with_duplicates(submodule, submodule_prefix):
yield m
def _to_distiller_modulelist(model):
"""Replaces all instances of torch.nn.ModuleList in a model with DistillerModuleList instances
......@@ -713,9 +731,11 @@ def _to_distiller_modulelist(model):
model (torch.nn.Module): Model to convert
"""
def convert_container(container):
named_children = OrderedDict(container.named_children())
# To maintain a similar order of registered modules compared to the original container, we unregister
# all modules and then register them again
# We take care to include duplicated modules, which are not returned by the original named_moduels/children
# implementation in torch.nn.Module
named_children = OrderedDict(_named_children_with_duplicates(container))
for n, _ in named_children.items():
delattr(container, n)
for name, child in named_children.items():
......@@ -732,9 +752,10 @@ def _to_distiller_modulelist(model):
convert_container(m)
return container
named_modules_orig = OrderedDict([(n, m) for n, m in model.named_modules() if not isinstance(m, nn.ModuleList)])
named_modules_orig = OrderedDict([(n, m) for n, m in _named_modules_with_duplicates(model)
if not isinstance(m, nn.ModuleList)])
model = convert_container(model)
named_modules_dmlist = OrderedDict(model.named_modules())
named_modules_dmlist = OrderedDict(_named_modules_with_duplicates(model))
converted_module_names_map = OrderedDict(zip(named_modules_dmlist.keys(), named_modules_orig.keys()))
return model, converted_module_names_map
......@@ -26,7 +26,8 @@ from distiller.apputils import *
from distiller import normalize_module_name, denormalize_module_name, \
SummaryGraph, onnx_name_2_pytorch_name
from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose
from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, _to_distiller_modulelist
from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, \
_to_distiller_modulelist, _named_modules_with_duplicates
# Logging configuration
logging.basicConfig(level=logging.DEBUG)
......@@ -399,156 +400,187 @@ def test_adjacency_map(parallel, dedicated_modules):
check_adj_entry(adj_map[name], expected)
def test_distiller_module_list():
#############################################################
# Model for testing conversion of nested ModuleLists
#############################################################
class ListsModule(nn.Module):
def __init__(self):
super(ListsModule, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 5)
self.post_conv1 = nn.ModuleList([
nn.BatchNorm2d(10),
nn.ModuleList([
nn.ReLU(),
nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])])
self.conv2 = nn.Conv2d(10, 20, 3)
self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)])
self.expected_mlist_to_dmlist = OrderedDict([
('post_conv1', ['post_conv1']),
('post_conv1.1', ['post_conv1', '1']),
('post_conv1.1.1', ['post_conv1', '1', '1']),
('post_conv2', ['post_conv2']),
])
self.expected_list_contents_name_changes = OrderedDict([
('post_conv1.0', 'post_conv1_0'),
('post_conv1.1.0', 'post_conv1_1_0'),
('post_conv1.1.1.0', 'post_conv1_1_1_0'),
('post_conv1.1.1.1', 'post_conv1_1_1_1'),
('post_conv2.0', 'post_conv2_0'),
('post_conv2.1', 'post_conv2_1'),
])
def forward(self, x):
x = self.conv1(x)
x = self.post_conv1[0](x)
x = self.post_conv1[1][0](x)
for m in self.post_conv1[1][1]:
x = m(x)
x = self.conv2(x)
for m in self.post_conv2:
x = m(x)
return x
#############################################################
# Model for testing conversion in case of nested containers
#############################################################
class Block(nn.Module):
def __init__(self, in_ch):
super(Block, self).__init__()
self.in_ch = in_ch
self.out_ch = in_ch * 2
self.conv = nn.Conv2d(in_ch, self.out_ch, 3)
self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()])
def forward(self, x):
x = self.conv(x)
for m in self.post_conv:
x = m(x)
return x
class BlocksModule(nn.Module):
def __init__(self):
super(BlocksModule, self).__init__()
self.block1 = Block(3)
self.blocks2_3 = nn.Sequential(Block(6), Block(12))
self.blocks4_5 = nn.ModuleList([Block(24), Block(48)])
self.block6 = Block(96)
self.expected_mlist_to_dmlist = OrderedDict([
('block1.post_conv', ['block1', 'post_conv']),
('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']),
('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']),
('blocks4_5', ['blocks4_5']),
('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']),
('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']),
('block6.post_conv', ['block6', 'post_conv']),
])
self.expected_list_contents_name_changes = OrderedDict([
('block1.post_conv.0', 'block1.post_conv_0'),
('block1.post_conv.1', 'block1.post_conv_1'),
('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'),
('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'),
('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'),
('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'),
('blocks4_5.0', 'blocks4_5_0'),
('blocks4_5.0.conv', 'blocks4_5_0.conv'),
('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'),
('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'),
('blocks4_5.1', 'blocks4_5_1'),
('blocks4_5.1.conv', 'blocks4_5_1.conv'),
('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'),
('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'),
('block6.post_conv.0', 'block6.post_conv_0'),
('block6.post_conv.1', 'block6.post_conv_1'),
])
def forward(self, x):
x = self.block1(x)
x = self.blocks2_3(x)
for block in self.blocks4_5:
x = block(x)
x = self.block6(x)
return x
def check(m):
def check_equal_tensors(actual, expected):
assert (actual == expected).all().item() == 1
m_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(m))
# Check all modules converted as expected
named_modules_dmlist = OrderedDict(m_dml.named_modules())
for name_orig, module_orig in m.named_modules():
if name_orig in m.expected_mlist_to_dmlist:
# Check ModuleLists were converted to an attribute with the expected name, which is not
# registered as a module in the converted model
assert name_orig not in named_modules_dmlist
attr_dml = m_dml
for attr_name in m.expected_mlist_to_dmlist[name_orig]:
try:
attr_dml = attr_dml[int(attr_name)]
except ValueError:
attr_dml = getattr(attr_dml, attr_name)
assert isinstance(attr_dml, _DistillerModuleList)
else:
# Check module name changed as expected, and that the module type didn't change
expected_name_dml = m.expected_list_contents_name_changes.get(name_orig, name_orig)
assert expected_name_dml in named_modules_dmlist
assert expected_name_dml in converted_module_names_map
assert converted_module_names_map[expected_name_dml] == name_orig
assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig)
converted_module_names_map.pop(expected_name_dml)
named_modules_dmlist.pop(expected_name_dml)
assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map'
assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules'
# Now make sure all parameters and buffers didn't change
for p_orig, p_dml in zip(m.parameters(), m_dml.parameters()):
check_equal_tensors(p_dml, p_orig)
for b_orig, b_dml in zip(m.buffers(), m_dml.buffers()):
check_equal_tensors(b_dml, b_orig)
# Check forward pass gives identical results
x = torch.randn(1, 3, 50, 50)
y_orig = m(x)
y_dml = m_dml(x)
check_equal_tensors(y_dml, y_orig)
check(ListsModule())
check(BlocksModule())
#############################################################
# Conversion to DistillerModuleList tests
#############################################################
# Model for testing conversion of nested ModuleLists
class ListsModule(nn.Module):
def __init__(self):
super(ListsModule, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 5)
self.post_conv1 = nn.ModuleList([
nn.BatchNorm2d(10),
nn.ModuleList([
nn.ReLU(),
nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])])
self.conv2 = nn.Conv2d(10, 20, 3)
self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)])
self.expected_mlist_to_dmlist = OrderedDict([
('post_conv1', ['post_conv1']),
('post_conv1.1', ['post_conv1', '1']),
('post_conv1.1.1', ['post_conv1', '1', '1']),
('post_conv2', ['post_conv2']),
])
self.expected_list_contents_name_changes = OrderedDict([
('post_conv1.0', 'post_conv1_0'),
('post_conv1.1.0', 'post_conv1_1_0'),
('post_conv1.1.1.0', 'post_conv1_1_1_0'),
('post_conv1.1.1.1', 'post_conv1_1_1_1'),
('post_conv2.0', 'post_conv2_0'),
('post_conv2.1', 'post_conv2_1'),
])
def forward(self, x):
x = self.conv1(x)
x = self.post_conv1[0](x)
x = self.post_conv1[1][0](x)
for m in self.post_conv1[1][1]:
x = m(x)
x = self.conv2(x)
for m in self.post_conv2:
x = m(x)
return x
# Model for testing conversion in case of nested containers
class Block(nn.Module):
def __init__(self, in_ch):
super(Block, self).__init__()
self.in_ch = in_ch
self.out_ch = in_ch * 2
self.conv = nn.Conv2d(in_ch, self.out_ch, 3)
self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()])
def forward(self, x):
x = self.conv(x)
for m in self.post_conv:
x = m(x)
return x
class BlocksModule(nn.Module):
def __init__(self):
super(BlocksModule, self).__init__()
self.block1 = Block(3)
self.blocks2_3 = nn.Sequential(Block(6), Block(12))
self.blocks4_5 = nn.ModuleList([Block(24), Block(48)])
self.block6 = Block(96)
self.expected_mlist_to_dmlist = OrderedDict([
('block1.post_conv', ['block1', 'post_conv']),
('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']),
('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']),
('blocks4_5', ['blocks4_5']),
('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']),
('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']),
('block6.post_conv', ['block6', 'post_conv']),
])
self.expected_list_contents_name_changes = OrderedDict([
('block1.post_conv.0', 'block1.post_conv_0'),
('block1.post_conv.1', 'block1.post_conv_1'),
('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'),
('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'),
('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'),
('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'),
('blocks4_5.0', 'blocks4_5_0'),
('blocks4_5.0.conv', 'blocks4_5_0.conv'),
('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'),
('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'),
('blocks4_5.1', 'blocks4_5_1'),
('blocks4_5.1.conv', 'blocks4_5_1.conv'),
('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'),
('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'),
('block6.post_conv.0', 'block6.post_conv_0'),
('block6.post_conv.1', 'block6.post_conv_1'),
])
def forward(self, x):
x = self.block1(x)
x = self.blocks2_3(x)
for block in self.blocks4_5:
x = block(x)
x = self.block6(x)
return x
# Model with duplicate modules
class ModelWithDuplicates(nn.Module):
def __init__(self):
super(ModelWithDuplicates, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 5)
self.post_conv1 = nn.ModuleList([nn.ReLU(), nn.Tanh()])
self.conv2 = nn.Conv2d(10, 20, 3)
self.post_conv2 = self.post_conv1
self.expected_mlist_to_dmlist = OrderedDict([
('post_conv1', ['post_conv1']),
('post_conv2', ['post_conv2']),
])
self.expected_list_contents_name_changes = OrderedDict([
('post_conv1.0', 'post_conv1_0'),
('post_conv1.1', 'post_conv1_1'),
('post_conv2.0', 'post_conv2_0'),
('post_conv2.1', 'post_conv2_1'),
])
def forward(self, x):
x = self.conv1(x)
for m in self.post_conv1:
x = m(x)
x = self.conv2(x)
for m in self.post_conv2:
x = m(x)
return x
@pytest.mark.parametrize("model", [ListsModule(), BlocksModule(), ModelWithDuplicates()],
ids=['ListsModule', 'BlocksModule', 'ModelWithDuplicates'])
def test_distiller_module_list_conversion(model):
def check_equal_tensors(actual, expected):
assert (actual == expected).all().item() == 1
model_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(model))
# Check all modules converted as expected
named_modules_dmlist = OrderedDict(_named_modules_with_duplicates(model_dml))
for name_orig, module_orig in _named_modules_with_duplicates(model):
if name_orig in model.expected_mlist_to_dmlist:
# Check ModuleLists were converted to an attribute with the expected name, which is not
# registered as a module in the converted model
assert name_orig not in named_modules_dmlist
attr_dml = model_dml
for attr_name in model.expected_mlist_to_dmlist[name_orig]:
try:
attr_dml = attr_dml[int(attr_name)]
except ValueError:
attr_dml = getattr(attr_dml, attr_name)
assert isinstance(attr_dml, _DistillerModuleList)
else:
# Check module name changed as expected, and that the module type didn't change
expected_name_dml = model.expected_list_contents_name_changes.get(name_orig, name_orig)
assert expected_name_dml in named_modules_dmlist
assert expected_name_dml in converted_module_names_map
assert converted_module_names_map[expected_name_dml] == name_orig
assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig)
converted_module_names_map.pop(expected_name_dml)
named_modules_dmlist.pop(expected_name_dml)
assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map'
assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules'
# Now make sure all parameters and buffers didn't change
for p_orig, p_dml in zip(model.parameters(), model_dml.parameters()):
check_equal_tensors(p_dml, p_orig)
for b_orig, b_dml in zip(model.buffers(), model_dml.buffers()):
check_equal_tensors(b_dml, b_orig)
# Check forward pass gives identical results
x = torch.randn(1, 3, 50, 50)
y_orig = model(x)
y_dml = model_dml(x)
check_equal_tensors(y_dml, y_orig)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment