From 8e14ef0bbf606ccff4eedd87929964bb23046a0b Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Wed, 3 Jul 2019 11:18:07 +0300
Subject: [PATCH] Bugfix in normalize_module_name

Previous code looked for the patterns 'module.' and '.module' separately
and removed the first instance.
Issues with this:
* Too broad. If a user gives some module a name that has the prefix or
  suffix 'module', that pre/suffix would be removed.
* Doesn't catch the corner case of the "root" module in a model

Modified to split name by the module separator '.', and then remove the
first instance of the name 'module'
---
 distiller/utils.py         | 10 +++++++---
 tests/test_summarygraph.py | 40 ++++++++++++++++++++------------------
 2 files changed, 28 insertions(+), 22 deletions(-)

diff --git a/distiller/utils.py b/distiller/utils.py
index d4f435c..996d0f5 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -127,9 +127,13 @@ def normalize_module_name(layer_name):
     module and want to use the same module name whether the module is parallel or not.
     We call this module name normalization, and this is implemented here.
     """
-    if layer_name.find("module.") >= 0:
-        return layer_name.replace("module.", "")
-    return layer_name.replace(".module", "")
+    modules = layer_name.split('.')
+    try:
+        idx = modules.index('module')
+    except ValueError:
+        return layer_name
+    del modules[idx]
+    return '.'.join(modules)
 
 
 def denormalize_module_name(parallel_model, normalized_name):
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index c92fe4f..596fc54 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -126,26 +126,28 @@ def test_simplenet():
     assert len(preds) == 1
 
 
-def name_test(dataset, arch):
-    model = create_model(False, dataset, arch, parallel=False)
-    modelp = create_model(False, dataset, arch, parallel=True)
-    assert model is not None and modelp is not None
-
-    mod_names   = [mod_name for mod_name, _ in model.named_modules()]
-    mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
-    assert mod_names is not None and mod_names_p is not None
-    assert len(mod_names)+1 == len(mod_names_p)
-
-    for i in range(len(mod_names)-1):
-        assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2])
-        logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2])))
-        assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1])
-
-
 def test_normalize_module_name():
-    assert "features.0" == normalize_module_name("features.module.0")
-    assert "features.0" == normalize_module_name("module.features.0")
-    assert "features" == normalize_module_name("features.module")
+    def name_test(dataset, arch):
+        model = create_model(False, dataset, arch, parallel=False)
+        modelp = create_model(False, dataset, arch, parallel=True)
+        assert model is not None and modelp is not None
+
+        mod_names = [mod_name for mod_name, _ in model.named_modules()]
+        mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
+        assert mod_names is not None and mod_names_p is not None
+        assert len(mod_names) + 1 == len(mod_names_p)
+
+        for i in range(len(mod_names) - 1):
+            assert mod_names[i + 1] == normalize_module_name(mod_names_p[i + 2])
+            logging.debug(
+                "{} {} {}".format(mod_names_p[i + 2], mod_names[i + 1], normalize_module_name(mod_names_p[i + 2])))
+            assert mod_names_p[i + 2] == denormalize_module_name(modelp, mod_names[i + 1])
+
+    assert normalize_module_name("features.module.0") == "features.0"
+    assert normalize_module_name("module.features.0") == "features.0"
+    assert normalize_module_name("features.module") == "features"
+    assert normalize_module_name('module') == ''
+    assert normalize_module_name('no.parallel.modules') == 'no.parallel.modules'
     name_test('imagenet', 'vgg19')
     name_test('cifar10', 'resnet20_cifar')
     name_test('imagenet', 'alexnet')
-- 
GitLab