diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py
index b956abc0914224b60c1fb8a96031e10225832c37..953ab8fd35fea0fec53070762cc992567325c00e 100755
--- a/apputils/model_summaries.py
+++ b/apputils/model_summaries.py
@@ -30,8 +30,10 @@ from torch.autograd import Variable
 import torch.jit as jit
 import pandas as pd
 from tabulate import tabulate
+import pydot
+
 
-def onnx_name_2_pytorch_name(name):
+def onnx_name_2_pytorch_name(name, op_type):
     # Convert a layer's name from an ONNX name, to a PyTorch name
     # For example:
     #   ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1 ==> layer3.0.relu.1
@@ -45,12 +47,16 @@ def onnx_name_2_pytorch_name(name):
     name_parts = re.findall('\[.*?\]', name)
     name_parts = [part[1:-1] for part in name_parts]
 
-    new_name = '.'.join(name_parts) + instance
-    if new_name == '':
-        new_name = name
+    # If name doesn't have the pattern above, it probably means the op was called via
+    # some functional API and not via a module. Couple of examples:
+    #   x = x.view(...)
+    #   x = F.relu(x)
+    # In this case, to have a meaningful name, we use the op type
+    new_name = ('.'.join(name_parts) if len(name_parts) > 0 else op_type) + instance
 
     return new_name
 
+
 class SummaryGraph(object):
     """We use Pytorch's JIT tracer to run a forward pass and generate a trace graph, which
     is an internal representation of the model.  We then use ONNX to "clean" this
@@ -102,13 +108,18 @@ class SummaryGraph(object):
             for node in graph.nodes():
                 new_op = self.__create_op(node)
 
-                # in-place operators create very confusing graphs (Resnet, for example),
-                # so we "unroll" them
-                same = [op for op in self.ops.values() if op['orig-name'] == new_op['orig-name']]
+                # Operators with the same name create very confusing graphs (Resnet, for example),
+                # so we "unroll" them.
+                # Sometimes operations of different types have the same name, so we differentiate
+                # using both name and type
+                # (this happens, for example, when an operator is called via some functional API and
+                # not via a module)
+                same = [op for op in self.ops.values() if
+                        op['orig-name'] + op['type'] == new_op['orig-name'] + new_op['type']]
                 if len(same) > 0:
                     new_op['name'] += "." + str(len(same))
 
-                new_op['name'] = onnx_name_2_pytorch_name(new_op['name'])
+                new_op['name'] = onnx_name_2_pytorch_name(new_op['name'], new_op['type'])
                 assert len(new_op['name']) > 0
 
                 self.ops[new_op['name']] = new_op
@@ -137,21 +148,21 @@ class SummaryGraph(object):
         op['params'] = []
         return op
 
-
     def __add_input(self, op, n):
         param = self.__add_param(n)
-        if param is None: return
+        if param is None:
+            return
         if param['id'] not in op['inputs']:
             op['inputs'].append(param['id'])
 
     def __add_output(self, op, n):
         param = self.__add_param(n)
-        if param is None: return
+        if param is None:
+            return
         if param['id'] not in op['outputs']:
             op['outputs'].append(param['id'])
 
     def __add_param(self, n):
-        param = {}
         if n.uniqueName() not in self.params:
             param = self.__tensor_desc(n)
             self.params[n.uniqueName()] = param
@@ -168,7 +179,8 @@ class SummaryGraph(object):
             s = s[s.find('(')+1: s.find(')')]
             tensor['shape'] = tuple(map(lambda x: int(x), s.split(',')))
         except:
-            return None
+            # Size not specified in type
+            tensor['shape'] = 0,
         return tensor
 
     def param_shape(self, param_id):
@@ -261,7 +273,6 @@ class SummaryGraph(object):
                 ret += self.predecessors(predecessor, depth-1, done_list) #, logging)
             return ret
 
-
     def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None):
         """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria.
         """
@@ -299,7 +310,6 @@ class SummaryGraph(object):
             ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging)
         return ret
 
-
     def successors(self, node, depth, done_list=None):
         """Returns a list of <op>'s successors"""
 
@@ -370,6 +380,7 @@ class SummaryGraph(object):
             ret += self.successors_f(successor, successors_types, done_list, logging)
         return ret
 
+
 def attributes_summary(sgraph, ignore_attrs):
     """Generate a summary of a graph's attributes.
 
@@ -399,10 +410,12 @@ def attributes_summary(sgraph, ignore_attrs):
         df.loc[i] = [op['name'], op['type'], pretty_attrs(op['attrs'], ignore_attrs)]
     return df
 
+
 def attributes_summary_tbl(sgraph, ignore_attrs):
     df = attributes_summary(sgraph, ignore_attrs)
     return tabulate(df, headers='keys', tablefmt='psql')
 
+
 def connectivity_summary(sgraph):
     """Generate a summary of each node's connectivity.
 
@@ -415,6 +428,7 @@ def connectivity_summary(sgraph):
         df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']]
     return df
 
+
 def connectivity_summary_verbose(sgraph):
     """Generate a summary of each node's connectivity, with details
     about the parameters.
@@ -444,6 +458,7 @@ def connectivity_summary_verbose(sgraph):
 
     return df
 
+
 def connectivity_tbl_summary(sgraph, verbose=False):
     if verbose:
         df = connectivity_summary_verbose(sgraph)
@@ -452,9 +467,6 @@ def connectivity_tbl_summary(sgraph, verbose=False):
     return tabulate(df, headers='keys', tablefmt='psql')
 
 
-
-import pydot
-
 def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges):
     pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir='TB')
 
@@ -467,32 +479,32 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges):
                              label="\n".join(op_node)))
 
     for data_node in data_nodes:
-        #pydot_graph.add_node(pydot.Node(data_node))
         pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node)))
 
-
     node_style = {'shape': 'oval',
                   'fillcolor': 'gray',
                   'style': 'rounded, filled'}
 
     if param_nodes is not None:
-        for data_node in param_nodes:
-            pydot_graph.add_node(pydot.Node(data_node, **node_style))
+        for param_node in param_nodes:
+            pydot_graph.add_node(pydot.Node(param_node[0], **node_style, label="\n".join(param_node)))
 
     for edge in edges:
         pydot_graph.add_edge(pydot.Edge(edge[0], edge[1]))
 
     return pydot_graph
 
-def draw_model_to_file(sgraph, png_fname):
+
+def draw_model_to_file(sgraph, png_fname, display_param_nodes=False):
     """Create a PNG file, containing a graphiz-dot graph of the netowrk represented
     by SummaryGraph 'sgraph'
     """
-    png = create_png(sgraph)
+    png = create_png(sgraph, display_param_nodes=display_param_nodes)
     with open(png_fname, 'wb') as fid:
         fid.write(png)
 
-def draw_img_classifier_to_file(model, png_fname, dataset):
+
+def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False):
     try:
         if dataset == 'imagenet':
             dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
@@ -503,41 +515,42 @@ def draw_img_classifier_to_file(model, png_fname, dataset):
             return
 
         g = SummaryGraph(model, dummy_input)
-        draw_model_to_file(g, png_fname)
+        draw_model_to_file(g, png_fname, display_param_nodes)
         print("Network PNG image generation completed")
-    except TypeError as e:
-        print("An error has occured while generating the network PNG image.")
-        print("This feature is not supported on official PyTorch releases.")
-        print("Please check that you are using a valid PyTorch version.")
-        print("You are using pytorch version %s" %torch.__version__)
     except FileNotFoundError:
         print("An error has occured while generating the network PNG image.")
         print("Please check that you have graphviz installed.")
         print("\t$ sudo apt-get install graphviz")
 
+
 def create_png(sgraph, display_param_nodes=False):
     """Create a PNG object containing a graphiz-dot graph of the network,
     as represented by SummaryGraph 'sgraph'
     """
 
-    op_nodes = [op['name'] for op in sgraph.ops]
-    data_nodes = [(id,param.shape) for (id, param) in sgraph.params.items() if data_node_has_parent(sgraph, id)]
-    param_nodes = [id for id in sgraph.params.keys() if not data_node_has_parent(sgraph, id)]
+    op_nodes = [op['name'] for op in sgraph.ops.values()]
+    data_nodes = []
+    param_nodes = []
+    for id, param in sgraph.params.items():
+        n_data = (id, str(param['shape']))
+        if data_node_has_parent(sgraph, id):
+            data_nodes.append(n_data)
+        else:
+            param_nodes.append(n_data)
     edges = sgraph.edges
 
     if not display_param_nodes:
         # Use only the edges that don't have a parameter source
-        edges = [edge for edge in sgraph.edges if edge.src in (data_nodes+op_nodes)]
+        non_param_ids = op_nodes + [dn[0] for dn in data_nodes]
+        edges = [edge for edge in sgraph.edges if edge.src in non_param_ids]
         param_nodes = None
 
-    if False:
-        data_nodes = None
-
-    op_nodes = [(op['name'], op['type']) for op in sgraph.ops]
+    op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()]
     pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges)
     png = pydot_graph.create_png()
     return png
 
+
 def data_node_has_parent(g, id):
     for edge in g.edges:
         if edge.dst == id: return True
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 3a2d5500fb52fa243451b0b13ae5d99efb946ece..338ec78438c0bdbbc757dff475c34e744445983c 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -111,7 +111,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true',
                     help='collect activation statistics (WARNING: this slows down training)')
 parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                     help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
-SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png']
+SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'png_w_params']
 parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
                     help='print a summary of the model, and exit - options: ' +
                     ' | '.join(SUMMARY_CHOICES))
@@ -193,7 +193,8 @@ def main():
     args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
 
     # Create the model
-    is_parallel = args.summary != 'png' # For PNG summary, parallel graphs are illegible
+    png_summary = args.summary is not None and args.summary.startswith('png')
+    is_parallel = not png_summary   # For PNG summary, parallel graphs are illegible
     model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus)
 
     compression_scheduler = None
@@ -223,8 +224,8 @@ def main():
     # This sample application can be invoked to produce various summary reports.
     if args.summary:
         which_summary = args.summary
-        if which_summary == 'png':
-            apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset)
+        if which_summary.startswith('png'):
+            apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset, which_summary == 'png_w_params')
         else:
             distiller.model_summary(model, optimizer, which_summary, args.dataset)
         exit()
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 09c867b6f088271f39ee44e470574efc7e23b296..31145dbb8dedf6d486bf65f681570b3d68995bfb 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -33,6 +33,7 @@ fh = logging.FileHandler('test.log')
 logger = logging.getLogger()
 logger.addHandler(fh)
 
+
 def get_input(dataset):
     if dataset == 'imagenet':
         return torch.randn((1, 3, 224, 224), requires_grad=False)
@@ -40,6 +41,7 @@ def get_input(dataset):
         return torch.randn((1, 3, 32, 32))
     return None
 
+
 def create_graph(dataset, arch):
     dummy_input = get_input(dataset)
     assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
@@ -48,10 +50,12 @@ def create_graph(dataset, arch):
     assert model is not None
     return SummaryGraph(model, dummy_input)
 
+
 def test_graph():
     g = create_graph('cifar10', 'resnet20_cifar')
     assert g is not None
 
+
 def test_connectivity():
     g = create_graph('cifar10', 'resnet20_cifar')
     assert g is not None
@@ -86,6 +90,7 @@ def test_connectivity():
     assert preds == ['0', '1']
     #logging.debug(preds)
 
+
 def test_layer_search():
     g = create_graph('cifar10', 'resnet20_cifar')
     assert g is not None
@@ -118,12 +123,14 @@ def test_layer_search():
     preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging)
     assert preds == ['layer1.0.conv2', 'conv1']
 
+
 def normalize_layer_name(layer_name):
     start = layer_name.find('module.')
     if start != -1:
         layer_name = layer_name[:start] + layer_name[start + len('module.'):]
     return layer_name
 
+
 def test_vgg():
     g = create_graph('imagenet', 'vgg19')
     assert g is not None
@@ -131,12 +138,15 @@ def test_vgg():
     logging.debug(succs)
     succs = g.successors_f('features.34', 'Conv')
 
+
 def test_normalize_layer_name():
     assert "features.0", normalize_layer_name("features.module.0")
     assert "features.0", normalize_layer_name("module.features.0")
     assert "features.0", normalize_layer_name("features.0.module")
 
+
 def test_onnx_name_2_pytorch_name():
-    assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1")
-    assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]')
-    #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]')
+    assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu')
+    assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')
+    assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu')
+    #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv')