From 8e14ef0bbf606ccff4eedd87929964bb23046a0b Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Wed, 3 Jul 2019 11:18:07 +0300 Subject: [PATCH] Bugfix in normalize_module_name Previous code looked for the patterns 'module.' and '.module' separately and removed the first instance. Issues with this: * Too broad. If a user gives some module a name that has the prefix or suffix 'module', that pre/suffix would be removed. * Doesn't catch the corner case of the "root" module in a model Modified to split name by the module separator '.', and then remove the first instance of the name 'module' --- distiller/utils.py | 10 +++++++--- tests/test_summarygraph.py | 40 ++++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/distiller/utils.py b/distiller/utils.py index d4f435c..996d0f5 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -127,9 +127,13 @@ def normalize_module_name(layer_name): module and want to use the same module name whether the module is parallel or not. We call this module name normalization, and this is implemented here. """ - if layer_name.find("module.") >= 0: - return layer_name.replace("module.", "") - return layer_name.replace(".module", "") + modules = layer_name.split('.') + try: + idx = modules.index('module') + except ValueError: + return layer_name + del modules[idx] + return '.'.join(modules) def denormalize_module_name(parallel_model, normalized_name): diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index c92fe4f..596fc54 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -126,26 +126,28 @@ def test_simplenet(): assert len(preds) == 1 -def name_test(dataset, arch): - model = create_model(False, dataset, arch, parallel=False) - modelp = create_model(False, dataset, arch, parallel=True) - assert model is not None and modelp is not None - - mod_names = [mod_name for mod_name, _ in model.named_modules()] - mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()] - assert mod_names is not None and mod_names_p is not None - assert len(mod_names)+1 == len(mod_names_p) - - for i in range(len(mod_names)-1): - assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2]) - logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2]))) - assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1]) - - def test_normalize_module_name(): - assert "features.0" == normalize_module_name("features.module.0") - assert "features.0" == normalize_module_name("module.features.0") - assert "features" == normalize_module_name("features.module") + def name_test(dataset, arch): + model = create_model(False, dataset, arch, parallel=False) + modelp = create_model(False, dataset, arch, parallel=True) + assert model is not None and modelp is not None + + mod_names = [mod_name for mod_name, _ in model.named_modules()] + mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()] + assert mod_names is not None and mod_names_p is not None + assert len(mod_names) + 1 == len(mod_names_p) + + for i in range(len(mod_names) - 1): + assert mod_names[i + 1] == normalize_module_name(mod_names_p[i + 2]) + logging.debug( + "{} {} {}".format(mod_names_p[i + 2], mod_names[i + 1], normalize_module_name(mod_names_p[i + 2]))) + assert mod_names_p[i + 2] == denormalize_module_name(modelp, mod_names[i + 1]) + + assert normalize_module_name("features.module.0") == "features.0" + assert normalize_module_name("module.features.0") == "features.0" + assert normalize_module_name("features.module") == "features" + assert normalize_module_name('module') == '' + assert normalize_module_name('no.parallel.modules') == 'no.parallel.modules' name_test('imagenet', 'vgg19') name_test('cifar10', 'resnet20_cifar') name_test('imagenet', 'alexnet') -- GitLab