From db531db8e88a07031e45a54b96b07d1528b5da00 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 29 Jul 2019 14:10:23 +0300 Subject: [PATCH] DistillerModuleList conversion: Handle models w. duplicate modules (#338) * By duplicate modules we mean: self.relu1 = nn.Relu() self.relu2 = self.relu1 * The issue: The second module ('relu2') will not be returned by torch.nn.Module.named_modules/children() * When converting to DistillerModuleList, in order to maintain the original order of modules and in order to have a correct mapping of names before/after the conversion - we need to take the duplicates into account * Implemented an internal version of named_modules/children that includes duplicates * Added test case for this + refactored the module list conversion tests --- distiller/summary_graph.py | 27 ++- tests/test_summarygraph.py | 334 ++++++++++++++++++++----------------- 2 files changed, 207 insertions(+), 154 deletions(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 7fb5ab5..0c43a3a 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -706,6 +706,24 @@ class _DistillerModuleList(object): return main_str +def _named_children_with_duplicates(module): + """Version of torch.nn.Module.named_children() that includes duplicate modules""" + for name, module in module._modules.items(): + if module is not None: + yield name, module + + +def _named_modules_with_duplicates(module, prefix=''): + """Version of torch.nn.Module.named_modules() that includes duplicate modules""" + yield prefix, module + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in _named_modules_with_duplicates(submodule, submodule_prefix): + yield m + + def _to_distiller_modulelist(model): """Replaces all instances of torch.nn.ModuleList in a model with DistillerModuleList instances @@ -713,9 +731,11 @@ def _to_distiller_modulelist(model): model (torch.nn.Module): Model to convert """ def convert_container(container): - named_children = OrderedDict(container.named_children()) # To maintain a similar order of registered modules compared to the original container, we unregister # all modules and then register them again + # We take care to include duplicated modules, which are not returned by the original named_moduels/children + # implementation in torch.nn.Module + named_children = OrderedDict(_named_children_with_duplicates(container)) for n, _ in named_children.items(): delattr(container, n) for name, child in named_children.items(): @@ -732,9 +752,10 @@ def _to_distiller_modulelist(model): convert_container(m) return container - named_modules_orig = OrderedDict([(n, m) for n, m in model.named_modules() if not isinstance(m, nn.ModuleList)]) + named_modules_orig = OrderedDict([(n, m) for n, m in _named_modules_with_duplicates(model) + if not isinstance(m, nn.ModuleList)]) model = convert_container(model) - named_modules_dmlist = OrderedDict(model.named_modules()) + named_modules_dmlist = OrderedDict(_named_modules_with_duplicates(model)) converted_module_names_map = OrderedDict(zip(named_modules_dmlist.keys(), named_modules_orig.keys())) return model, converted_module_names_map diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index e220c90..293f69b 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -26,7 +26,8 @@ from distiller.apputils import * from distiller import normalize_module_name, denormalize_module_name, \ SummaryGraph, onnx_name_2_pytorch_name from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose -from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, _to_distiller_modulelist +from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, \ + _to_distiller_modulelist, _named_modules_with_duplicates # Logging configuration logging.basicConfig(level=logging.DEBUG) @@ -399,156 +400,187 @@ def test_adjacency_map(parallel, dedicated_modules): check_adj_entry(adj_map[name], expected) -def test_distiller_module_list(): - ############################################################# - # Model for testing conversion of nested ModuleLists - ############################################################# - class ListsModule(nn.Module): - def __init__(self): - super(ListsModule, self).__init__() - self.conv1 = nn.Conv2d(3, 10, 5) - self.post_conv1 = nn.ModuleList([ - nn.BatchNorm2d(10), - nn.ModuleList([ - nn.ReLU(), - nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])]) - self.conv2 = nn.Conv2d(10, 20, 3) - self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)]) - - self.expected_mlist_to_dmlist = OrderedDict([ - ('post_conv1', ['post_conv1']), - ('post_conv1.1', ['post_conv1', '1']), - ('post_conv1.1.1', ['post_conv1', '1', '1']), - ('post_conv2', ['post_conv2']), - ]) - self.expected_list_contents_name_changes = OrderedDict([ - ('post_conv1.0', 'post_conv1_0'), - ('post_conv1.1.0', 'post_conv1_1_0'), - ('post_conv1.1.1.0', 'post_conv1_1_1_0'), - ('post_conv1.1.1.1', 'post_conv1_1_1_1'), - ('post_conv2.0', 'post_conv2_0'), - ('post_conv2.1', 'post_conv2_1'), - ]) - - def forward(self, x): - x = self.conv1(x) - x = self.post_conv1[0](x) - x = self.post_conv1[1][0](x) - for m in self.post_conv1[1][1]: - x = m(x) - x = self.conv2(x) - for m in self.post_conv2: - x = m(x) - return x - - ############################################################# - # Model for testing conversion in case of nested containers - ############################################################# - class Block(nn.Module): - def __init__(self, in_ch): - super(Block, self).__init__() - self.in_ch = in_ch - self.out_ch = in_ch * 2 - self.conv = nn.Conv2d(in_ch, self.out_ch, 3) - self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()]) - - def forward(self, x): - x = self.conv(x) - for m in self.post_conv: - x = m(x) - return x - - class BlocksModule(nn.Module): - def __init__(self): - super(BlocksModule, self).__init__() - self.block1 = Block(3) - self.blocks2_3 = nn.Sequential(Block(6), Block(12)) - self.blocks4_5 = nn.ModuleList([Block(24), Block(48)]) - self.block6 = Block(96) - - self.expected_mlist_to_dmlist = OrderedDict([ - ('block1.post_conv', ['block1', 'post_conv']), - ('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']), - ('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']), - ('blocks4_5', ['blocks4_5']), - ('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']), - ('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']), - ('block6.post_conv', ['block6', 'post_conv']), - ]) - self.expected_list_contents_name_changes = OrderedDict([ - ('block1.post_conv.0', 'block1.post_conv_0'), - ('block1.post_conv.1', 'block1.post_conv_1'), - ('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'), - ('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'), - ('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'), - ('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'), - ('blocks4_5.0', 'blocks4_5_0'), - ('blocks4_5.0.conv', 'blocks4_5_0.conv'), - ('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'), - ('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'), - ('blocks4_5.1', 'blocks4_5_1'), - ('blocks4_5.1.conv', 'blocks4_5_1.conv'), - ('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'), - ('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'), - ('block6.post_conv.0', 'block6.post_conv_0'), - ('block6.post_conv.1', 'block6.post_conv_1'), - ]) - - def forward(self, x): - x = self.block1(x) - x = self.blocks2_3(x) - for block in self.blocks4_5: - x = block(x) - x = self.block6(x) - return x - - def check(m): - def check_equal_tensors(actual, expected): - assert (actual == expected).all().item() == 1 - - m_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(m)) - - # Check all modules converted as expected - named_modules_dmlist = OrderedDict(m_dml.named_modules()) - for name_orig, module_orig in m.named_modules(): - if name_orig in m.expected_mlist_to_dmlist: - # Check ModuleLists were converted to an attribute with the expected name, which is not - # registered as a module in the converted model - assert name_orig not in named_modules_dmlist - attr_dml = m_dml - for attr_name in m.expected_mlist_to_dmlist[name_orig]: - try: - attr_dml = attr_dml[int(attr_name)] - except ValueError: - attr_dml = getattr(attr_dml, attr_name) - assert isinstance(attr_dml, _DistillerModuleList) - else: - # Check module name changed as expected, and that the module type didn't change - expected_name_dml = m.expected_list_contents_name_changes.get(name_orig, name_orig) - assert expected_name_dml in named_modules_dmlist - assert expected_name_dml in converted_module_names_map - assert converted_module_names_map[expected_name_dml] == name_orig - assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig) - converted_module_names_map.pop(expected_name_dml) - named_modules_dmlist.pop(expected_name_dml) - - assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map' - assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules' - - # Now make sure all parameters and buffers didn't change - for p_orig, p_dml in zip(m.parameters(), m_dml.parameters()): - check_equal_tensors(p_dml, p_orig) - for b_orig, b_dml in zip(m.buffers(), m_dml.buffers()): - check_equal_tensors(b_dml, b_orig) - - # Check forward pass gives identical results - x = torch.randn(1, 3, 50, 50) - y_orig = m(x) - y_dml = m_dml(x) - check_equal_tensors(y_dml, y_orig) - - check(ListsModule()) - check(BlocksModule()) +############################################################# +# Conversion to DistillerModuleList tests +############################################################# + +# Model for testing conversion of nested ModuleLists +class ListsModule(nn.Module): + def __init__(self): + super(ListsModule, self).__init__() + self.conv1 = nn.Conv2d(3, 10, 5) + self.post_conv1 = nn.ModuleList([ + nn.BatchNorm2d(10), + nn.ModuleList([ + nn.ReLU(), + nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])]) + self.conv2 = nn.Conv2d(10, 20, 3) + self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)]) + + self.expected_mlist_to_dmlist = OrderedDict([ + ('post_conv1', ['post_conv1']), + ('post_conv1.1', ['post_conv1', '1']), + ('post_conv1.1.1', ['post_conv1', '1', '1']), + ('post_conv2', ['post_conv2']), + ]) + self.expected_list_contents_name_changes = OrderedDict([ + ('post_conv1.0', 'post_conv1_0'), + ('post_conv1.1.0', 'post_conv1_1_0'), + ('post_conv1.1.1.0', 'post_conv1_1_1_0'), + ('post_conv1.1.1.1', 'post_conv1_1_1_1'), + ('post_conv2.0', 'post_conv2_0'), + ('post_conv2.1', 'post_conv2_1'), + ]) + + def forward(self, x): + x = self.conv1(x) + x = self.post_conv1[0](x) + x = self.post_conv1[1][0](x) + for m in self.post_conv1[1][1]: + x = m(x) + x = self.conv2(x) + for m in self.post_conv2: + x = m(x) + return x + + +# Model for testing conversion in case of nested containers +class Block(nn.Module): + def __init__(self, in_ch): + super(Block, self).__init__() + self.in_ch = in_ch + self.out_ch = in_ch * 2 + self.conv = nn.Conv2d(in_ch, self.out_ch, 3) + self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()]) + + def forward(self, x): + x = self.conv(x) + for m in self.post_conv: + x = m(x) + return x + + +class BlocksModule(nn.Module): + def __init__(self): + super(BlocksModule, self).__init__() + self.block1 = Block(3) + self.blocks2_3 = nn.Sequential(Block(6), Block(12)) + self.blocks4_5 = nn.ModuleList([Block(24), Block(48)]) + self.block6 = Block(96) + + self.expected_mlist_to_dmlist = OrderedDict([ + ('block1.post_conv', ['block1', 'post_conv']), + ('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']), + ('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']), + ('blocks4_5', ['blocks4_5']), + ('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']), + ('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']), + ('block6.post_conv', ['block6', 'post_conv']), + ]) + self.expected_list_contents_name_changes = OrderedDict([ + ('block1.post_conv.0', 'block1.post_conv_0'), + ('block1.post_conv.1', 'block1.post_conv_1'), + ('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'), + ('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'), + ('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'), + ('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'), + ('blocks4_5.0', 'blocks4_5_0'), + ('blocks4_5.0.conv', 'blocks4_5_0.conv'), + ('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'), + ('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'), + ('blocks4_5.1', 'blocks4_5_1'), + ('blocks4_5.1.conv', 'blocks4_5_1.conv'), + ('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'), + ('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'), + ('block6.post_conv.0', 'block6.post_conv_0'), + ('block6.post_conv.1', 'block6.post_conv_1'), + ]) + + def forward(self, x): + x = self.block1(x) + x = self.blocks2_3(x) + for block in self.blocks4_5: + x = block(x) + x = self.block6(x) + return x + + +# Model with duplicate modules +class ModelWithDuplicates(nn.Module): + def __init__(self): + super(ModelWithDuplicates, self).__init__() + self.conv1 = nn.Conv2d(3, 10, 5) + self.post_conv1 = nn.ModuleList([nn.ReLU(), nn.Tanh()]) + self.conv2 = nn.Conv2d(10, 20, 3) + self.post_conv2 = self.post_conv1 + + self.expected_mlist_to_dmlist = OrderedDict([ + ('post_conv1', ['post_conv1']), + ('post_conv2', ['post_conv2']), + ]) + self.expected_list_contents_name_changes = OrderedDict([ + ('post_conv1.0', 'post_conv1_0'), + ('post_conv1.1', 'post_conv1_1'), + ('post_conv2.0', 'post_conv2_0'), + ('post_conv2.1', 'post_conv2_1'), + ]) + + def forward(self, x): + x = self.conv1(x) + for m in self.post_conv1: + x = m(x) + x = self.conv2(x) + for m in self.post_conv2: + x = m(x) + return x + + +@pytest.mark.parametrize("model", [ListsModule(), BlocksModule(), ModelWithDuplicates()], + ids=['ListsModule', 'BlocksModule', 'ModelWithDuplicates']) +def test_distiller_module_list_conversion(model): + def check_equal_tensors(actual, expected): + assert (actual == expected).all().item() == 1 + + model_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(model)) + + # Check all modules converted as expected + named_modules_dmlist = OrderedDict(_named_modules_with_duplicates(model_dml)) + for name_orig, module_orig in _named_modules_with_duplicates(model): + if name_orig in model.expected_mlist_to_dmlist: + # Check ModuleLists were converted to an attribute with the expected name, which is not + # registered as a module in the converted model + assert name_orig not in named_modules_dmlist + attr_dml = model_dml + for attr_name in model.expected_mlist_to_dmlist[name_orig]: + try: + attr_dml = attr_dml[int(attr_name)] + except ValueError: + attr_dml = getattr(attr_dml, attr_name) + assert isinstance(attr_dml, _DistillerModuleList) + else: + # Check module name changed as expected, and that the module type didn't change + expected_name_dml = model.expected_list_contents_name_changes.get(name_orig, name_orig) + assert expected_name_dml in named_modules_dmlist + assert expected_name_dml in converted_module_names_map + assert converted_module_names_map[expected_name_dml] == name_orig + assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig) + converted_module_names_map.pop(expected_name_dml) + named_modules_dmlist.pop(expected_name_dml) + + assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map' + assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules' + + # Now make sure all parameters and buffers didn't change + for p_orig, p_dml in zip(model.parameters(), model_dml.parameters()): + check_equal_tensors(p_dml, p_orig) + for b_orig, b_dml in zip(model.buffers(), model_dml.buffers()): + check_equal_tensors(b_dml, b_orig) + + # Check forward pass gives identical results + x = torch.randn(1, 3, 50, 50) + y_orig = model(x) + y_dml = model_dml(x) + check_equal_tensors(y_dml, y_orig) if __name__ == '__main__': -- GitLab