From a3f2ce2d5199c22dccbe21c870a60e4409b4c490 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 15 May 2019 16:48:01 +0300
Subject: [PATCH] =?UTF-8?q?SummaryGraph:=20fix=20=E2=80=98weights=5Fvol?=
 =?UTF-8?q?=E2=80=99=20attribute=20for=20conv=20and=20linear=20layers?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The weights_vol attribute reflects the size (volume) of an SG node’s
weights tensor.  The calculation of the weights volume was wrong.
This does not have any significant impact because this attribute is
not used.
wq
---
 distiller/summary_graph.py | 25 +++++++++++++++----------
 tests/test_summarygraph.py | 30 ++++++++++++++++++++++++------
 2 files changed, 39 insertions(+), 16 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index d3facd5..d246fd5 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -21,7 +21,6 @@ import collections
 import torch
 import torch.jit as jit
 import logging
-from collections import OrderedDict
 msglogger = logging.getLogger()
 
 
@@ -100,17 +99,17 @@ class SummaryGraph(object):
             
             device = next(model_clone.parameters()).device
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
-            trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
+            trace, _ = jit.get_trace_graph(model_clone, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
 
             graph = trace.graph()
-            self.ops = OrderedDict()
-            self.params = OrderedDict()
+            self.ops = {}
+            self.params = {}
             self.edges = []
-            self.temp = OrderedDict()
+            self.temp = {}
 
             in_out = list(graph.inputs()) + list(graph.outputs())
             for param in in_out:
@@ -149,7 +148,7 @@ class SummaryGraph(object):
                     self.__add_output(new_op, output)
                     self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName()))
 
-                new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
+                new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
 
         self.add_macs_attr()
         self.add_footprint_attr()
@@ -157,7 +156,7 @@ class SummaryGraph(object):
         del model_clone
 
     def __create_op(self, onnx_node):
-        op = OrderedDict()
+        op = {}
         op['name'] = onnx_node.scopeName()
         op['orig-name'] = onnx_node.scopeName()
         op['type'] = onnx_node.kind().lstrip('::onnx')
@@ -189,7 +188,7 @@ class SummaryGraph(object):
         return param
 
     def __tensor_desc(self, n):
-        tensor = OrderedDict()
+        tensor = {}
         tensor['id'] = n.uniqueName()
         try:
             # try parsing the FM tensor type.  For example: Float(1, 64, 8, 8)
@@ -246,8 +245,14 @@ class SummaryGraph(object):
                 ofm_vol = self.param_volume(conv_out)
                 ifm_vol = self.param_volume(conv_in)
                 if op['type'] == 'Conv' or op['type'] == 'Gemm':
-                    conv_w = op['inputs'][1]
-                    weights_vol = self.param_volume(conv_w)
+                    if op['type'] == 'Conv':
+                        kernel_size =  self.volume(op['attrs']['kernel_shape'])
+                        group = op['attrs']['group']
+                    else:
+                        kernel_size, group = 1, 1
+                    n_ifm = self.param_shape(conv_in)[1] / group
+                    n_ofm = self.param_shape(conv_out)[1] 
+                    weights_vol = kernel_size * n_ifm * n_ofm
                     op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
                     op['attrs']['fm_vol'] = ofm_vol + ifm_vol
                     op['attrs']['weights_vol'] = weights_vol
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 8cf1e7d..f74bfda 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -116,6 +116,24 @@ def test_layer_search():
     assert preds == ['layer1.0.conv2', 'conv1']
 
 
+def test_weights_size_attr():
+    def test(dataset, arch, dataparallel:bool):
+        model = create_model(False, dataset, arch, parallel=False)
+        sgraph = SummaryGraph(model, get_input(dataset))
+
+        distiller.assign_layer_fq_names(model)
+        for name, mod in model.named_modules():
+            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
+                op = sgraph.find_op(name)
+                assert op is not None
+                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
+
+    for data_parallel in (True, False):
+        test('cifar10', 'resnet20_cifar', data_parallel)
+        test('imagenet', 'alexnet', data_parallel)
+        test('imagenet', 'resnext101_32x4d', data_parallel)
+
+
 def test_vgg():
     g = create_graph('imagenet', 'vgg19')
     assert g is not None
@@ -170,11 +188,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
 
 
 def test_named_params_layers():
-    for dataParallelModel in (True, False):
-        named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
-        named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
-        named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
-        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
+    for data_parallel in (True, False):
+        named_params_layers_test_aux('imagenet', 'vgg19', data_parallel)
+        named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel)
+        named_params_layers_test_aux('imagenet', 'alexnet', data_parallel)
+        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel)
 
 
 def test_onnx_name_2_pytorch_name():
@@ -213,4 +231,4 @@ def test_sg_macs():
 
 if __name__ == '__main__':
     #test_connectivity_summary()
-    test_sg_macs()
\ No newline at end of file
+    test_sg_macs()
-- 
GitLab