diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index b956abc0914224b60c1fb8a96031e10225832c37..953ab8fd35fea0fec53070762cc992567325c00e 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -30,8 +30,10 @@ from torch.autograd import Variable import torch.jit as jit import pandas as pd from tabulate import tabulate +import pydot + -def onnx_name_2_pytorch_name(name): +def onnx_name_2_pytorch_name(name, op_type): # Convert a layer's name from an ONNX name, to a PyTorch name # For example: # ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1 ==> layer3.0.relu.1 @@ -45,12 +47,16 @@ def onnx_name_2_pytorch_name(name): name_parts = re.findall('\[.*?\]', name) name_parts = [part[1:-1] for part in name_parts] - new_name = '.'.join(name_parts) + instance - if new_name == '': - new_name = name + # If name doesn't have the pattern above, it probably means the op was called via + # some functional API and not via a module. Couple of examples: + # x = x.view(...) + # x = F.relu(x) + # In this case, to have a meaningful name, we use the op type + new_name = ('.'.join(name_parts) if len(name_parts) > 0 else op_type) + instance return new_name + class SummaryGraph(object): """We use Pytorch's JIT tracer to run a forward pass and generate a trace graph, which is an internal representation of the model. We then use ONNX to "clean" this @@ -102,13 +108,18 @@ class SummaryGraph(object): for node in graph.nodes(): new_op = self.__create_op(node) - # in-place operators create very confusing graphs (Resnet, for example), - # so we "unroll" them - same = [op for op in self.ops.values() if op['orig-name'] == new_op['orig-name']] + # Operators with the same name create very confusing graphs (Resnet, for example), + # so we "unroll" them. + # Sometimes operations of different types have the same name, so we differentiate + # using both name and type + # (this happens, for example, when an operator is called via some functional API and + # not via a module) + same = [op for op in self.ops.values() if + op['orig-name'] + op['type'] == new_op['orig-name'] + new_op['type']] if len(same) > 0: new_op['name'] += "." + str(len(same)) - new_op['name'] = onnx_name_2_pytorch_name(new_op['name']) + new_op['name'] = onnx_name_2_pytorch_name(new_op['name'], new_op['type']) assert len(new_op['name']) > 0 self.ops[new_op['name']] = new_op @@ -137,21 +148,21 @@ class SummaryGraph(object): op['params'] = [] return op - def __add_input(self, op, n): param = self.__add_param(n) - if param is None: return + if param is None: + return if param['id'] not in op['inputs']: op['inputs'].append(param['id']) def __add_output(self, op, n): param = self.__add_param(n) - if param is None: return + if param is None: + return if param['id'] not in op['outputs']: op['outputs'].append(param['id']) def __add_param(self, n): - param = {} if n.uniqueName() not in self.params: param = self.__tensor_desc(n) self.params[n.uniqueName()] = param @@ -168,7 +179,8 @@ class SummaryGraph(object): s = s[s.find('(')+1: s.find(')')] tensor['shape'] = tuple(map(lambda x: int(x), s.split(','))) except: - return None + # Size not specified in type + tensor['shape'] = 0, return tensor def param_shape(self, param_id): @@ -261,7 +273,6 @@ class SummaryGraph(object): ret += self.predecessors(predecessor, depth-1, done_list) #, logging) return ret - def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None): """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria. """ @@ -299,7 +310,6 @@ class SummaryGraph(object): ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging) return ret - def successors(self, node, depth, done_list=None): """Returns a list of <op>'s successors""" @@ -370,6 +380,7 @@ class SummaryGraph(object): ret += self.successors_f(successor, successors_types, done_list, logging) return ret + def attributes_summary(sgraph, ignore_attrs): """Generate a summary of a graph's attributes. @@ -399,10 +410,12 @@ def attributes_summary(sgraph, ignore_attrs): df.loc[i] = [op['name'], op['type'], pretty_attrs(op['attrs'], ignore_attrs)] return df + def attributes_summary_tbl(sgraph, ignore_attrs): df = attributes_summary(sgraph, ignore_attrs) return tabulate(df, headers='keys', tablefmt='psql') + def connectivity_summary(sgraph): """Generate a summary of each node's connectivity. @@ -415,6 +428,7 @@ def connectivity_summary(sgraph): df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']] return df + def connectivity_summary_verbose(sgraph): """Generate a summary of each node's connectivity, with details about the parameters. @@ -444,6 +458,7 @@ def connectivity_summary_verbose(sgraph): return df + def connectivity_tbl_summary(sgraph, verbose=False): if verbose: df = connectivity_summary_verbose(sgraph) @@ -452,9 +467,6 @@ def connectivity_tbl_summary(sgraph, verbose=False): return tabulate(df, headers='keys', tablefmt='psql') - -import pydot - def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges): pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir='TB') @@ -467,32 +479,32 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges): label="\n".join(op_node))) for data_node in data_nodes: - #pydot_graph.add_node(pydot.Node(data_node)) pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node))) - node_style = {'shape': 'oval', 'fillcolor': 'gray', 'style': 'rounded, filled'} if param_nodes is not None: - for data_node in param_nodes: - pydot_graph.add_node(pydot.Node(data_node, **node_style)) + for param_node in param_nodes: + pydot_graph.add_node(pydot.Node(param_node[0], **node_style, label="\n".join(param_node))) for edge in edges: pydot_graph.add_edge(pydot.Edge(edge[0], edge[1])) return pydot_graph -def draw_model_to_file(sgraph, png_fname): + +def draw_model_to_file(sgraph, png_fname, display_param_nodes=False): """Create a PNG file, containing a graphiz-dot graph of the netowrk represented by SummaryGraph 'sgraph' """ - png = create_png(sgraph) + png = create_png(sgraph, display_param_nodes=display_param_nodes) with open(png_fname, 'wb') as fid: fid.write(png) -def draw_img_classifier_to_file(model, png_fname, dataset): + +def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False): try: if dataset == 'imagenet': dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) @@ -503,41 +515,42 @@ def draw_img_classifier_to_file(model, png_fname, dataset): return g = SummaryGraph(model, dummy_input) - draw_model_to_file(g, png_fname) + draw_model_to_file(g, png_fname, display_param_nodes) print("Network PNG image generation completed") - except TypeError as e: - print("An error has occured while generating the network PNG image.") - print("This feature is not supported on official PyTorch releases.") - print("Please check that you are using a valid PyTorch version.") - print("You are using pytorch version %s" %torch.__version__) except FileNotFoundError: print("An error has occured while generating the network PNG image.") print("Please check that you have graphviz installed.") print("\t$ sudo apt-get install graphviz") + def create_png(sgraph, display_param_nodes=False): """Create a PNG object containing a graphiz-dot graph of the network, as represented by SummaryGraph 'sgraph' """ - op_nodes = [op['name'] for op in sgraph.ops] - data_nodes = [(id,param.shape) for (id, param) in sgraph.params.items() if data_node_has_parent(sgraph, id)] - param_nodes = [id for id in sgraph.params.keys() if not data_node_has_parent(sgraph, id)] + op_nodes = [op['name'] for op in sgraph.ops.values()] + data_nodes = [] + param_nodes = [] + for id, param in sgraph.params.items(): + n_data = (id, str(param['shape'])) + if data_node_has_parent(sgraph, id): + data_nodes.append(n_data) + else: + param_nodes.append(n_data) edges = sgraph.edges if not display_param_nodes: # Use only the edges that don't have a parameter source - edges = [edge for edge in sgraph.edges if edge.src in (data_nodes+op_nodes)] + non_param_ids = op_nodes + [dn[0] for dn in data_nodes] + edges = [edge for edge in sgraph.edges if edge.src in non_param_ids] param_nodes = None - if False: - data_nodes = None - - op_nodes = [(op['name'], op['type']) for op in sgraph.ops] + op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()] pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges) png = pydot_graph.create_png() return png + def data_node_has_parent(g, id): for edge in g.edges: if edge.dst == id: return True diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 3a2d5500fb52fa243451b0b13ae5d99efb946ece..338ec78438c0bdbbc757dff475c34e744445983c 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -111,7 +111,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true', help='collect activation statistics (WARNING: this slows down training)') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)') -SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'png_w_params'] parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, help='print a summary of the model, and exit - options: ' + ' | '.join(SUMMARY_CHOICES)) @@ -193,7 +193,8 @@ def main(): args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' # Create the model - is_parallel = args.summary != 'png' # For PNG summary, parallel graphs are illegible + png_summary = args.summary is not None and args.summary.startswith('png') + is_parallel = not png_summary # For PNG summary, parallel graphs are illegible model = create_model(args.pretrained, args.dataset, args.arch, parallel=is_parallel, device_ids=args.gpus) compression_scheduler = None @@ -223,8 +224,8 @@ def main(): # This sample application can be invoked to produce various summary reports. if args.summary: which_summary = args.summary - if which_summary == 'png': - apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset) + if which_summary.startswith('png'): + apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset, which_summary == 'png_w_params') else: distiller.model_summary(model, optimizer, which_summary, args.dataset) exit() diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 09c867b6f088271f39ee44e470574efc7e23b296..31145dbb8dedf6d486bf65f681570b3d68995bfb 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -33,6 +33,7 @@ fh = logging.FileHandler('test.log') logger = logging.getLogger() logger.addHandler(fh) + def get_input(dataset): if dataset == 'imagenet': return torch.randn((1, 3, 224, 224), requires_grad=False) @@ -40,6 +41,7 @@ def get_input(dataset): return torch.randn((1, 3, 32, 32)) return None + def create_graph(dataset, arch): dummy_input = get_input(dataset) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) @@ -48,10 +50,12 @@ def create_graph(dataset, arch): assert model is not None return SummaryGraph(model, dummy_input) + def test_graph(): g = create_graph('cifar10', 'resnet20_cifar') assert g is not None + def test_connectivity(): g = create_graph('cifar10', 'resnet20_cifar') assert g is not None @@ -86,6 +90,7 @@ def test_connectivity(): assert preds == ['0', '1'] #logging.debug(preds) + def test_layer_search(): g = create_graph('cifar10', 'resnet20_cifar') assert g is not None @@ -118,12 +123,14 @@ def test_layer_search(): preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging) assert preds == ['layer1.0.conv2', 'conv1'] + def normalize_layer_name(layer_name): start = layer_name.find('module.') if start != -1: layer_name = layer_name[:start] + layer_name[start + len('module.'):] return layer_name + def test_vgg(): g = create_graph('imagenet', 'vgg19') assert g is not None @@ -131,12 +138,15 @@ def test_vgg(): logging.debug(succs) succs = g.successors_f('features.34', 'Conv') + def test_normalize_layer_name(): assert "features.0", normalize_layer_name("features.module.0") assert "features.0", normalize_layer_name("module.features.0") assert "features.0", normalize_layer_name("features.0.module") + def test_onnx_name_2_pytorch_name(): - assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1") - assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]') - #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]') + assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu') + assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv') + assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu') + #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv')