diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 5a89685e22b7ff74fa3840fe1f7afab7ac5b07f2..cf9b124929917e76e2361e759bd71e8639e223be 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -51,7 +51,7 @@ def model_summary(model, what, dataset=None): distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger]) elif what == 'compute': try: - dummy_input = dataset_dummy_input(dataset) + dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) except ValueError as e: print(e) return @@ -319,7 +319,7 @@ def connectivity_tbl_summary(sgraph, verbose=False): return tabulate(df, headers='keys', tablefmt='psql') -def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', styles=None): +def create_pydot_graph(op_nodes_desc, data_nodes, param_nodes, edges, rankdir='TB', styles=None): """Low-level API to create a PyDot graph (dot formatted). """ pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir=rankdir) @@ -328,7 +328,7 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', s 'fillcolor': '#6495ED', 'style': 'rounded, filled'} - for op_node in op_nodes: + for op_node in op_nodes_desc: style = op_node_style # Check if we should override the style of this node. if styles is not None and op_node[0] in styles: @@ -368,6 +368,12 @@ def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None): 'style': 'rounded, filled'} """ + def annotate_op_node(op): + if op['type'] == 'Conv': + return ["sh={}".format(distiller.size2str(op['attrs']['kernel_shape'])), + "g={}".format(str(op['attrs']['group']))] + return '' + op_nodes = [op['name'] for op in sgraph.ops.values()] data_nodes = [] param_nodes = [] @@ -385,8 +391,8 @@ def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None): edges = [edge for edge in sgraph.edges if edge.src in non_param_ids] param_nodes = None - op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()] - pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir, styles) + op_nodes_desc = [(op['name'], op['type'], *annotate_op_node(op)) for op in sgraph.ops.values()] + pydot_graph = create_pydot_graph(op_nodes_desc, data_nodes, param_nodes, edges, rankdir, styles) png = pydot_graph.create_png() return png @@ -431,7 +437,7 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F 'fillcolor': 'gray', 'style': 'rounded, filled'} """ - dummy_input = dataset_dummy_input(dataset) + dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) try: non_para_model = distiller.make_non_parallel_copy(model) g = SummaryGraph(non_para_model, dummy_input) @@ -446,16 +452,6 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F del non_para_model -def dataset_dummy_input(dataset): - if dataset == 'imagenet': - dummy_input = torch.randn(1, 3, 224, 224) - elif dataset == 'cifar10': - dummy_input = torch.randn(1, 3, 32, 32) - else: - raise ValueError("Unsupported dataset (%s) - aborting operation" % dataset) - return dummy_input - - def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, **kwargs): """Export a PyTorch image classifier to ONNX. @@ -463,7 +459,7 @@ def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, add_softmax: when True, adds softmax layer to the output model. kwargs: arguments to be passed to torch.onnx.export """ - dummy_input = dataset_dummy_input(dataset).to('cuda') + dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) # Pytorch doesn't support exporting modules wrapped in DataParallel non_para_model = distiller.make_non_parallel_copy(model) diff --git a/distiller/utils.py b/distiller/utils.py index 580fa518ccb1b6ed5880fc99dbe113d148c2a2e4..f0b24a57afc72ed8b7ea0e7be0307ec1c362b2d2 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -60,12 +60,14 @@ def size2str(torch_size): return size_to_str(torch_size.size()) if isinstance(torch_size, torch.autograd.Variable): return size_to_str(torch_size.data.size()) + if isinstance(torch_size, tuple) or isinstance(torch_size, list): + return size_to_str(torch_size) raise TypeError def size_to_str(torch_size): """Convert a pytorch Size object to a string""" - assert isinstance(torch_size, torch.Size) + assert isinstance(torch_size, torch.Size) or isinstance(torch_size, tuple) or isinstance(torch_size, list) return '('+(', ').join(['%d' % v for v in torch_size])+')'