From af4bf3dc753e67c1f34958b1b5425d7ef50561c4 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 9 Jul 2018 01:44:54 +0300 Subject: [PATCH] Bug fix in connectivity_summary; extend the API of create_png() *connectivity_summary() does not use SummaryGraph correctly: Recently we changed the internal representation of SummaryGraph.ops, but connectivity_summary() and connectivity_summary_verbose() were not updated. Fixed that. *Extend the API of create_png(): Add to the signature of create_png() and create_pydot_graph() rankdir and External styles. These are explained in the docstrings. *Added documentation to the PNG drawing functions *Added tests to catch trivial connectivity_summary() bugs --- apputils/model_summaries.py | 134 ++++++++++++++++++++++++------------ tests/test_summarygraph.py | 48 +++++++++++-- 2 files changed, 132 insertions(+), 50 deletions(-) diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 953ab8f..3cd46ca 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -178,7 +178,7 @@ class SummaryGraph(object): s = str(n.type()) s = s[s.find('(')+1: s.find(')')] tensor['shape'] = tuple(map(lambda x: int(x), s.split(','))) - except: + except ValueError: # Size not specified in type tensor['shape'] = 0, return tensor @@ -198,14 +198,14 @@ class SummaryGraph(object): op['attrs']['MACs'] = 0 if op['type'] == 'Conv': conv_out = op['outputs'][0] - conv_in = op['inputs'][0] + conv_in = op['inputs'][0] conv_w = op['attrs']['kernel_shape'] ofm_vol = self.param_volume(conv_out) # MACs = volume(OFM) * (#IFM * K^2) op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] elif op['type'] == 'Gemm': - conv_out = op['outputs'][0] - conv_in = op['inputs'][0] + conv_out = op['outputs'][0] + conv_in = op['inputs'][0] n_ifm = self.param_shape(conv_in)[1] n_ofm = self.param_shape(conv_out)[1] # MACs = #IFM * #OFM @@ -424,7 +424,7 @@ def connectivity_summary(sgraph): """ df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs']) pd.set_option('precision', 5) - for i, op in enumerate(sgraph.ops): + for i, op in enumerate(sgraph.ops.values()): df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']] return df @@ -443,7 +443,7 @@ def connectivity_summary_verbose(sgraph): df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs']) pd.set_option('precision', 5) - for i, op in enumerate(sgraph.ops): + for i, op in enumerate(sgraph.ops.values()): outputs = [] for blob in op['outputs']: if blob in sgraph.params: @@ -467,16 +467,21 @@ def connectivity_tbl_summary(sgraph, verbose=False): return tabulate(df, headers='keys', tablefmt='psql') -def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges): - pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir='TB') +def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', styles=None): + """Low-level API to create a PyDot graph (dot formatted). + """ + pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir=rankdir) - node_style = {'shape': 'record', - 'fillcolor': '#6495ED', - 'style': 'rounded, filled'} + op_node_style = {'shape': 'record', + 'fillcolor': '#6495ED', + 'style': 'rounded, filled'} for op_node in op_nodes: - pydot_graph.add_node(pydot.Node(op_node[0], **node_style, - label="\n".join(op_node))) + style = op_node_style + # Check if we should override the style of this node. + if styles is not None and op_node[0] in styles: + style = styles[op_node[0]] + pydot_graph.add_node(pydot.Node(op_node[0], **style, label="\n".join(op_node))) for data_node in data_nodes: pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node))) @@ -495,37 +500,20 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges): return pydot_graph -def draw_model_to_file(sgraph, png_fname, display_param_nodes=False): - """Create a PNG file, containing a graphiz-dot graph of the netowrk represented - by SummaryGraph 'sgraph' - """ - png = create_png(sgraph, display_param_nodes=display_param_nodes) - with open(png_fname, 'wb') as fid: - fid.write(png) - - -def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False): - try: - if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) - else: - print("Unsupported dataset (%s) - aborting draw operation" % dataset) - return - - g = SummaryGraph(model, dummy_input) - draw_model_to_file(g, png_fname, display_param_nodes) - print("Network PNG image generation completed") - except FileNotFoundError: - print("An error has occured while generating the network PNG image.") - print("Please check that you have graphviz installed.") - print("\t$ sudo apt-get install graphviz") - - -def create_png(sgraph, display_param_nodes=False): +def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None): """Create a PNG object containing a graphiz-dot graph of the network, - as represented by SummaryGraph 'sgraph' + as represented by SummaryGraph 'sgraph'. + + Args: + sgraph (SummaryGraph): the SummaryGraph instance to draw. + display_param_nodes (boolean): if True, draw the parameter nodes + rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top + 'LR'/'R/L' is Left-to-Rt/Rt-to-Left + styles: a dictionary of styles. Key is module name. Value is + a legal pydot style dictionary. For example: + styles['conv1'] = {'shape': 'oval', + 'fillcolor': 'gray', + 'style': 'rounded, filled'} """ op_nodes = [op['name'] for op in sgraph.ops.values()] @@ -546,11 +534,69 @@ def create_png(sgraph, display_param_nodes=False): param_nodes = None op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()] - pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges) + pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir, styles) png = pydot_graph.create_png() return png +def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB', styles=None): + """Create a PNG file, containing a graphiz-dot graph of the netowrk represented + by SummaryGraph 'sgraph' + + Args: + sgraph (SummaryGraph): the SummaryGraph instance to draw. + png_fname (string): PNG file name + display_param_nodes (boolean): if True, draw the parameter nodes + rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top + 'LR'/'R/L' is Left-to-Rt/Rt-to-Left + styles: a dictionary of styles. Key is module name. Value is + a legal pydot style dictionary. For example: + styles['conv1'] = {'shape': 'oval', + 'fillcolor': 'gray', + 'style': 'rounded, filled'} + """ + png = create_png(sgraph, display_param_nodes=display_param_nodes) + with open(png_fname, 'wb') as fid: + fid.write(png) + + +def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False, + rankdir='TB', styles=None): + """Draw a PyTorch image classifier to a PNG file. This a helper function that + simplifies the interface of draw_model_to_file(). + + Args: + model: PyTorch model instance + png_fname (string): PNG file name + dataset (string): one of 'imagenet' or 'cifar10'. This is required in order to + create a dummy input of the correct shape. + display_param_nodes (boolean): if True, draw the parameter nodes + rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top + 'LR'/'R/L' is Left-to-Rt/Rt-to-Left + styles: a dictionary of styles. Key is module name. Value is + a legal pydot style dictionary. For example: + styles['conv1'] = {'shape': 'oval', + 'fillcolor': 'gray', + 'style': 'rounded, filled'} + """ + try: + if dataset == 'imagenet': + dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) + elif dataset == 'cifar10': + dummy_input = Variable(torch.randn(1, 3, 32, 32)) + else: + print("Unsupported dataset (%s) - aborting draw operation" % dataset) + return + + g = SummaryGraph(model, dummy_input) + draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles) + print("Network PNG image generation completed") + except FileNotFoundError: + print("An error has occured while generating the network PNG image.") + print("Please check that you have graphviz installed.") + print("\t$ sudo apt-get install graphviz") + + def data_node_has_parent(g, id): for edge in g.edges: if edge.dst == id: return True diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 5bcb82d..f36f565 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -22,10 +22,8 @@ module_path = os.path.abspath(os.path.join('..')) if module_path not in sys.path: sys.path.append(module_path) import distiller - -import pytest from models import ALL_MODEL_NAMES, create_model -from apputils import SummaryGraph, onnx_name_2_pytorch_name +from apputils import * from distiller import normalize_module_name, denormalize_module_name # Logging configuration @@ -66,13 +64,13 @@ def test_connectivity(): edges = g.edges assert edges[0].src == '0' and edges[0].dst == 'conv1' - #logging.debug(g.ops[1]['name']) + # Test two sequential calls to predecessors (this was a bug once) preds = g.predecessors(g.find_op('bn1'), 1) preds = g.predecessors(g.find_op('bn1'), 1) assert preds == ['108', '2', '3', '4', '5'] # Test successors - succs = g.successors(g.find_op('bn1'), 2)#, logging) + succs = g.successors(g.find_op('bn1'), 2) assert succs == ['relu'] op = g.find_op('layer1.0') @@ -89,7 +87,6 @@ def test_connectivity(): assert preds == [] preds = g.predecessors(g.find_op('bn1'), 3) assert preds == ['0', '1'] - #logging.debug(preds) def test_layer_search(): @@ -133,6 +130,30 @@ def test_vgg(): succs = g.successors_f('features.34', 'Conv') +def test_simplenet(): + g = create_graph('cifar10', 'simplenet_cifar') + assert g is not None + preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv') + logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds)) + assert len(preds) == 0 + + preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv') + logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds)) + assert len(preds) == 1 + + +def test_simplenet(): + g = create_graph('cifar10', 'simplenet_cifar') + assert g is not None + preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv') + logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds)) + assert len(preds) == 0 + + preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv') + logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds)) + assert len(preds) == 1 + + def name_test(dataset, arch): model = create_model(False, dataset, arch, parallel=False) modelp = create_model(False, dataset, arch, parallel=True) @@ -163,3 +184,18 @@ def test_onnx_name_2_pytorch_name(): assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv') assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu') #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv') + + +def test_connectivity_summary(): + g = create_graph('cifar10', 'resnet20_cifar') + assert g is not None + + summary = connectivity_summary(g) + assert len(summary) == 73 + + verbose_summary = connectivity_summary_verbose(g) + assert len(verbose_summary ) == 73 + + +if __name__ == '__main__': + test_connectivity_summary() -- GitLab