diff --git a/distiller/utils.py b/distiller/utils.py
index d4f435c81a27ee3b6ffff006749efadc7aa8e4c9..996d0f5168a28a8cdd96cb0048231ddd428d6519 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -127,9 +127,13 @@ def normalize_module_name(layer_name):
     module and want to use the same module name whether the module is parallel or not.
     We call this module name normalization, and this is implemented here.
     """
-    if layer_name.find("module.") >= 0:
-        return layer_name.replace("module.", "")
-    return layer_name.replace(".module", "")
+    modules = layer_name.split('.')
+    try:
+        idx = modules.index('module')
+    except ValueError:
+        return layer_name
+    del modules[idx]
+    return '.'.join(modules)
 
 
 def denormalize_module_name(parallel_model, normalized_name):
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index c92fe4fdeb15d859dd08d67d6245f427db373b4b..596fc54f0885507ba998bdb7746fdde21ec94987 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -126,26 +126,28 @@ def test_simplenet():
     assert len(preds) == 1
 
 
-def name_test(dataset, arch):
-    model = create_model(False, dataset, arch, parallel=False)
-    modelp = create_model(False, dataset, arch, parallel=True)
-    assert model is not None and modelp is not None
-
-    mod_names   = [mod_name for mod_name, _ in model.named_modules()]
-    mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
-    assert mod_names is not None and mod_names_p is not None
-    assert len(mod_names)+1 == len(mod_names_p)
-
-    for i in range(len(mod_names)-1):
-        assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2])
-        logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2])))
-        assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1])
-
-
 def test_normalize_module_name():
-    assert "features.0" == normalize_module_name("features.module.0")
-    assert "features.0" == normalize_module_name("module.features.0")
-    assert "features" == normalize_module_name("features.module")
+    def name_test(dataset, arch):
+        model = create_model(False, dataset, arch, parallel=False)
+        modelp = create_model(False, dataset, arch, parallel=True)
+        assert model is not None and modelp is not None
+
+        mod_names = [mod_name for mod_name, _ in model.named_modules()]
+        mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
+        assert mod_names is not None and mod_names_p is not None
+        assert len(mod_names) + 1 == len(mod_names_p)
+
+        for i in range(len(mod_names) - 1):
+            assert mod_names[i + 1] == normalize_module_name(mod_names_p[i + 2])
+            logging.debug(
+                "{} {} {}".format(mod_names_p[i + 2], mod_names[i + 1], normalize_module_name(mod_names_p[i + 2])))
+            assert mod_names_p[i + 2] == denormalize_module_name(modelp, mod_names[i + 1])
+
+    assert normalize_module_name("features.module.0") == "features.0"
+    assert normalize_module_name("module.features.0") == "features.0"
+    assert normalize_module_name("features.module") == "features"
+    assert normalize_module_name('module') == ''
+    assert normalize_module_name('no.parallel.modules') == 'no.parallel.modules'
     name_test('imagenet', 'vgg19')
     name_test('cifar10', 'resnet20_cifar')
     name_test('imagenet', 'alexnet')