From 08b5cd95704d850cfb845ed7785f739cbb57de54 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 15 May 2019 17:19:05 +0300 Subject: [PATCH] =?UTF-8?q?SummaryGraph:=20fix=20=E2=80=98weights=5Fvol?= =?UTF-8?q?=E2=80=99=20attribute=20for=20conv=20and=20linear=20layers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The weights_vol attribute reflects the size (volume) of an SG node’s weights tensor. The calculation of the weights volume was wrong. This does not have any significant impact because this attribute is not used. --- distiller/summary_graph.py | 10 ++++++++-- tests/test_summarygraph.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index d3facd5..a5f8ce4 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -246,8 +246,14 @@ class SummaryGraph(object): ofm_vol = self.param_volume(conv_out) ifm_vol = self.param_volume(conv_in) if op['type'] == 'Conv' or op['type'] == 'Gemm': - conv_w = op['inputs'][1] - weights_vol = self.param_volume(conv_w) + if op['type'] == 'Conv': + kernel_size = self.volume(op['attrs']['kernel_shape']) + group = op['attrs']['group'] + else: + kernel_size, group = 1, 1 + n_ifm = self.param_shape(conv_in)[1] / group + n_ofm = self.param_shape(conv_out)[1] + weights_vol = kernel_size * n_ifm * n_ofm op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['weights_vol'] = weights_vol diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 8cf1e7d..46fedb3 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -211,6 +211,24 @@ def test_sg_macs(): assert summary_macs == sg_macs +def test_weights_size_attr(): + def test(dataset, arch, dataparallel:bool): + model = create_model(False, dataset, arch, parallel=dataparallel) + sgraph = SummaryGraph(model, get_input(dataset)) + + distiller.assign_layer_fq_names(model) + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): + op = sgraph.find_op(name) + assert op is not None + assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) + + for data_parallel in (True, False): + test('cifar10', 'resnet20_cifar', data_parallel) + test('imagenet', 'alexnet', data_parallel) + test('imagenet', 'resnext101_32x4d', data_parallel) + + if __name__ == '__main__': #test_connectivity_summary() test_sg_macs() \ No newline at end of file -- GitLab