diff --git a/distiller/utils.py b/distiller/utils.py index d4f435c81a27ee3b6ffff006749efadc7aa8e4c9..996d0f5168a28a8cdd96cb0048231ddd428d6519 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -127,9 +127,13 @@ def normalize_module_name(layer_name): module and want to use the same module name whether the module is parallel or not. We call this module name normalization, and this is implemented here. """ - if layer_name.find("module.") >= 0: - return layer_name.replace("module.", "") - return layer_name.replace(".module", "") + modules = layer_name.split('.') + try: + idx = modules.index('module') + except ValueError: + return layer_name + del modules[idx] + return '.'.join(modules) def denormalize_module_name(parallel_model, normalized_name): diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index c92fe4fdeb15d859dd08d67d6245f427db373b4b..596fc54f0885507ba998bdb7746fdde21ec94987 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -126,26 +126,28 @@ def test_simplenet(): assert len(preds) == 1 -def name_test(dataset, arch): - model = create_model(False, dataset, arch, parallel=False) - modelp = create_model(False, dataset, arch, parallel=True) - assert model is not None and modelp is not None - - mod_names = [mod_name for mod_name, _ in model.named_modules()] - mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()] - assert mod_names is not None and mod_names_p is not None - assert len(mod_names)+1 == len(mod_names_p) - - for i in range(len(mod_names)-1): - assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2]) - logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2]))) - assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1]) - - def test_normalize_module_name(): - assert "features.0" == normalize_module_name("features.module.0") - assert "features.0" == normalize_module_name("module.features.0") - assert "features" == normalize_module_name("features.module") + def name_test(dataset, arch): + model = create_model(False, dataset, arch, parallel=False) + modelp = create_model(False, dataset, arch, parallel=True) + assert model is not None and modelp is not None + + mod_names = [mod_name for mod_name, _ in model.named_modules()] + mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()] + assert mod_names is not None and mod_names_p is not None + assert len(mod_names) + 1 == len(mod_names_p) + + for i in range(len(mod_names) - 1): + assert mod_names[i + 1] == normalize_module_name(mod_names_p[i + 2]) + logging.debug( + "{} {} {}".format(mod_names_p[i + 2], mod_names[i + 1], normalize_module_name(mod_names_p[i + 2]))) + assert mod_names_p[i + 2] == denormalize_module_name(modelp, mod_names[i + 1]) + + assert normalize_module_name("features.module.0") == "features.0" + assert normalize_module_name("module.features.0") == "features.0" + assert normalize_module_name("features.module") == "features" + assert normalize_module_name('module') == '' + assert normalize_module_name('no.parallel.modules') == 'no.parallel.modules' name_test('imagenet', 'vgg19') name_test('cifar10', 'resnet20_cifar') name_test('imagenet', 'alexnet')