From a0ebeb7effaf4199a35661f4ee0c085fec8b3ad7 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 15 May 2019 17:01:48 +0300 Subject: [PATCH] =?UTF-8?q?Revert=20"SummaryGraph:=20fix=20=E2=80=98weight?= =?UTF-8?q?s=5Fvol=E2=80=99=20attribute=20for=20conv=20and=20linear=20laye?= =?UTF-8?q?rs"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit a3f2ce2d5199c22dccbe21c870a60e4409b4c490. --- distiller/summary_graph.py | 25 ++++++++++--------------- tests/test_summarygraph.py | 30 ++++++------------------------ 2 files changed, 16 insertions(+), 39 deletions(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index d246fd5..d3facd5 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -21,6 +21,7 @@ import collections import torch import torch.jit as jit import logging +from collections import OrderedDict msglogger = logging.getLogger() @@ -99,17 +100,17 @@ class SummaryGraph(object): device = next(model_clone.parameters()).device dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) - trace, _ = jit.get_trace_graph(model_clone, dummy_input) + trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) graph = trace.graph() - self.ops = {} - self.params = {} + self.ops = OrderedDict() + self.params = OrderedDict() self.edges = [] - self.temp = {} + self.temp = OrderedDict() in_out = list(graph.inputs()) + list(graph.outputs()) for param in in_out: @@ -148,7 +149,7 @@ class SummaryGraph(object): self.__add_output(new_op, output) self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName())) - new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()} + new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()]) self.add_macs_attr() self.add_footprint_attr() @@ -156,7 +157,7 @@ class SummaryGraph(object): del model_clone def __create_op(self, onnx_node): - op = {} + op = OrderedDict() op['name'] = onnx_node.scopeName() op['orig-name'] = onnx_node.scopeName() op['type'] = onnx_node.kind().lstrip('::onnx') @@ -188,7 +189,7 @@ class SummaryGraph(object): return param def __tensor_desc(self, n): - tensor = {} + tensor = OrderedDict() tensor['id'] = n.uniqueName() try: # try parsing the FM tensor type. For example: Float(1, 64, 8, 8) @@ -245,14 +246,8 @@ class SummaryGraph(object): ofm_vol = self.param_volume(conv_out) ifm_vol = self.param_volume(conv_in) if op['type'] == 'Conv' or op['type'] == 'Gemm': - if op['type'] == 'Conv': - kernel_size = self.volume(op['attrs']['kernel_shape']) - group = op['attrs']['group'] - else: - kernel_size, group = 1, 1 - n_ifm = self.param_shape(conv_in)[1] / group - n_ofm = self.param_shape(conv_out)[1] - weights_vol = kernel_size * n_ifm * n_ofm + conv_w = op['inputs'][1] + weights_vol = self.param_volume(conv_w) op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['weights_vol'] = weights_vol diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index f74bfda..8cf1e7d 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -116,24 +116,6 @@ def test_layer_search(): assert preds == ['layer1.0.conv2', 'conv1'] -def test_weights_size_attr(): - def test(dataset, arch, dataparallel:bool): - model = create_model(False, dataset, arch, parallel=False) - sgraph = SummaryGraph(model, get_input(dataset)) - - distiller.assign_layer_fq_names(model) - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): - op = sgraph.find_op(name) - assert op is not None - assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) - - for data_parallel in (True, False): - test('cifar10', 'resnet20_cifar', data_parallel) - test('imagenet', 'alexnet', data_parallel) - test('imagenet', 'resnext101_32x4d', data_parallel) - - def test_vgg(): g = create_graph('imagenet', 'vgg19') assert g is not None @@ -188,11 +170,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool): def test_named_params_layers(): - for data_parallel in (True, False): - named_params_layers_test_aux('imagenet', 'vgg19', data_parallel) - named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel) - named_params_layers_test_aux('imagenet', 'alexnet', data_parallel) - named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel) + for dataParallelModel in (True, False): + named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel) + named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel) + named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel) + named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel) def test_onnx_name_2_pytorch_name(): @@ -231,4 +213,4 @@ def test_sg_macs(): if __name__ == '__main__': #test_connectivity_summary() - test_sg_macs() + test_sg_macs() \ No newline at end of file -- GitLab