From 1a8c6bb823f066a89cb5bf538c64631ec895b220 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 10 Apr 2019 01:50:26 +0300
Subject: [PATCH] SummaryGraph - fix MACs calculation for grouped-convolutions

Also added tests
---
 distiller/summary_graph.py  |  7 +++--
 tests/test_model_summary.py | 63 ++++++++++++++++++++++++++-----------
 tests/test_summarygraph.py  | 22 +++++++++++--
 3 files changed, 67 insertions(+), 25 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 7c0b09d..4469294 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -217,11 +217,12 @@ class SummaryGraph(object):
                 conv_out = op['outputs'][0]
                 conv_in = op['inputs'][0]
                 conv_w = op['attrs']['kernel_shape']
+                groups = op['attrs']['group']
                 ofm_vol = self.param_volume(conv_out)
                 try:
-                    # MACs = volume(OFM) * (#IFM * K^2)
-                    op['attrs']['MACs'] = ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1]
-                except IndexError as e:
+                    # MACs = volume(OFM) * (#IFM * K^2) / #Groups
+                    op['attrs']['MACs'] = int(ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] / groups)
+                except IndexError:
                     # Todo: change the method for calculating MACs
                     msglogger.error("An input to a Convolutional layer is missing shape information "
                                     "(MAC values will be wrong)")
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
index 5053748..e5271a8 100755
--- a/tests/test_model_summary.py
+++ b/tests/test_model_summary.py
@@ -15,11 +15,11 @@
 #
 
 import logging
-import torch
 import distiller
 import pytest
 import common  # common test code
 
+
 # Logging configuration
 logging.basicConfig(level=logging.INFO)
 fh = logging.FileHandler('test.log')
@@ -28,30 +28,55 @@ logger.addHandler(fh)
 
 
 def test_png_generation():
-    DATASET = "cifar10"
-    ARCH = "resnet20_cifar"
-    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
+    dataset = "cifar10"
+    arch = "resnet20_cifar"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
     # 2 different ways to create a PNG
-    distiller.draw_img_classifier_to_file(model, 'model.png', DATASET, True)
-    distiller.draw_img_classifier_to_file(model, 'model.png', DATASET, False)
-
+    distiller.draw_img_classifier_to_file(model, 'model.png', dataset, True)
+    distiller.draw_img_classifier_to_file(model, 'model.png', dataset, False)
+    
 
 def test_negative():
-    DATASET = "cifar10"
-    ARCH = "resnet20_cifar"
-    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
+    dataset = "cifar10"
+    arch = "resnet20_cifar"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
 
     with pytest.raises(ValueError):
         # png is not a supported summary type, so we expect this to fail with a ValueError
-        distiller.model_summary(model, what='png', dataset=DATASET)
+        distiller.model_summary(model, what='png', dataset=dataset)
+
+
+def test_compute_summary():
+    dataset = "cifar10"
+    arch = "simplenet_cifar"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
+    df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset))
+    module_macs = df_compute.loc[:, 'MACs'].to_list()
+    #                     [conv1,  conv2,  fc1,   fc2,   fc3]
+    assert module_macs == [352800, 240000, 48000, 10080, 840]
+
+    dataset = "imagenet"
+    arch = "mobilenet"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
+    df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset))
+    module_macs = df_compute.loc[:, 'MACs'].to_list()
+    expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168, 
+                     25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168, 
+                     51380224, 903168, 51380224, 903168, 51380224, 903168, 51380224, 225792, 
+                     25690112, 451584, 51380224, 1024000]
+    assert module_macs == expected_macs
 
 
 def test_summary():
-    DATASET = "cifar10"
-    ARCH = "resnet20_cifar"
-    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
-
-    distiller.model_summary(model, what='sparsity', dataset=DATASET)
-    distiller.model_summary(model, what='compute', dataset=DATASET)
-    distiller.model_summary(model, what='model', dataset=DATASET)
-    distiller.model_summary(model, what='modules', dataset=DATASET)
+    dataset = "cifar10"
+    arch = "resnet20_cifar"
+    model, _ = common.setup_test(arch, dataset, parallel=True)
+
+    distiller.model_summary(model, what='sparsity', dataset=dataset)
+    distiller.model_summary(model, what='compute', dataset=dataset)
+    distiller.model_summary(model, what='model', dataset=dataset)
+    distiller.model_summary(model, what='modules', dataset=dataset)
+
+
+if __name__ == '__main__':
+    test_compute_summary()
\ No newline at end of file
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 195a982..8cf1e7d 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -166,8 +166,7 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
     sgraph = SummaryGraph(model, get_input(dataset))
     sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
     for layer_name in sgraph_layer_names:
-        assert (sgraph.find_op(layer_name) is not None,
-            '{} was not found in summary graph'.format(layer_name))
+        assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name)
 
 
 def test_named_params_layers():
@@ -196,5 +195,22 @@ def test_connectivity_summary():
     assert len(verbose_summary) == 81
 
 
+def test_sg_macs():
+    '''Compare the MACs of different modules as computed by a SummaryGraph
+    and model summary.'''
+    import common
+    sg = create_graph('imagenet', 'mobilenet')
+    assert sg
+    model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False)
+    df_compute = distiller.model_performance_summary(model, common.get_dummy_input('imagenet'))
+    modules_macs = df_compute.loc[:, ['Name', 'MACs']]
+    for name, mod in model.named_modules():
+        if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)):
+            summary_macs = int(modules_macs.loc[modules_macs.Name == name].MACs)
+            sg_macs = sg.find_op(name)['attrs']['MACs']
+            assert summary_macs == sg_macs
+ 
+
 if __name__ == '__main__':
-    test_connectivity_summary()
+    #test_connectivity_summary()
+    test_sg_macs()
\ No newline at end of file
-- 
GitLab