diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index d3facd54e9ff58360b9f89b57c69872bc29c0ff6..a5f8ce41d3762395a02c3da786aeca003912b3f7 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -246,8 +246,14 @@ class SummaryGraph(object): ofm_vol = self.param_volume(conv_out) ifm_vol = self.param_volume(conv_in) if op['type'] == 'Conv' or op['type'] == 'Gemm': - conv_w = op['inputs'][1] - weights_vol = self.param_volume(conv_w) + if op['type'] == 'Conv': + kernel_size = self.volume(op['attrs']['kernel_shape']) + group = op['attrs']['group'] + else: + kernel_size, group = 1, 1 + n_ifm = self.param_shape(conv_in)[1] / group + n_ofm = self.param_shape(conv_out)[1] + weights_vol = kernel_size * n_ifm * n_ofm op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['weights_vol'] = weights_vol diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 8cf1e7dd2632f3c238d95de87e95afa4e5c454e3..46fedb3b6a99734b46d0bc801247e93502a1016d 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -211,6 +211,24 @@ def test_sg_macs(): assert summary_macs == sg_macs +def test_weights_size_attr(): + def test(dataset, arch, dataparallel:bool): + model = create_model(False, dataset, arch, parallel=dataparallel) + sgraph = SummaryGraph(model, get_input(dataset)) + + distiller.assign_layer_fq_names(model) + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): + op = sgraph.find_op(name) + assert op is not None + assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) + + for data_parallel in (True, False): + test('cifar10', 'resnet20_cifar', data_parallel) + test('imagenet', 'alexnet', data_parallel) + test('imagenet', 'resnext101_32x4d', data_parallel) + + if __name__ == '__main__': #test_connectivity_summary() test_sg_macs() \ No newline at end of file