diff --git a/distiller/thinning.py b/distiller/thinning.py
index caa3192528b6a6518beafbc5cc349b2a83e734cb..dbd2c56e7f28ad29a6c836a8473a1ecbf9026399 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -35,6 +35,7 @@ from collections import namedtuple
 import torch
 from .policy import ScheduledTrainingPolicy
 import distiller
+from distiller import normalize_module_name, denormalize_module_name
 from apputils import SummaryGraph
 from models import ALL_MODEL_NAMES, create_model
 msglogger = logging.getLogger()
@@ -69,6 +70,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'find_nonzero_channels',
            'execute_thinning_recipes_list']
 
+
 def create_graph(dataset, arch):
     if dataset == 'imagenet':
         dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
@@ -81,23 +83,6 @@ def create_graph(dataset, arch):
     return SummaryGraph(model, dummy_input)
 
 
-def normalize_layer_name(layer_name):
-    start = layer_name.find('module.')
-    normalized_layer_name = layer_name
-    if start != -1:
-        normalized_layer_name = layer_name[:start] + layer_name[start + len('module.'):]
-    return normalized_layer_name
-
-
-def denormalize_layer_name(model, normalized_name):
-    """Convert back from the normalized form of the layer name, to PyTorch's name
-    which contains "artifacts" if DataParallel is used.
-    """
-    ugly = [mod_name for mod_name, _ in model.named_modules() if normalize_layer_name(mod_name) == normalized_name]
-    assert len(ugly) == 1
-    return ugly[0]
-
-
 def param_name_2_layer_name(param_name):
     return param_name[:-len('weights')]
 
@@ -262,9 +247,9 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
         append_param_directive(thinning_recipe, param_name, (1, indices))
 
         # Find all instances of Convolution layers that immediately preceed this layer
-        predecessors = sgraph.predecessors_f(normalize_layer_name(layer_name), ['Conv'])
+        predecessors = sgraph.predecessors_f(normalize_module_name(layer_name), ['Conv'])
         # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
-        predecessors = [denormalize_layer_name(model, predecessor) for predecessor in predecessors]
+        predecessors = [denormalize_module_name(model, predecessor) for predecessor in predecessors]
         for predecessor in predecessors:
             # For each of the convolutional layers that preceed, we have to reduce the number of output channels.
             append_module_directive(thinning_recipe, predecessor, key='out_channels', val=len(nonzero_channels))
@@ -273,11 +258,11 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
             append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
 
         # Now handle the BatchNormalization layer that follows the convolution
-        bn_layers = sgraph.predecessors_f(normalize_layer_name(layer_name), ['BatchNormalization'])
+        bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization'])
         if len(bn_layers) > 0:
             assert len(bn_layers) == 1
             # Thinning of the BN layer that follows the convolution
-            bn_layer_name = denormalize_layer_name(model, bn_layers[0])
+            bn_layer_name = denormalize_module_name(model, bn_layers[0])
             bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_channels), thin_features=indices)
 
     return thinning_recipe
@@ -331,9 +316,9 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices))
 
         # Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer
-        successors = sgraph.successors_f(normalize_layer_name(layer_name), ['Conv', 'Gemm'])
+        successors = sgraph.successors_f(normalize_module_name(layer_name), ['Conv', 'Gemm'])
         # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
-        successors = [denormalize_layer_name(model, successor) for successor in successors]
+        successors = [denormalize_module_name(model, successor) for successor in successors]
         for successor in successors:
 
             if isinstance(layers[successor], torch.nn.modules.Conv2d):
@@ -360,11 +345,11 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
                 append_param_directive(thinning_recipe, successor+'.weight', (1, indices, view_4D, view_2D))
 
         # Now handle the BatchNormalization layer that follows the convolution
-        bn_layers = sgraph.successors_f(normalize_layer_name(layer_name), ['BatchNormalization'])
+        bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization'])
         if len(bn_layers) > 0:
             assert len(bn_layers) == 1
             # Thinning of the BN layer that follows the convolution
-            bn_layer_name = denormalize_layer_name(model, bn_layers[0])
+            bn_layer_name = denormalize_module_name(model, bn_layers[0])
             bn_thinning(thinning_recipe, layers, bn_layer_name, len_thin_features=len(nonzero_filters), thin_features=indices)
     return thinning_recipe
 
@@ -412,6 +397,7 @@ class FilterRemover(ScheduledTrainingPolicy):
         # The epoch has ended and we reset the 'done' flag, so that the FilterRemover instance can be reused
         self.done = False
 
+
 def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
     # Invoke this function when you want to use a list of thinning recipes to convert a programmed model
     # to a thinned model. For example, this is invoked when loading a model from a checkpoint.
@@ -419,6 +405,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list):
         msglogger.info("recipe %d:" % i)
         execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=True)
 
+
 def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False):
     """Apply a thinning recipe to a model.
 
diff --git a/distiller/utils.py b/distiller/utils.py
index 7b6f303c8dad68ea1011570b0ca386e08aa6ac25..c977b809973ee8c30939782221af06479b4bf132 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -24,14 +24,17 @@ import numpy as np
 import torch
 from torch.autograd import Variable
 
+
 def to_np(var):
     return var.data.cpu().numpy()
 
+
 def to_var(tensor, cuda=True):
     if cuda and torch.cuda.is_available():
         tensor = tensor.cuda()
     return Variable(tensor)
 
+
 def size2str(torch_size):
     if isinstance(torch_size, torch.Size):
         return size_to_str(torch_size)
@@ -41,15 +44,45 @@ def size2str(torch_size):
         return size_to_str(torch_size.data.size())
     raise TypeError
 
+
 def size_to_str(torch_size):
     """Convert a pytorch Size object to a string"""
     assert isinstance(torch_size, torch.Size)
     return '('+(', ').join(['%d' % v for v in torch_size])+')'
 
+
+def normalize_module_name(layer_name):
+    """Normalize a module's name.
+
+    PyTorch let's you parallelize the computation of a model, by wrapping a model with a
+    DataParallel module.  Unfortunately, this changs the fully-qualified name of a module,
+    even though the actual functionality of the module doesn't change.
+    Many time, when we search for modules by name, we are indifferent to the DataParallel
+    module and want to use the same module name whether the module is parallel or not.
+    We call this module name normalization, and this is implemented here.
+    """
+    if layer_name.find("module.") >= 0:
+        return layer_name.replace("module.", "")
+    return layer_name.replace(".module", "")
+
+
+def denormalize_module_name(parallel_model, normalized_name):
+    """Convert back from the normalized form of the layer name, to PyTorch's name
+    which contains "artifacts" if DataParallel is used.
+    """
+    fully_qualified_name = [mod_name for mod_name, _ in parallel_model.named_modules() if
+                            normalize_module_name(mod_name) == normalized_name]
+    if len(fully_qualified_name) > 0:
+        return fully_qualified_name[-1]
+    else:
+        return ""   # Did not find a module with the name <normalized_name>
+
+
 def volume(tensor):
     """return the volume of a pytorch tensor"""
     return np.prod(tensor.shape)
 
+
 def density(tensor):
     """Computes the density of a tensor.
 
@@ -84,6 +117,7 @@ def sparsity(tensor):
     """
     return 1.0 - density(tensor)
 
+
 def sparsity_3D(tensor):
     """Filter-wise sparsity for 4D tensors"""
     if tensor.dim() != 4:
@@ -93,10 +127,12 @@ def sparsity_3D(tensor):
     nonzero_filters = len(torch.nonzero(view_3d.abs().sum(dim=1)))
     return 1 - nonzero_filters/num_filters
 
+
 def density_3D(tensor):
     """Filter-wise density for 4D tensors"""
     return 1 - sparsity_3D(tensor)
 
+
 def sparsity_2D(tensor):
     """Create a list of sparsity levels for each channel in the tensor 't'
 
@@ -128,10 +164,12 @@ def sparsity_2D(tensor):
     nonzero_structs = len(torch.nonzero(view_2d.abs().sum(dim=1)))
     return 1 - nonzero_structs/num_structs
 
+
 def density_2D(tensor):
     """Kernel-wise sparsity for 4D tensors"""
     return 1 - sparsity_2D(tensor)
 
+
 def sparsity_ch(tensor):
     """Channel-wise sparsity for 4D tensors"""
     if tensor.dim() != 4:
@@ -150,10 +188,12 @@ def sparsity_ch(tensor):
     nonzero_channels = len(torch.nonzero(k_sums_mat.abs().sum(dim=1)))
     return 1 - nonzero_channels/num_kernels_per_filter
 
+
 def density_ch(tensor):
     """Channel-wise density for 4D tensors"""
     return 1 - sparsity_ch(tensor)
 
+
 def sparsity_cols(tensor):
     """Column-wise sparsity for 2D tensors"""
     if tensor.dim() != 2:
@@ -163,10 +203,12 @@ def sparsity_cols(tensor):
     nonzero_cols = len(torch.nonzero(tensor.abs().sum(dim=0)))
     return 1 - nonzero_cols/num_cols
 
+
 def density_cols(tensor):
     """Column-wise density for 2D tensors"""
     return 1 - sparsity_cols(tensor)
 
+
 def sparsity_rows(tensor):
     """Row-wise sparsity for 2D matrices"""
     if tensor.dim() != 2:
@@ -176,10 +218,12 @@ def sparsity_rows(tensor):
     nonzero_rows = len(torch.nonzero(tensor.abs().sum(dim=1)))
     return 1 - nonzero_rows/num_rows
 
+
 def density_rows(tensor):
     """Row-wise density for 2D tensors"""
     return 1 - sparsity_rows(tensor)
 
+
 def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers):
     """Log information about the training progress, and the distribution of the weight tensors.
 
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 31145dbb8dedf6d486bf65f681570b3d68995bfb..50ce5753da84840bf1a8a89859ae67f3d950eec7 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -26,6 +26,7 @@ import distiller
 import pytest
 from models import ALL_MODEL_NAMES, create_model
 from apputils import SummaryGraph, onnx_name_2_pytorch_name
+from distiller import normalize_module_name, denormalize_module_name
 
 # Logging configuration
 logging.basicConfig(level=logging.DEBUG)
@@ -124,13 +125,6 @@ def test_layer_search():
     assert preds == ['layer1.0.conv2', 'conv1']
 
 
-def normalize_layer_name(layer_name):
-    start = layer_name.find('module.')
-    if start != -1:
-        layer_name = layer_name[:start] + layer_name[start + len('module.'):]
-    return layer_name
-
-
 def test_vgg():
     g = create_graph('imagenet', 'vgg19')
     assert g is not None
@@ -139,10 +133,29 @@ def test_vgg():
     succs = g.successors_f('features.34', 'Conv')
 
 
-def test_normalize_layer_name():
-    assert "features.0", normalize_layer_name("features.module.0")
-    assert "features.0", normalize_layer_name("module.features.0")
-    assert "features.0", normalize_layer_name("features.0.module")
+def name_test(dataset, arch):
+    model = create_model(False, dataset, arch, parallel=False)
+    modelp = create_model(False, dataset, arch, parallel=True)
+    assert model is not None and modelp is not None
+
+    mod_names   = [mod_name for mod_name, _ in model.named_modules()]
+    mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
+    assert mod_names is not None and mod_names_p is not None
+    assert len(mod_names)+1 == len(mod_names_p)
+
+    for i in range(len(mod_names)-1):
+        assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2])
+        logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2])))
+        assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1])
+
+
+def test_normalize_module_name():
+    assert "features.0" == normalize_module_name("features.module.0")
+    assert "features.0" == normalize_module_name("module.features.0")
+    assert "features" == normalize_module_name("features.module")
+    name_test('imagenet', 'vgg19')
+    name_test('cifar10', 'resnet20_cifar')
+    name_test('imagenet', 'alexnet')
 
 
 def test_onnx_name_2_pytorch_name():