From a0ebeb7effaf4199a35661f4ee0c085fec8b3ad7 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 15 May 2019 17:01:48 +0300
Subject: [PATCH] =?UTF-8?q?Revert=20"SummaryGraph:=20fix=20=E2=80=98weight?=
 =?UTF-8?q?s=5Fvol=E2=80=99=20attribute=20for=20conv=20and=20linear=20laye?=
 =?UTF-8?q?rs"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This reverts commit a3f2ce2d5199c22dccbe21c870a60e4409b4c490.
---
 distiller/summary_graph.py | 25 ++++++++++---------------
 tests/test_summarygraph.py | 30 ++++++------------------------
 2 files changed, 16 insertions(+), 39 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index d246fd5..d3facd5 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -21,6 +21,7 @@ import collections
 import torch
 import torch.jit as jit
 import logging
+from collections import OrderedDict
 msglogger = logging.getLogger()
 
 
@@ -99,17 +100,17 @@ class SummaryGraph(object):
             
             device = next(model_clone.parameters()).device
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
-            trace, _ = jit.get_trace_graph(model_clone, dummy_input)
+            trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
 
             graph = trace.graph()
-            self.ops = {}
-            self.params = {}
+            self.ops = OrderedDict()
+            self.params = OrderedDict()
             self.edges = []
-            self.temp = {}
+            self.temp = OrderedDict()
 
             in_out = list(graph.inputs()) + list(graph.outputs())
             for param in in_out:
@@ -148,7 +149,7 @@ class SummaryGraph(object):
                     self.__add_output(new_op, output)
                     self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName()))
 
-                new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
+                new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
 
         self.add_macs_attr()
         self.add_footprint_attr()
@@ -156,7 +157,7 @@ class SummaryGraph(object):
         del model_clone
 
     def __create_op(self, onnx_node):
-        op = {}
+        op = OrderedDict()
         op['name'] = onnx_node.scopeName()
         op['orig-name'] = onnx_node.scopeName()
         op['type'] = onnx_node.kind().lstrip('::onnx')
@@ -188,7 +189,7 @@ class SummaryGraph(object):
         return param
 
     def __tensor_desc(self, n):
-        tensor = {}
+        tensor = OrderedDict()
         tensor['id'] = n.uniqueName()
         try:
             # try parsing the FM tensor type.  For example: Float(1, 64, 8, 8)
@@ -245,14 +246,8 @@ class SummaryGraph(object):
                 ofm_vol = self.param_volume(conv_out)
                 ifm_vol = self.param_volume(conv_in)
                 if op['type'] == 'Conv' or op['type'] == 'Gemm':
-                    if op['type'] == 'Conv':
-                        kernel_size =  self.volume(op['attrs']['kernel_shape'])
-                        group = op['attrs']['group']
-                    else:
-                        kernel_size, group = 1, 1
-                    n_ifm = self.param_shape(conv_in)[1] / group
-                    n_ofm = self.param_shape(conv_out)[1] 
-                    weights_vol = kernel_size * n_ifm * n_ofm
+                    conv_w = op['inputs'][1]
+                    weights_vol = self.param_volume(conv_w)
                     op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
                     op['attrs']['fm_vol'] = ofm_vol + ifm_vol
                     op['attrs']['weights_vol'] = weights_vol
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index f74bfda..8cf1e7d 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -116,24 +116,6 @@ def test_layer_search():
     assert preds == ['layer1.0.conv2', 'conv1']
 
 
-def test_weights_size_attr():
-    def test(dataset, arch, dataparallel:bool):
-        model = create_model(False, dataset, arch, parallel=False)
-        sgraph = SummaryGraph(model, get_input(dataset))
-
-        distiller.assign_layer_fq_names(model)
-        for name, mod in model.named_modules():
-            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
-                op = sgraph.find_op(name)
-                assert op is not None
-                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
-
-    for data_parallel in (True, False):
-        test('cifar10', 'resnet20_cifar', data_parallel)
-        test('imagenet', 'alexnet', data_parallel)
-        test('imagenet', 'resnext101_32x4d', data_parallel)
-
-
 def test_vgg():
     g = create_graph('imagenet', 'vgg19')
     assert g is not None
@@ -188,11 +170,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
 
 
 def test_named_params_layers():
-    for data_parallel in (True, False):
-        named_params_layers_test_aux('imagenet', 'vgg19', data_parallel)
-        named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel)
-        named_params_layers_test_aux('imagenet', 'alexnet', data_parallel)
-        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel)
+    for dataParallelModel in (True, False):
+        named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
+        named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
+        named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
+        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
 
 
 def test_onnx_name_2_pytorch_name():
@@ -231,4 +213,4 @@ def test_sg_macs():
 
 if __name__ == '__main__':
     #test_connectivity_summary()
-    test_sg_macs()
+    test_sg_macs()
\ No newline at end of file
-- 
GitLab