Skip to content
Snippets Groups Projects
Commit e7d897a6 authored by Neta Zmora's avatar Neta Zmora
Browse files

model summary: Added tensor sizes annotation to generated model graph PNG

parent 9c701c1c
No related branches found
No related tags found
No related merge requests found
...@@ -489,7 +489,7 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', s ...@@ -489,7 +489,7 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', s
pydot_graph.add_node(pydot.Node(op_node[0], **style, label="\n".join(op_node))) pydot_graph.add_node(pydot.Node(op_node[0], **style, label="\n".join(op_node)))
for data_node in data_nodes: for data_node in data_nodes:
pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node))) pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node[1:])))
node_style = {'shape': 'oval', node_style = {'shape': 'oval',
'fillcolor': 'gray', 'fillcolor': 'gray',
...@@ -497,7 +497,7 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', s ...@@ -497,7 +497,7 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', s
if param_nodes is not None: if param_nodes is not None:
for param_node in param_nodes: for param_node in param_nodes:
pydot_graph.add_node(pydot.Node(param_node[0], **node_style, label="\n".join(param_node))) pydot_graph.add_node(pydot.Node(param_node[0], **node_style, label="\n".join(param_node[1:])))
for edge in edges: for edge in edges:
pydot_graph.add_edge(pydot.Edge(edge[0], edge[1])) pydot_graph.add_edge(pydot.Edge(edge[0], edge[1]))
...@@ -525,7 +525,7 @@ def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None): ...@@ -525,7 +525,7 @@ def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
data_nodes = [] data_nodes = []
param_nodes = [] param_nodes = []
for id, param in sgraph.params.items(): for id, param in sgraph.params.items():
n_data = (id, str(param['shape'])) n_data = (id, str(distiller.volume(param['shape'])), str(param['shape']))
if data_node_has_parent(sgraph, id): if data_node_has_parent(sgraph, id):
data_nodes.append(n_data) data_nodes.append(n_data)
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment