From a3f2ce2d5199c22dccbe21c870a60e4409b4c490 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 15 May 2019 16:48:01 +0300 Subject: [PATCH] =?UTF-8?q?SummaryGraph:=20fix=20=E2=80=98weights=5Fvol?= =?UTF-8?q?=E2=80=99=20attribute=20for=20conv=20and=20linear=20layers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The weights_vol attribute reflects the size (volume) of an SG node’s weights tensor. The calculation of the weights volume was wrong. This does not have any significant impact because this attribute is not used. wq --- distiller/summary_graph.py | 25 +++++++++++++++---------- tests/test_summarygraph.py | 30 ++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index d3facd5..d246fd5 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -21,7 +21,6 @@ import collections import torch import torch.jit as jit import logging -from collections import OrderedDict msglogger = logging.getLogger() @@ -100,17 +99,17 @@ class SummaryGraph(object): device = next(model_clone.parameters()).device dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) - trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) + trace, _ = jit.get_trace_graph(model_clone, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) graph = trace.graph() - self.ops = OrderedDict() - self.params = OrderedDict() + self.ops = {} + self.params = {} self.edges = [] - self.temp = OrderedDict() + self.temp = {} in_out = list(graph.inputs()) + list(graph.outputs()) for param in in_out: @@ -149,7 +148,7 @@ class SummaryGraph(object): self.__add_output(new_op, output) self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName())) - new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()]) + new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()} self.add_macs_attr() self.add_footprint_attr() @@ -157,7 +156,7 @@ class SummaryGraph(object): del model_clone def __create_op(self, onnx_node): - op = OrderedDict() + op = {} op['name'] = onnx_node.scopeName() op['orig-name'] = onnx_node.scopeName() op['type'] = onnx_node.kind().lstrip('::onnx') @@ -189,7 +188,7 @@ class SummaryGraph(object): return param def __tensor_desc(self, n): - tensor = OrderedDict() + tensor = {} tensor['id'] = n.uniqueName() try: # try parsing the FM tensor type. For example: Float(1, 64, 8, 8) @@ -246,8 +245,14 @@ class SummaryGraph(object): ofm_vol = self.param_volume(conv_out) ifm_vol = self.param_volume(conv_in) if op['type'] == 'Conv' or op['type'] == 'Gemm': - conv_w = op['inputs'][1] - weights_vol = self.param_volume(conv_w) + if op['type'] == 'Conv': + kernel_size = self.volume(op['attrs']['kernel_shape']) + group = op['attrs']['group'] + else: + kernel_size, group = 1, 1 + n_ifm = self.param_shape(conv_in)[1] / group + n_ofm = self.param_shape(conv_out)[1] + weights_vol = kernel_size * n_ifm * n_ofm op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['weights_vol'] = weights_vol diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 8cf1e7d..f74bfda 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -116,6 +116,24 @@ def test_layer_search(): assert preds == ['layer1.0.conv2', 'conv1'] +def test_weights_size_attr(): + def test(dataset, arch, dataparallel:bool): + model = create_model(False, dataset, arch, parallel=False) + sgraph = SummaryGraph(model, get_input(dataset)) + + distiller.assign_layer_fq_names(model) + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): + op = sgraph.find_op(name) + assert op is not None + assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) + + for data_parallel in (True, False): + test('cifar10', 'resnet20_cifar', data_parallel) + test('imagenet', 'alexnet', data_parallel) + test('imagenet', 'resnext101_32x4d', data_parallel) + + def test_vgg(): g = create_graph('imagenet', 'vgg19') assert g is not None @@ -170,11 +188,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool): def test_named_params_layers(): - for dataParallelModel in (True, False): - named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel) - named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel) - named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel) - named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel) + for data_parallel in (True, False): + named_params_layers_test_aux('imagenet', 'vgg19', data_parallel) + named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel) + named_params_layers_test_aux('imagenet', 'alexnet', data_parallel) + named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel) def test_onnx_name_2_pytorch_name(): @@ -213,4 +231,4 @@ def test_sg_macs(): if __name__ == '__main__': #test_connectivity_summary() - test_sg_macs() \ No newline at end of file + test_sg_macs() -- GitLab