diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 2d2dab63b81f4bd1fe9b05e26bee75bf7d58d019..7c0b09d0e0d4aaa509f831e9eb17abc1316f567d 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -93,12 +93,13 @@ class SummaryGraph(object):
     Edge = collections.namedtuple('Edge', 'src dst')
 
     def __init__(self, model, dummy_input):
-        model = distiller.make_non_parallel_copy(model)
-        with torch.onnx.set_training(model, False):
+        self._src_model = model
+        model_clone = distiller.make_non_parallel_copy(model)
+        with torch.onnx.set_training(model_clone, False):
             
-            device = next(model.parameters()).device
+            device = next(model_clone.parameters()).device
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
-            trace, _ = jit.get_trace_graph(model, dummy_input)
+            trace, _ = jit.get_trace_graph(model_clone, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
@@ -152,7 +153,7 @@ class SummaryGraph(object):
         self.add_macs_attr()
         self.add_footprint_attr()
         self.add_arithmetic_intensity_attr()
-        del model
+        del model_clone
 
     def __create_op(self, onnx_node):
         op = {}
@@ -266,15 +267,13 @@ class SummaryGraph(object):
         return [op for op in self.ops.values() if attr in op['attrs'] and f(op)]
 
     def find_op(self, lost_op_name):
-        assert isinstance(lost_op_name, str)
-        return self.ops.get(lost_op_name, None)
+        return self.ops.get(distiller.normalize_module_name(lost_op_name), None)
 
     def find_param(self, data_name):
         return self.params.get(data_name, None)
 
     def predecessors(self, op, depth, done_list=None):
         """Returns a list of <op>'s predecessors"""
-
         if done_list is None:
             done_list = []
 
@@ -288,16 +287,18 @@ class SummaryGraph(object):
             done_list += preds
 
         if depth == 1:
-            return preds
+            ret = preds
         else:
             ret = []
             for predecessor in preds:
                 ret += self.predecessors(predecessor, depth-1, done_list)
-            return ret
+
+        return [distiller.denormalize_module_name(self._src_model, x) for x in ret]
 
     def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None):
         """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria.
         """
+        node_name = distiller.normalize_module_name(node_name)
         node = self.find_op(node_name)
         node_is_an_op = True
         if node is None:
@@ -319,7 +320,7 @@ class SummaryGraph(object):
             # We check if we found the type of node we're looking for,
             # and that this is not the first node in our search.
             if node['type'] in predecessors_types and len(done_list) > 1:
-                return [node_name]
+                return [distiller.denormalize_module_name(self._src_model, node_name)]
 
             # This is an operation node
             preds = [edge.src for edge in self.edges if (edge.dst == node_name and
@@ -331,11 +332,11 @@ class SummaryGraph(object):
         ret = []
         for predecessor in preds:
             ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging)
-        return ret
+
+        return [distiller.denormalize_module_name(self._src_model, node) for node in ret]
 
     def successors(self, node, depth, done_list=None):
         """Returns a list of <op>'s successors"""
-
         if done_list is None:
             done_list = []
 
@@ -351,12 +352,13 @@ class SummaryGraph(object):
             done_list += succs
 
         if depth == 1:
-            return succs
+            ret = succs
         else:
             ret = []
             for successor in succs:
                 ret += self.successors(successor, depth-1, done_list)
-            return ret
+
+        return [distiller.denormalize_module_name(self._src_model, x) for x in ret]
 
     def successors_f(self, node_name, successors_types, done_list=None, logging=None):
         """Returns a list of <op>'s successors, if they match the <successors_types> criteria.
@@ -367,7 +369,7 @@ class SummaryGraph(object):
 
         <node_name> and the returned list of successors are strings, because
         """
-
+        node_name = distiller.normalize_module_name(node_name)
         node = self.find_op(node_name)
         node_is_an_op = True
         if node is None:
@@ -389,7 +391,7 @@ class SummaryGraph(object):
             # We check if we found the type of node we're looking for,
             # and that this is not the first node in our search.
             if node['type'] in successors_types and len(done_list) > 1:
-                return [node_name]
+                return [distiller.denormalize_module_name(self._src_model, node_name)]
 
             # This is an operation node
             succs = [edge.dst for edge in self.edges if (edge.src == node_name and
@@ -401,4 +403,15 @@ class SummaryGraph(object):
         ret = []
         for successor in succs:
             ret += self.successors_f(successor, successors_types, done_list, logging)
-        return ret
+
+        return [distiller.denormalize_module_name(self._src_model, node) for node in ret]
+
+    def named_params_layers(self):
+        for param_name, param in self._src_model.named_parameters():
+            # remove the extension of param_name, and then normalize it
+            # to create a normalized layer name
+            normalized_layer_name = distiller.normalize_module_name(
+                '.'.join(param_name.split('.')[:-1]))
+            sgraph_layer_name = distiller.denormalize_module_name(
+                self._src_model, normalized_layer_name)
+            yield sgraph_layer_name, param_name, param
diff --git a/distiller/thinning.py b/distiller/thinning.py
index 78cd2879469b1c62f40094f2702f4e4e2e6ecedb..0ca0e2499ce749dc8e8e65a37689f0c5a4b3941a 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -31,8 +31,6 @@ from collections import namedtuple
 import torch
 from .policy import ScheduledTrainingPolicy
 import distiller
-from distiller import normalize_module_name, denormalize_module_name
-from distiller.models import create_model
 from .summary_graph import SummaryGraph
 msglogger = logging.getLogger(__name__)
 
@@ -64,34 +62,23 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
            'execute_thinning_recipes_list', 'get_normalized_recipe']
 
 
-def create_graph(dataset, arch):
+def create_graph(dataset, model):
+    dummy_input = None
     if dataset == 'imagenet':
         dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
     elif dataset == 'cifar10':
         dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
     assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
 
-    model = create_model(False, dataset, arch, parallel=False)
-    assert model is not None
     dummy_input = dummy_input.to(distiller.model_device(model))
     return SummaryGraph(model, dummy_input)
 
 
 def get_normalized_recipe(recipe):
-    new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()},
-                                parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()})
-    return new_recipe
-
-
-def param_name_2_layer_name(param_name):
-    """Convert a weights tensor's name to the name of the layer using the tensor.
-    
-    By convention, PyTorch modules name their weights parameters as self.weight
-    (see for example: torch.nn.modules.conv) which means that their fully-qualified 
-    name when enumerating a model's parameters is the modules name followed by '.weight'.
-    We exploit this convention to convert a weights tensor name to the fully-qualified 
-    module name."""
-    return param_name[:-len('.weight')]
+    return ThinningRecipe(
+        modules={distiller.normalize_module_name(k): v for k, v in recipe.modules.items()},
+        parameters={distiller.normalize_module_name(k): v for k, v in recipe.parameters.items()},
+        )
 
 
 def directives_equal(d1, d2):
@@ -120,9 +107,8 @@ def append_param_directive(thinning_recipe, param_name, directive):
     thinning_recipe.parameters[param_name] = param_directives
 
 
-def append_module_directive(model, thinning_recipe, module_name, key, val):
+def append_module_directive(thinning_recipe, module_name, key, val):
     msglogger.debug("\t[recipe] setting {}.{} = {}".format(module_name, key, val))
-    module_name = denormalize_module_name(model, module_name)
     mod_directive = thinning_recipe.modules.get(module_name, {})
     mod_directive[key] = val
     thinning_recipe.modules[module_name] = mod_directive
@@ -180,7 +166,7 @@ def resnet_cifar_remove_layers(model):
 
 
 def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
-    sgraph = create_graph(dataset, arch)
+    sgraph = create_graph(dataset, model)
     thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
     apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
@@ -234,7 +220,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
 
 
 def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
-    sgraph = create_graph(dataset, arch)
+    sgraph = create_graph(dataset, model)
     thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
     apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
     return model
@@ -256,7 +242,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
 
     # Traverse all of the model's parameters, search for zero-channels, and
     # create a thinning recipe that descibes the required changes to the model.
-    for param_name, param in model.named_parameters():
+    for layer_name, param_name, param in sgraph.named_params_layers():
         # We are only interested in 4D weights (of Convolution layers)
         if param.dim() != 4:
             continue
@@ -272,43 +258,35 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
 
         # We are removing channels, so update the number of incoming channels (IFMs)
         # in the convolutional layer
-        layer_name = param_name_2_layer_name(param_name)
         assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
-        append_module_directive(model, thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels)
+        append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels)
 
         # Select only the non-zero filters
         indices = nonzero_channels.data.squeeze()
         append_param_directive(thinning_recipe, param_name, (1, indices))
 
         # Find all instances of Convolution layers that immediately preceed this layer
-        predecessors = sgraph.predecessors_f(normalize_module_name(layer_name), ['Conv'])
-        # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
-        predecessors = [normalize_module_name(predecessor) for predecessor in predecessors]
+        predecessors = sgraph.predecessors_f(layer_name, ['Conv'])
         if len(predecessors) == 0:
-            msglogger.info("Could not find predecessors for name={} normal={} {}".format(
-                           layer_name, normalize_module_name(layer_name), denormalize_module_name(model, layer_name)))
+            msglogger.info("Could not find predecessors for name={}".format(layer_name))
         for predecessor in predecessors:
             # For each of the convolutional layers that preceed, we have to reduce the number of output channels.
-            append_module_directive(model, thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels)
+            append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels)
 
             # Now remove channels from the weights tensor of the predecessor conv
-            append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.weight', (0, indices))
+            append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
 
-            if layers[denormalize_module_name(model, predecessor)].bias is not None:
+            if layers[predecessor].bias is not None:
                 # This convolution has bias coefficients
-                append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.bias', (0, indices))
+                append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
 
         # Now handle the BatchNormalization layer that follows the convolution
-        bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization'])
-        if len(bn_layers) > 0:
-            # if len(bn_layers) != 1:
-            #     raise RuntimeError("{} should have exactly one BN predecessors, but has {}".format(layer_name, len(bn_layers)))
-            for bn_layer in bn_layers:
-                # Thinning of the BN layer that follows the convolution
-                bn_layer_name = denormalize_module_name(model, bn_layer)
-                msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer_name))
-                append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name,
-                                             len_thin_features=num_nnz_channels, thin_features=indices)
+        bn_layers = sgraph.predecessors_f(layer_name, ['BatchNormalization'])
+        for bn_layer in bn_layers:
+            # Thinning of the BN layer that follows the convolution
+            msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer))
+            append_bn_thinning_directive(thinning_recipe, layers, bn_layer,
+                                         len_thin_features=num_nnz_channels, thin_features=indices)
 
     msglogger.debug(thinning_recipe)
     return thinning_recipe
@@ -329,7 +307,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
     thinning_recipe = ThinningRecipe(modules={}, parameters={})
     layers = {mod_name: m for mod_name, m in model.named_modules()}
 
-    for param_name, param in model.named_parameters():
+    for layer_name, param_name, param in sgraph.named_params_layers():
         # We are only interested in 4D weights
         if param.dim() != 4:
             continue
@@ -343,7 +321,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name)
         # If there are non-zero filters in this tensor then continue to next tensor
         if num_filters <= num_nnz_filters:
-            msglogger.debug("Skipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape))
+            msglogger.debug("Skipping {} shape={}".format(param_name, param.shape))
             continue
 
         msglogger.info("In tensor %s found %d/%d zero filters", param_name,
@@ -351,9 +329,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
 
         # We are removing filters, so update the number of outgoing channels (OFMs)
         # in the convolutional layer
-        layer_name = param_name_2_layer_name(param_name)
         assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
-        append_module_directive(model, thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
+        append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
 
         # Select only the non-zero filters
         indices = nonzero_filters.data.squeeze()
@@ -364,24 +341,20 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
             append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices))
 
         # Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer
-        msglogger.debug("{} => {}".format(layer_name, normalize_module_name(layer_name)))
-        successors = sgraph.successors_f(normalize_module_name(layer_name), ['Conv', 'Gemm'])
-        # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
-        successors = [denormalize_module_name(model, successor) for successor in successors]
+        successors = sgraph.successors_f(layer_name, ['Conv', 'Gemm'])
         for successor in successors:
-
             if isinstance(layers[successor], torch.nn.modules.Conv2d):
                 # For each of the convolutional layers that follow, we have to reduce the number of input channels.
-                append_module_directive(model, thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
+                append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
 
                 # Now remove channels from the weights tensor of the successor conv
-                append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices))
+                append_param_directive(thinning_recipe, successor+'.weight', (1, indices))
 
             elif isinstance(layers[successor], torch.nn.modules.Linear):
                 # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member
                 fm_size = layers[successor].in_features // layers[layer_name].out_channels
                 in_features = fm_size * num_nnz_filters
-                append_module_directive(model, thinning_recipe, successor, key='in_features', val=in_features)
+                append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
                 msglogger.debug("[recipe] Linear {}: fm_size = {}  layers[{}].out_channels={}".format(
                                 successor, in_features, layer_name, layers[layer_name].out_channels))
                 msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features))
@@ -391,18 +364,16 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
                 fm_height = fm_width = int(math.sqrt(fm_size))
                 view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width)
                 view_2D = (layers[successor].out_features, in_features)
-                append_param_directive(thinning_recipe,
-                                       denormalize_module_name(model, successor)+'.weight',
+                append_param_directive(thinning_recipe, successor+'.weight',
                                        (1, indices, view_4D, view_2D))
 
         # Now handle the BatchNormalization layer that follows the convolution
-        bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization'])
+        bn_layers = sgraph.successors_f(layer_name, ['BatchNormalization'])
         if len(bn_layers) > 0:
             assert len(bn_layers) == 1
             # Thinning of the BN layer that follows the convolution
-            bn_layer_name = denormalize_module_name(model, bn_layers[0])
-            append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name,
-                                         len_thin_features=num_nnz_filters, thin_features=indices)
+            append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0],
+                len_thin_features=num_nnz_filters, thin_features=indices)
     return thinning_recipe
 
 
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 0d378fa9f18df827cbde0af82ce785f6b6d71d7d..195a9825adba8ccc9cdfb5f0f65f471834170d72 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -161,6 +161,23 @@ def test_normalize_module_name():
     name_test('imagenet', 'alexnet')
 
 
+def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
+    model = create_model(False, dataset, arch, parallel=dataparallel)
+    sgraph = SummaryGraph(model, get_input(dataset))
+    sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
+    for layer_name in sgraph_layer_names:
+        assert (sgraph.find_op(layer_name) is not None,
+            '{} was not found in summary graph'.format(layer_name))
+
+
+def test_named_params_layers():
+    for dataParallelModel in (True, False):
+        named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
+        named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
+        named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
+        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
+
+
 def test_onnx_name_2_pytorch_name():
     assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu')
     assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')