From 08b5cd95704d850cfb845ed7785f739cbb57de54 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 15 May 2019 17:19:05 +0300
Subject: [PATCH] =?UTF-8?q?SummaryGraph:=20fix=20=E2=80=98weights=5Fvol?=
 =?UTF-8?q?=E2=80=99=20attribute=20for=20conv=20and=20linear=20layers?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The weights_vol attribute reflects the size (volume) of an SG node’s
weights tensor.  The calculation of the weights volume was wrong.
This does not have any significant impact because this attribute is
not used.
---
 distiller/summary_graph.py | 10 ++++++++--
 tests/test_summarygraph.py | 18 ++++++++++++++++++
 2 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index d3facd5..a5f8ce4 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -246,8 +246,14 @@ class SummaryGraph(object):
                 ofm_vol = self.param_volume(conv_out)
                 ifm_vol = self.param_volume(conv_in)
                 if op['type'] == 'Conv' or op['type'] == 'Gemm':
-                    conv_w = op['inputs'][1]
-                    weights_vol = self.param_volume(conv_w)
+                    if op['type'] == 'Conv':
+                        kernel_size =  self.volume(op['attrs']['kernel_shape'])
+                        group = op['attrs']['group']
+                    else:
+                        kernel_size, group = 1, 1
+                    n_ifm = self.param_shape(conv_in)[1] / group
+                    n_ofm = self.param_shape(conv_out)[1] 
+                    weights_vol = kernel_size * n_ifm * n_ofm
                     op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
                     op['attrs']['fm_vol'] = ofm_vol + ifm_vol
                     op['attrs']['weights_vol'] = weights_vol
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 8cf1e7d..46fedb3 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -211,6 +211,24 @@ def test_sg_macs():
             assert summary_macs == sg_macs
  
 
+def test_weights_size_attr():
+    def test(dataset, arch, dataparallel:bool):
+        model = create_model(False, dataset, arch, parallel=dataparallel)
+        sgraph = SummaryGraph(model, get_input(dataset))
+
+        distiller.assign_layer_fq_names(model)
+        for name, mod in model.named_modules():
+            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
+                op = sgraph.find_op(name)
+                assert op is not None
+                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
+
+    for data_parallel in (True, False):
+        test('cifar10', 'resnet20_cifar', data_parallel)
+        test('imagenet', 'alexnet', data_parallel)
+        test('imagenet', 'resnext101_32x4d', data_parallel)
+
+
 if __name__ == '__main__':
     #test_connectivity_summary()
     test_sg_macs()
\ No newline at end of file
-- 
GitLab