From db531db8e88a07031e45a54b96b07d1528b5da00 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Mon, 29 Jul 2019 14:10:23 +0300
Subject: [PATCH] DistillerModuleList conversion: Handle models w. duplicate
 modules (#338)

* By duplicate modules we mean:
  self.relu1 = nn.Relu()
  self.relu2 = self.relu1
* The issue:
  The second module ('relu2') will not be returned by
  torch.nn.Module.named_modules/children()
* When converting to DistillerModuleList, in order to maintain the
  original order of modules and in order to have a correct mapping
  of names before/after the conversion - we need to take the duplicates
  into account
* Implemented an internal version of named_modules/children that includes
  duplicates
* Added test case for this + refactored the module list conversion tests
---
 distiller/summary_graph.py |  27 ++-
 tests/test_summarygraph.py | 334 ++++++++++++++++++++-----------------
 2 files changed, 207 insertions(+), 154 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 7fb5ab5..0c43a3a 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -706,6 +706,24 @@ class _DistillerModuleList(object):
         return main_str
 
 
+def _named_children_with_duplicates(module):
+    """Version of torch.nn.Module.named_children() that includes duplicate modules"""
+    for name, module in module._modules.items():
+        if module is not None:
+            yield name, module
+
+
+def _named_modules_with_duplicates(module, prefix=''):
+    """Version of torch.nn.Module.named_modules() that includes duplicate modules"""
+    yield prefix, module
+    for name, submodule in module._modules.items():
+        if submodule is None:
+            continue
+        submodule_prefix = prefix + ('.' if prefix else '') + name
+        for m in _named_modules_with_duplicates(submodule, submodule_prefix):
+            yield m
+
+
 def _to_distiller_modulelist(model):
     """Replaces all instances of torch.nn.ModuleList in a model with DistillerModuleList instances
 
@@ -713,9 +731,11 @@ def _to_distiller_modulelist(model):
         model (torch.nn.Module): Model to convert
     """
     def convert_container(container):
-        named_children = OrderedDict(container.named_children())
         # To maintain a similar order of registered modules compared to the original container, we unregister
         # all modules and then register them again
+        # We take care to include duplicated modules, which are not returned by the original named_moduels/children
+        # implementation in torch.nn.Module
+        named_children = OrderedDict(_named_children_with_duplicates(container))
         for n, _ in named_children.items():
             delattr(container, n)
         for name, child in named_children.items():
@@ -732,9 +752,10 @@ def _to_distiller_modulelist(model):
                     convert_container(m)
         return container
 
-    named_modules_orig = OrderedDict([(n, m) for n, m in model.named_modules() if not isinstance(m, nn.ModuleList)])
+    named_modules_orig = OrderedDict([(n, m) for n, m in _named_modules_with_duplicates(model)
+                                      if not isinstance(m, nn.ModuleList)])
     model = convert_container(model)
-    named_modules_dmlist = OrderedDict(model.named_modules())
+    named_modules_dmlist = OrderedDict(_named_modules_with_duplicates(model))
     converted_module_names_map = OrderedDict(zip(named_modules_dmlist.keys(), named_modules_orig.keys()))
 
     return model, converted_module_names_map
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index e220c90..293f69b 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -26,7 +26,8 @@ from distiller.apputils import *
 from distiller import normalize_module_name, denormalize_module_name, \
     SummaryGraph, onnx_name_2_pytorch_name
 from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose
-from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, _to_distiller_modulelist
+from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata, _DistillerModuleList, \
+    _to_distiller_modulelist, _named_modules_with_duplicates
 
 # Logging configuration
 logging.basicConfig(level=logging.DEBUG)
@@ -399,156 +400,187 @@ def test_adjacency_map(parallel, dedicated_modules):
         check_adj_entry(adj_map[name], expected)
 
 
-def test_distiller_module_list():
-    #############################################################
-    # Model for testing conversion of nested ModuleLists
-    #############################################################
-    class ListsModule(nn.Module):
-        def __init__(self):
-            super(ListsModule, self).__init__()
-            self.conv1 = nn.Conv2d(3, 10, 5)
-            self.post_conv1 = nn.ModuleList([
-                nn.BatchNorm2d(10),
-                nn.ModuleList([
-                    nn.ReLU(),
-                    nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])])
-            self.conv2 = nn.Conv2d(10, 20, 3)
-            self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)])
-
-            self.expected_mlist_to_dmlist = OrderedDict([
-                ('post_conv1', ['post_conv1']),
-                ('post_conv1.1', ['post_conv1', '1']),
-                ('post_conv1.1.1', ['post_conv1', '1', '1']),
-                ('post_conv2', ['post_conv2']),
-            ])
-            self.expected_list_contents_name_changes = OrderedDict([
-                ('post_conv1.0', 'post_conv1_0'),
-                ('post_conv1.1.0', 'post_conv1_1_0'),
-                ('post_conv1.1.1.0', 'post_conv1_1_1_0'),
-                ('post_conv1.1.1.1', 'post_conv1_1_1_1'),
-                ('post_conv2.0', 'post_conv2_0'),
-                ('post_conv2.1', 'post_conv2_1'),
-            ])
-
-        def forward(self, x):
-            x = self.conv1(x)
-            x = self.post_conv1[0](x)
-            x = self.post_conv1[1][0](x)
-            for m in self.post_conv1[1][1]:
-                x = m(x)
-            x = self.conv2(x)
-            for m in self.post_conv2:
-                x = m(x)
-            return x
-
-    #############################################################
-    # Model for testing conversion in case of nested containers
-    #############################################################
-    class Block(nn.Module):
-        def __init__(self, in_ch):
-            super(Block, self).__init__()
-            self.in_ch = in_ch
-            self.out_ch = in_ch * 2
-            self.conv = nn.Conv2d(in_ch, self.out_ch, 3)
-            self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()])
-
-        def forward(self, x):
-            x = self.conv(x)
-            for m in self.post_conv:
-                x = m(x)
-            return x
-
-    class BlocksModule(nn.Module):
-        def __init__(self):
-            super(BlocksModule, self).__init__()
-            self.block1 = Block(3)
-            self.blocks2_3 = nn.Sequential(Block(6), Block(12))
-            self.blocks4_5 = nn.ModuleList([Block(24), Block(48)])
-            self.block6 = Block(96)
-
-            self.expected_mlist_to_dmlist = OrderedDict([
-                ('block1.post_conv', ['block1', 'post_conv']),
-                ('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']),
-                ('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']),
-                ('blocks4_5', ['blocks4_5']),
-                ('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']),
-                ('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']),
-                ('block6.post_conv', ['block6', 'post_conv']),
-            ])
-            self.expected_list_contents_name_changes = OrderedDict([
-                ('block1.post_conv.0', 'block1.post_conv_0'),
-                ('block1.post_conv.1', 'block1.post_conv_1'),
-                ('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'),
-                ('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'),
-                ('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'),
-                ('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'),
-                ('blocks4_5.0', 'blocks4_5_0'),
-                ('blocks4_5.0.conv', 'blocks4_5_0.conv'),
-                ('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'),
-                ('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'),
-                ('blocks4_5.1', 'blocks4_5_1'),
-                ('blocks4_5.1.conv', 'blocks4_5_1.conv'),
-                ('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'),
-                ('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'),
-                ('block6.post_conv.0', 'block6.post_conv_0'),
-                ('block6.post_conv.1', 'block6.post_conv_1'),
-            ])
-
-        def forward(self, x):
-            x = self.block1(x)
-            x = self.blocks2_3(x)
-            for block in self.blocks4_5:
-                x = block(x)
-            x = self.block6(x)
-            return x
-
-    def check(m):
-        def check_equal_tensors(actual, expected):
-            assert (actual == expected).all().item() == 1
-
-        m_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(m))
-
-        # Check all modules converted as expected
-        named_modules_dmlist = OrderedDict(m_dml.named_modules())
-        for name_orig, module_orig in m.named_modules():
-            if name_orig in m.expected_mlist_to_dmlist:
-                # Check ModuleLists were converted to an attribute with the expected name, which is not
-                # registered as a module in the converted model
-                assert name_orig not in named_modules_dmlist
-                attr_dml = m_dml
-                for attr_name in m.expected_mlist_to_dmlist[name_orig]:
-                    try:
-                        attr_dml = attr_dml[int(attr_name)]
-                    except ValueError:
-                        attr_dml = getattr(attr_dml, attr_name)
-                assert isinstance(attr_dml, _DistillerModuleList)
-            else:
-                # Check module name changed as expected, and that the module type didn't change
-                expected_name_dml = m.expected_list_contents_name_changes.get(name_orig, name_orig)
-                assert expected_name_dml in named_modules_dmlist
-                assert expected_name_dml in converted_module_names_map
-                assert converted_module_names_map[expected_name_dml] == name_orig
-                assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig)
-                converted_module_names_map.pop(expected_name_dml)
-                named_modules_dmlist.pop(expected_name_dml)
-
-        assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map'
-        assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules'
-
-        # Now make sure all parameters and buffers didn't change
-        for p_orig, p_dml in zip(m.parameters(), m_dml.parameters()):
-            check_equal_tensors(p_dml, p_orig)
-        for b_orig, b_dml in zip(m.buffers(), m_dml.buffers()):
-            check_equal_tensors(b_dml, b_orig)
-
-        # Check forward pass gives identical results
-        x = torch.randn(1, 3, 50, 50)
-        y_orig = m(x)
-        y_dml = m_dml(x)
-        check_equal_tensors(y_dml, y_orig)
-
-    check(ListsModule())
-    check(BlocksModule())
+#############################################################
+# Conversion to DistillerModuleList tests
+#############################################################
+
+# Model for testing conversion of nested ModuleLists
+class ListsModule(nn.Module):
+    def __init__(self):
+        super(ListsModule, self).__init__()
+        self.conv1 = nn.Conv2d(3, 10, 5)
+        self.post_conv1 = nn.ModuleList([
+            nn.BatchNorm2d(10),
+            nn.ModuleList([
+                nn.ReLU(),
+                nn.ModuleList([nn.Tanh(), nn.MaxPool2d(2)])])])
+        self.conv2 = nn.Conv2d(10, 20, 3)
+        self.post_conv2 = nn.ModuleList([nn.ReLU6(), nn.MaxPool2d(4)])
+
+        self.expected_mlist_to_dmlist = OrderedDict([
+            ('post_conv1', ['post_conv1']),
+            ('post_conv1.1', ['post_conv1', '1']),
+            ('post_conv1.1.1', ['post_conv1', '1', '1']),
+            ('post_conv2', ['post_conv2']),
+        ])
+        self.expected_list_contents_name_changes = OrderedDict([
+            ('post_conv1.0', 'post_conv1_0'),
+            ('post_conv1.1.0', 'post_conv1_1_0'),
+            ('post_conv1.1.1.0', 'post_conv1_1_1_0'),
+            ('post_conv1.1.1.1', 'post_conv1_1_1_1'),
+            ('post_conv2.0', 'post_conv2_0'),
+            ('post_conv2.1', 'post_conv2_1'),
+        ])
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.post_conv1[0](x)
+        x = self.post_conv1[1][0](x)
+        for m in self.post_conv1[1][1]:
+            x = m(x)
+        x = self.conv2(x)
+        for m in self.post_conv2:
+            x = m(x)
+        return x
+
+
+# Model for testing conversion in case of nested containers
+class Block(nn.Module):
+    def __init__(self, in_ch):
+        super(Block, self).__init__()
+        self.in_ch = in_ch
+        self.out_ch = in_ch * 2
+        self.conv = nn.Conv2d(in_ch, self.out_ch, 3)
+        self.post_conv = nn.ModuleList([nn.BatchNorm2d(self.out_ch), nn.ReLU()])
+
+    def forward(self, x):
+        x = self.conv(x)
+        for m in self.post_conv:
+            x = m(x)
+        return x
+
+
+class BlocksModule(nn.Module):
+    def __init__(self):
+        super(BlocksModule, self).__init__()
+        self.block1 = Block(3)
+        self.blocks2_3 = nn.Sequential(Block(6), Block(12))
+        self.blocks4_5 = nn.ModuleList([Block(24), Block(48)])
+        self.block6 = Block(96)
+
+        self.expected_mlist_to_dmlist = OrderedDict([
+            ('block1.post_conv', ['block1', 'post_conv']),
+            ('blocks2_3.0.post_conv', ['blocks2_3', '0', 'post_conv']),
+            ('blocks2_3.1.post_conv', ['blocks2_3', '1', 'post_conv']),
+            ('blocks4_5', ['blocks4_5']),
+            ('blocks4_5.0.post_conv', ['blocks4_5', '0', 'post_conv']),
+            ('blocks4_5.1.post_conv', ['blocks4_5', '1', 'post_conv']),
+            ('block6.post_conv', ['block6', 'post_conv']),
+        ])
+        self.expected_list_contents_name_changes = OrderedDict([
+            ('block1.post_conv.0', 'block1.post_conv_0'),
+            ('block1.post_conv.1', 'block1.post_conv_1'),
+            ('blocks2_3.0.post_conv.0', 'blocks2_3.0.post_conv_0'),
+            ('blocks2_3.0.post_conv.1', 'blocks2_3.0.post_conv_1'),
+            ('blocks2_3.1.post_conv.0', 'blocks2_3.1.post_conv_0'),
+            ('blocks2_3.1.post_conv.1', 'blocks2_3.1.post_conv_1'),
+            ('blocks4_5.0', 'blocks4_5_0'),
+            ('blocks4_5.0.conv', 'blocks4_5_0.conv'),
+            ('blocks4_5.0.post_conv.0', 'blocks4_5_0.post_conv_0'),
+            ('blocks4_5.0.post_conv.1', 'blocks4_5_0.post_conv_1'),
+            ('blocks4_5.1', 'blocks4_5_1'),
+            ('blocks4_5.1.conv', 'blocks4_5_1.conv'),
+            ('blocks4_5.1.post_conv.0', 'blocks4_5_1.post_conv_0'),
+            ('blocks4_5.1.post_conv.1', 'blocks4_5_1.post_conv_1'),
+            ('block6.post_conv.0', 'block6.post_conv_0'),
+            ('block6.post_conv.1', 'block6.post_conv_1'),
+        ])
+
+    def forward(self, x):
+        x = self.block1(x)
+        x = self.blocks2_3(x)
+        for block in self.blocks4_5:
+            x = block(x)
+        x = self.block6(x)
+        return x
+
+
+# Model with duplicate modules
+class ModelWithDuplicates(nn.Module):
+    def __init__(self):
+        super(ModelWithDuplicates, self).__init__()
+        self.conv1 = nn.Conv2d(3, 10, 5)
+        self.post_conv1 = nn.ModuleList([nn.ReLU(), nn.Tanh()])
+        self.conv2 = nn.Conv2d(10, 20, 3)
+        self.post_conv2 = self.post_conv1
+
+        self.expected_mlist_to_dmlist = OrderedDict([
+            ('post_conv1', ['post_conv1']),
+            ('post_conv2', ['post_conv2']),
+        ])
+        self.expected_list_contents_name_changes = OrderedDict([
+            ('post_conv1.0', 'post_conv1_0'),
+            ('post_conv1.1', 'post_conv1_1'),
+            ('post_conv2.0', 'post_conv2_0'),
+            ('post_conv2.1', 'post_conv2_1'),
+        ])
+
+    def forward(self, x):
+        x = self.conv1(x)
+        for m in self.post_conv1:
+            x = m(x)
+        x = self.conv2(x)
+        for m in self.post_conv2:
+            x = m(x)
+        return x
+
+
+@pytest.mark.parametrize("model", [ListsModule(), BlocksModule(), ModelWithDuplicates()],
+                         ids=['ListsModule', 'BlocksModule', 'ModelWithDuplicates'])
+def test_distiller_module_list_conversion(model):
+    def check_equal_tensors(actual, expected):
+        assert (actual == expected).all().item() == 1
+
+    model_dml, converted_module_names_map = _to_distiller_modulelist(deepcopy(model))
+
+    # Check all modules converted as expected
+    named_modules_dmlist = OrderedDict(_named_modules_with_duplicates(model_dml))
+    for name_orig, module_orig in _named_modules_with_duplicates(model):
+        if name_orig in model.expected_mlist_to_dmlist:
+            # Check ModuleLists were converted to an attribute with the expected name, which is not
+            # registered as a module in the converted model
+            assert name_orig not in named_modules_dmlist
+            attr_dml = model_dml
+            for attr_name in model.expected_mlist_to_dmlist[name_orig]:
+                try:
+                    attr_dml = attr_dml[int(attr_name)]
+                except ValueError:
+                    attr_dml = getattr(attr_dml, attr_name)
+            assert isinstance(attr_dml, _DistillerModuleList)
+        else:
+            # Check module name changed as expected, and that the module type didn't change
+            expected_name_dml = model.expected_list_contents_name_changes.get(name_orig, name_orig)
+            assert expected_name_dml in named_modules_dmlist
+            assert expected_name_dml in converted_module_names_map
+            assert converted_module_names_map[expected_name_dml] == name_orig
+            assert type(named_modules_dmlist[expected_name_dml]) == type(module_orig)
+            converted_module_names_map.pop(expected_name_dml)
+            named_modules_dmlist.pop(expected_name_dml)
+
+    assert not converted_module_names_map, 'Unexpected contents in converted_module_names_map'
+    assert not named_modules_dmlist, 'Unexpected contents in converted model named_modules'
+
+    # Now make sure all parameters and buffers didn't change
+    for p_orig, p_dml in zip(model.parameters(), model_dml.parameters()):
+        check_equal_tensors(p_dml, p_orig)
+    for b_orig, b_dml in zip(model.buffers(), model_dml.buffers()):
+        check_equal_tensors(b_dml, b_orig)
+
+    # Check forward pass gives identical results
+    x = torch.randn(1, 3, 50, 50)
+    y_orig = model(x)
+    y_dml = model_dml(x)
+    check_equal_tensors(y_dml, y_orig)
 
 
 if __name__ == '__main__':
-- 
GitLab