Skip to content
Snippets Groups Projects
Commit af4bf3dc authored by Neta Zmora's avatar Neta Zmora
Browse files

Bug fix in connectivity_summary; extend the API of create_png()

*connectivity_summary() does not use SummaryGraph correctly:
 Recently we changed the internal representation of SummaryGraph.ops, but
 connectivity_summary() and connectivity_summary_verbose() were not updated.
 Fixed that.

*Extend the API of create_png():
 Add to the signature of create_png() and create_pydot_graph() rankdir and
 External styles.  These are explained in the docstrings.

*Added documentation to the PNG drawing functions

*Added tests to catch trivial connectivity_summary() bugs
parent c14efaa9
No related branches found
No related tags found
No related merge requests found
...@@ -178,7 +178,7 @@ class SummaryGraph(object): ...@@ -178,7 +178,7 @@ class SummaryGraph(object):
s = str(n.type()) s = str(n.type())
s = s[s.find('(')+1: s.find(')')] s = s[s.find('(')+1: s.find(')')]
tensor['shape'] = tuple(map(lambda x: int(x), s.split(','))) tensor['shape'] = tuple(map(lambda x: int(x), s.split(',')))
except: except ValueError:
# Size not specified in type # Size not specified in type
tensor['shape'] = 0, tensor['shape'] = 0,
return tensor return tensor
...@@ -198,14 +198,14 @@ class SummaryGraph(object): ...@@ -198,14 +198,14 @@ class SummaryGraph(object):
op['attrs']['MACs'] = 0 op['attrs']['MACs'] = 0
if op['type'] == 'Conv': if op['type'] == 'Conv':
conv_out = op['outputs'][0] conv_out = op['outputs'][0]
conv_in = op['inputs'][0] conv_in = op['inputs'][0]
conv_w = op['attrs']['kernel_shape'] conv_w = op['attrs']['kernel_shape']
ofm_vol = self.param_volume(conv_out) ofm_vol = self.param_volume(conv_out)
# MACs = volume(OFM) * (#IFM * K^2) # MACs = volume(OFM) * (#IFM * K^2)
op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1]
elif op['type'] == 'Gemm': elif op['type'] == 'Gemm':
conv_out = op['outputs'][0] conv_out = op['outputs'][0]
conv_in = op['inputs'][0] conv_in = op['inputs'][0]
n_ifm = self.param_shape(conv_in)[1] n_ifm = self.param_shape(conv_in)[1]
n_ofm = self.param_shape(conv_out)[1] n_ofm = self.param_shape(conv_out)[1]
# MACs = #IFM * #OFM # MACs = #IFM * #OFM
...@@ -424,7 +424,7 @@ def connectivity_summary(sgraph): ...@@ -424,7 +424,7 @@ def connectivity_summary(sgraph):
""" """
df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs']) df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs'])
pd.set_option('precision', 5) pd.set_option('precision', 5)
for i, op in enumerate(sgraph.ops): for i, op in enumerate(sgraph.ops.values()):
df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']] df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']]
return df return df
...@@ -443,7 +443,7 @@ def connectivity_summary_verbose(sgraph): ...@@ -443,7 +443,7 @@ def connectivity_summary_verbose(sgraph):
df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs']) df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs'])
pd.set_option('precision', 5) pd.set_option('precision', 5)
for i, op in enumerate(sgraph.ops): for i, op in enumerate(sgraph.ops.values()):
outputs = [] outputs = []
for blob in op['outputs']: for blob in op['outputs']:
if blob in sgraph.params: if blob in sgraph.params:
...@@ -467,16 +467,21 @@ def connectivity_tbl_summary(sgraph, verbose=False): ...@@ -467,16 +467,21 @@ def connectivity_tbl_summary(sgraph, verbose=False):
return tabulate(df, headers='keys', tablefmt='psql') return tabulate(df, headers='keys', tablefmt='psql')
def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges): def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', styles=None):
pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir='TB') """Low-level API to create a PyDot graph (dot formatted).
"""
pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir=rankdir)
node_style = {'shape': 'record', op_node_style = {'shape': 'record',
'fillcolor': '#6495ED', 'fillcolor': '#6495ED',
'style': 'rounded, filled'} 'style': 'rounded, filled'}
for op_node in op_nodes: for op_node in op_nodes:
pydot_graph.add_node(pydot.Node(op_node[0], **node_style, style = op_node_style
label="\n".join(op_node))) # Check if we should override the style of this node.
if styles is not None and op_node[0] in styles:
style = styles[op_node[0]]
pydot_graph.add_node(pydot.Node(op_node[0], **style, label="\n".join(op_node)))
for data_node in data_nodes: for data_node in data_nodes:
pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node))) pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node)))
...@@ -495,37 +500,20 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges): ...@@ -495,37 +500,20 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges):
return pydot_graph return pydot_graph
def draw_model_to_file(sgraph, png_fname, display_param_nodes=False): def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
"""Create a PNG file, containing a graphiz-dot graph of the netowrk represented
by SummaryGraph 'sgraph'
"""
png = create_png(sgraph, display_param_nodes=display_param_nodes)
with open(png_fname, 'wb') as fid:
fid.write(png)
def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False):
try:
if dataset == 'imagenet':
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = Variable(torch.randn(1, 3, 32, 32))
else:
print("Unsupported dataset (%s) - aborting draw operation" % dataset)
return
g = SummaryGraph(model, dummy_input)
draw_model_to_file(g, png_fname, display_param_nodes)
print("Network PNG image generation completed")
except FileNotFoundError:
print("An error has occured while generating the network PNG image.")
print("Please check that you have graphviz installed.")
print("\t$ sudo apt-get install graphviz")
def create_png(sgraph, display_param_nodes=False):
"""Create a PNG object containing a graphiz-dot graph of the network, """Create a PNG object containing a graphiz-dot graph of the network,
as represented by SummaryGraph 'sgraph' as represented by SummaryGraph 'sgraph'.
Args:
sgraph (SummaryGraph): the SummaryGraph instance to draw.
display_param_nodes (boolean): if True, draw the parameter nodes
rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
'LR'/'R/L' is Left-to-Rt/Rt-to-Left
styles: a dictionary of styles. Key is module name. Value is
a legal pydot style dictionary. For example:
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
""" """
op_nodes = [op['name'] for op in sgraph.ops.values()] op_nodes = [op['name'] for op in sgraph.ops.values()]
...@@ -546,11 +534,69 @@ def create_png(sgraph, display_param_nodes=False): ...@@ -546,11 +534,69 @@ def create_png(sgraph, display_param_nodes=False):
param_nodes = None param_nodes = None
op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()] op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()]
pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges) pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir, styles)
png = pydot_graph.create_png() png = pydot_graph.create_png()
return png return png
def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB', styles=None):
"""Create a PNG file, containing a graphiz-dot graph of the netowrk represented
by SummaryGraph 'sgraph'
Args:
sgraph (SummaryGraph): the SummaryGraph instance to draw.
png_fname (string): PNG file name
display_param_nodes (boolean): if True, draw the parameter nodes
rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
'LR'/'R/L' is Left-to-Rt/Rt-to-Left
styles: a dictionary of styles. Key is module name. Value is
a legal pydot style dictionary. For example:
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
"""
png = create_png(sgraph, display_param_nodes=display_param_nodes)
with open(png_fname, 'wb') as fid:
fid.write(png)
def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False,
rankdir='TB', styles=None):
"""Draw a PyTorch image classifier to a PNG file. This a helper function that
simplifies the interface of draw_model_to_file().
Args:
model: PyTorch model instance
png_fname (string): PNG file name
dataset (string): one of 'imagenet' or 'cifar10'. This is required in order to
create a dummy input of the correct shape.
display_param_nodes (boolean): if True, draw the parameter nodes
rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
'LR'/'R/L' is Left-to-Rt/Rt-to-Left
styles: a dictionary of styles. Key is module name. Value is
a legal pydot style dictionary. For example:
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
"""
try:
if dataset == 'imagenet':
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = Variable(torch.randn(1, 3, 32, 32))
else:
print("Unsupported dataset (%s) - aborting draw operation" % dataset)
return
g = SummaryGraph(model, dummy_input)
draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles)
print("Network PNG image generation completed")
except FileNotFoundError:
print("An error has occured while generating the network PNG image.")
print("Please check that you have graphviz installed.")
print("\t$ sudo apt-get install graphviz")
def data_node_has_parent(g, id): def data_node_has_parent(g, id):
for edge in g.edges: for edge in g.edges:
if edge.dst == id: return True if edge.dst == id: return True
......
...@@ -22,10 +22,8 @@ module_path = os.path.abspath(os.path.join('..')) ...@@ -22,10 +22,8 @@ module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path: if module_path not in sys.path:
sys.path.append(module_path) sys.path.append(module_path)
import distiller import distiller
import pytest
from models import ALL_MODEL_NAMES, create_model from models import ALL_MODEL_NAMES, create_model
from apputils import SummaryGraph, onnx_name_2_pytorch_name from apputils import *
from distiller import normalize_module_name, denormalize_module_name from distiller import normalize_module_name, denormalize_module_name
# Logging configuration # Logging configuration
...@@ -66,13 +64,13 @@ def test_connectivity(): ...@@ -66,13 +64,13 @@ def test_connectivity():
edges = g.edges edges = g.edges
assert edges[0].src == '0' and edges[0].dst == 'conv1' assert edges[0].src == '0' and edges[0].dst == 'conv1'
#logging.debug(g.ops[1]['name'])
# Test two sequential calls to predecessors (this was a bug once) # Test two sequential calls to predecessors (this was a bug once)
preds = g.predecessors(g.find_op('bn1'), 1) preds = g.predecessors(g.find_op('bn1'), 1)
preds = g.predecessors(g.find_op('bn1'), 1) preds = g.predecessors(g.find_op('bn1'), 1)
assert preds == ['108', '2', '3', '4', '5'] assert preds == ['108', '2', '3', '4', '5']
# Test successors # Test successors
succs = g.successors(g.find_op('bn1'), 2)#, logging) succs = g.successors(g.find_op('bn1'), 2)
assert succs == ['relu'] assert succs == ['relu']
op = g.find_op('layer1.0') op = g.find_op('layer1.0')
...@@ -89,7 +87,6 @@ def test_connectivity(): ...@@ -89,7 +87,6 @@ def test_connectivity():
assert preds == [] assert preds == []
preds = g.predecessors(g.find_op('bn1'), 3) preds = g.predecessors(g.find_op('bn1'), 3)
assert preds == ['0', '1'] assert preds == ['0', '1']
#logging.debug(preds)
def test_layer_search(): def test_layer_search():
...@@ -133,6 +130,30 @@ def test_vgg(): ...@@ -133,6 +130,30 @@ def test_vgg():
succs = g.successors_f('features.34', 'Conv') succs = g.successors_f('features.34', 'Conv')
def test_simplenet():
g = create_graph('cifar10', 'simplenet_cifar')
assert g is not None
preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv')
logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds))
assert len(preds) == 0
preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv')
logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds))
assert len(preds) == 1
def test_simplenet():
g = create_graph('cifar10', 'simplenet_cifar')
assert g is not None
preds = g.predecessors_f(normalize_module_name('module.conv1'), 'Conv')
logging.debug("[simplenet_cifar]: preds of module.conv1 = {}".format(preds))
assert len(preds) == 0
preds = g.predecessors_f(normalize_module_name('module.conv2'), 'Conv')
logging.debug("[simplenet_cifar]: preds of module.conv2 = {}".format(preds))
assert len(preds) == 1
def name_test(dataset, arch): def name_test(dataset, arch):
model = create_model(False, dataset, arch, parallel=False) model = create_model(False, dataset, arch, parallel=False)
modelp = create_model(False, dataset, arch, parallel=True) modelp = create_model(False, dataset, arch, parallel=True)
...@@ -163,3 +184,18 @@ def test_onnx_name_2_pytorch_name(): ...@@ -163,3 +184,18 @@ def test_onnx_name_2_pytorch_name():
assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv') assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')
assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu') assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu')
#assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv') #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv')
def test_connectivity_summary():
g = create_graph('cifar10', 'resnet20_cifar')
assert g is not None
summary = connectivity_summary(g)
assert len(summary) == 73
verbose_summary = connectivity_summary_verbose(g)
assert len(verbose_summary ) == 73
if __name__ == '__main__':
test_connectivity_summary()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment