From af4bf3dc753e67c1f34958b1b5425d7ef50561c4 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 9 Jul 2018 01:44:54 +0300
Subject: [PATCH] Bug fix in connectivity_summary; extend the API of
 create_png()

*connectivity_summary() does not use SummaryGraph correctly:
 Recently we changed the internal representation of SummaryGraph.ops, but
 connectivity_summary() and connectivity_summary_verbose() were not updated.
 Fixed that.

*Extend the API of create_png():
 Add to the signature of create_png() and create_pydot_graph() rankdir and
 External styles.  These are explained in the docstrings.

*Added documentation to the PNG drawing functions

*Added tests to catch trivial connectivity_summary() bugs
---
 apputils/model_summaries.py | 134 ++++++++++++++++++++++++------------
 tests/test_summarygraph.py  |  48 +++++++++++--
 2 files changed, 132 insertions(+), 50 deletions(-)

diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py
index 953ab8f..3cd46ca 100755
--- a/apputils/model_summaries.py
+++ b/apputils/model_summaries.py
@@ -178,7 +178,7 @@ class SummaryGraph(object):
             s = str(n.type())
             s = s[s.find('(')+1: s.find(')')]
             tensor['shape'] = tuple(map(lambda x: int(x), s.split(',')))
-        except:
+        except ValueError:
             # Size not specified in type
             tensor['shape'] = 0,
         return tensor
@@ -198,14 +198,14 @@ class SummaryGraph(object):
             op['attrs']['MACs'] = 0
             if op['type'] == 'Conv':
                 conv_out = op['outputs'][0]
-                conv_in =  op['inputs'][0]
+                conv_in = op['inputs'][0]
                 conv_w = op['attrs']['kernel_shape']
                 ofm_vol = self.param_volume(conv_out)
                 # MACs = volume(OFM) * (#IFM * K^2)
                 op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1]
             elif op['type'] == 'Gemm':
-                conv_out =  op['outputs'][0]
-                conv_in =  op['inputs'][0]
+                conv_out = op['outputs'][0]
+                conv_in = op['inputs'][0]
                 n_ifm = self.param_shape(conv_in)[1]
                 n_ofm = self.param_shape(conv_out)[1]
                 # MACs = #IFM * #OFM
@@ -424,7 +424,7 @@ def connectivity_summary(sgraph):
     """
     df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs'])
     pd.set_option('precision', 5)
-    for i, op in enumerate(sgraph.ops):
+    for i, op in enumerate(sgraph.ops.values()):
         df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']]
     return df
 
@@ -443,7 +443,7 @@ def connectivity_summary_verbose(sgraph):
 
     df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs'])
     pd.set_option('precision', 5)
-    for i, op in enumerate(sgraph.ops):
+    for i, op in enumerate(sgraph.ops.values()):
         outputs = []
         for blob in op['outputs']:
             if blob in sgraph.params:
@@ -467,16 +467,21 @@ def connectivity_tbl_summary(sgraph, verbose=False):
     return tabulate(df, headers='keys', tablefmt='psql')
 
 
-def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges):
-    pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir='TB')
+def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', styles=None):
+    """Low-level API to create a PyDot graph (dot formatted).
+    """
+    pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir=rankdir)
 
-    node_style = {'shape': 'record',
-                  'fillcolor': '#6495ED',
-                  'style': 'rounded, filled'}
+    op_node_style = {'shape': 'record',
+                     'fillcolor': '#6495ED',
+                     'style': 'rounded, filled'}
 
     for op_node in op_nodes:
-        pydot_graph.add_node(pydot.Node(op_node[0], **node_style,
-                             label="\n".join(op_node)))
+        style = op_node_style
+        # Check if we should override the style of this node.
+        if styles is not None and op_node[0] in styles:
+            style = styles[op_node[0]]
+        pydot_graph.add_node(pydot.Node(op_node[0], **style, label="\n".join(op_node)))
 
     for data_node in data_nodes:
         pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node)))
@@ -495,37 +500,20 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges):
     return pydot_graph
 
 
-def draw_model_to_file(sgraph, png_fname, display_param_nodes=False):
-    """Create a PNG file, containing a graphiz-dot graph of the netowrk represented
-    by SummaryGraph 'sgraph'
-    """
-    png = create_png(sgraph, display_param_nodes=display_param_nodes)
-    with open(png_fname, 'wb') as fid:
-        fid.write(png)
-
-
-def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False):
-    try:
-        if dataset == 'imagenet':
-            dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
-        elif dataset == 'cifar10':
-            dummy_input = Variable(torch.randn(1, 3, 32, 32))
-        else:
-            print("Unsupported dataset (%s) - aborting draw operation" % dataset)
-            return
-
-        g = SummaryGraph(model, dummy_input)
-        draw_model_to_file(g, png_fname, display_param_nodes)
-        print("Network PNG image generation completed")
-    except FileNotFoundError:
-        print("An error has occured while generating the network PNG image.")
-        print("Please check that you have graphviz installed.")
-        print("\t$ sudo apt-get install graphviz")
-
-
-def create_png(sgraph, display_param_nodes=False):
+def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
     """Create a PNG object containing a graphiz-dot graph of the network,
-    as represented by SummaryGraph 'sgraph'
+    as represented by SummaryGraph 'sgraph'.
+
+    Args:
+        sgraph (SummaryGraph): the SummaryGraph instance to draw.
+        display_param_nodes (boolean): if True, draw the parameter nodes
+        rankdir: diagram direction.  'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
+                 'LR'/'R/L' is Left-to-Rt/Rt-to-Left
+        styles: a dictionary of styles.  Key is module name.  Value is
+                a legal pydot style dictionary.  For example:
+                styles['conv1'] = {'shape': 'oval',
+                                   'fillcolor': 'gray',
+                                   'style': 'rounded, filled'}
     """
 
     op_nodes = [op['name'] for op in sgraph.ops.values()]
@@ -546,11 +534,69 @@ def create_png(sgraph, display_param_nodes=False):
         param_nodes = None
 
     op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()]
-    pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges)
+    pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir, styles)
     png = pydot_graph.create_png()
     return png
 
 
+def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB', styles=None):
+    """Create a PNG file, containing a graphiz-dot graph of the netowrk represented
+    by SummaryGraph 'sgraph'
+
+    Args:
+        sgraph (SummaryGraph): the SummaryGraph instance to draw.
+        png_fname (string): PNG file name
+        display_param_nodes (boolean): if True, draw the parameter nodes
+        rankdir: diagram direction.  'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
+                 'LR'/'R/L' is Left-to-Rt/Rt-to-Left
+        styles: a dictionary of styles.  Key is module name.  Value is
+                a legal pydot style dictionary.  For example:
+                styles['conv1'] = {'shape': 'oval',
+                                   'fillcolor': 'gray',
+                                   'style': 'rounded, filled'}
+        """
+    png = create_png(sgraph, display_param_nodes=display_param_nodes)
+    with open(png_fname, 'wb') as fid:
+        fid.write(png)
+
+
+def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False,
+                                rankdir='TB', styles=None):
+    """Draw a PyTorch image classifier to a PNG file.  This a helper function that
+    simplifies the interface of draw_model_to_file().
+
+    Args:
+        model: PyTorch model instance
+        png_fname (string): PNG file name
+        dataset (string): one of 'imagenet' or 'cifar10'.  This is required in order to
+                          create a dummy input of the correct shape.
+        display_param_nodes (boolean): if True, draw the parameter nodes
+        rankdir: diagram direction.  'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
+                 'LR'/'R/L' is Left-to-Rt/Rt-to-Left
+        styles: a dictionary of styles.  Key is module name.  Value is
+                a legal pydot style dictionary.  For example:
+                styles['conv1'] = {'shape': 'oval',
+                                   'fillcolor': 'gray',
+                                   'style': 'rounded, filled'}
+    """
+    try:
+        if dataset == 'imagenet':
+            dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
+        elif dataset == 'cifar10':
+            dummy_input = Variable(torch.randn(1, 3, 32, 32))
+        else:
+            print("Unsupported dataset (%s) - aborting draw operation" % dataset)
+            return
+
+        g = SummaryGraph(model, dummy_input)
+        draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles)
+        print("Network PNG image generation completed")
+    except FileNotFoundError:
+        print("An error has occured while generating the network PNG image.")
+        print("Please check that you have graphviz installed.")
+        print("\t$ sudo apt-get install graphviz")
+
+
 def data_node_has_parent(g, id):
     for edge in g.edges:
         if edge.dst == id: return True
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 5bcb82d..f36f565 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -22,10 +22,8 @@ module_path = os.path.abspath(os.path.join('..'))
 if module_path not in sys.path:
     sys.path.append(module_path)
 import distiller
-
-import pytest
 from models import ALL_MODEL_NAMES, create_model
-from apputils import SummaryGraph, onnx_name_2_pytorch_name
+from apputils import *
 from distiller import normalize_module_name, denormalize_module_name
 
 # Logging configuration
@@ -66,13 +64,13 @@ def test_connectivity():
 
     edges = g.edges
     assert edges[0].src == '0' and edges[0].dst == 'conv1'
-    #logging.debug(g.ops[1]['name'])
+
     # Test two sequential calls to predecessors (this was a bug once)
     preds = g.predecessors(g.find_op('bn1'), 1)
     preds = g.predecessors(g.find_op('bn1'), 1)
     assert preds == ['108', '2', '3', '4', '5']
     # Test successors
-    succs = g.successors(g.find_op('bn1'), 2)#, logging)
+    succs = g.successors(g.find_op('bn1'), 2)
     assert succs == ['relu']
 
     op = g.find_op('layer1.0')
@@ -89,7 +87,6 @@ def test_connectivity():
     assert preds == []
     preds = g.predecessors(g.find_op('bn1'), 3)
     assert preds == ['0', '1']
-    #logging.debug(preds)
 
 
 def test_layer_search():
@@ -133,6 +130,30 @@ def test_vgg():
     succs = g.successors_f('features.34', 'Conv')
 
 
+def test_simplenet():
+    g = create_graph('cifar10', 'simplenet_cifar')
+    assert g is not None
+    preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds))
+    assert len(preds) == 0
+
+    preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds))
+    assert len(preds) == 1
+
+
+def test_simplenet():
+    g = create_graph('cifar10', 'simplenet_cifar')
+    assert g is not None
+    preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds))
+    assert len(preds) == 0
+
+    preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv')
+    logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds))
+    assert len(preds) == 1
+
+
 def name_test(dataset, arch):
     model = create_model(False, dataset, arch, parallel=False)
     modelp = create_model(False, dataset, arch, parallel=True)
@@ -163,3 +184,18 @@ def test_onnx_name_2_pytorch_name():
     assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')
     assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu')
     #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv')
+
+
+def test_connectivity_summary():
+    g = create_graph('cifar10', 'resnet20_cifar')
+    assert g is not None
+
+    summary = connectivity_summary(g)
+    assert len(summary) == 73
+
+    verbose_summary = connectivity_summary_verbose(g)
+    assert len(verbose_summary  ) == 73
+
+
+if __name__ == '__main__':
+    test_connectivity_summary()
-- 
GitLab