Skip to content
Snippets Groups Projects
Commit 08b5cd95 authored by Neta Zmora's avatar Neta Zmora
Browse files

SummaryGraph: fix ‘weights_vol’ attribute for conv and linear layers

The weights_vol attribute reflects the size (volume) of an SG node’s
weights tensor.  The calculation of the weights volume was wrong.
This does not have any significant impact because this attribute is
not used.
parent a0ebeb7e
No related branches found
No related tags found
No related merge requests found
...@@ -246,8 +246,14 @@ class SummaryGraph(object): ...@@ -246,8 +246,14 @@ class SummaryGraph(object):
ofm_vol = self.param_volume(conv_out) ofm_vol = self.param_volume(conv_out)
ifm_vol = self.param_volume(conv_in) ifm_vol = self.param_volume(conv_in)
if op['type'] == 'Conv' or op['type'] == 'Gemm': if op['type'] == 'Conv' or op['type'] == 'Gemm':
conv_w = op['inputs'][1] if op['type'] == 'Conv':
weights_vol = self.param_volume(conv_w) kernel_size = self.volume(op['attrs']['kernel_shape'])
group = op['attrs']['group']
else:
kernel_size, group = 1, 1
n_ifm = self.param_shape(conv_in)[1] / group
n_ofm = self.param_shape(conv_out)[1]
weights_vol = kernel_size * n_ifm * n_ofm
op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol
op['attrs']['weights_vol'] = weights_vol op['attrs']['weights_vol'] = weights_vol
......
...@@ -211,6 +211,24 @@ def test_sg_macs(): ...@@ -211,6 +211,24 @@ def test_sg_macs():
assert summary_macs == sg_macs assert summary_macs == sg_macs
def test_weights_size_attr():
def test(dataset, arch, dataparallel:bool):
model = create_model(False, dataset, arch, parallel=dataparallel)
sgraph = SummaryGraph(model, get_input(dataset))
distiller.assign_layer_fq_names(model)
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
op = sgraph.find_op(name)
assert op is not None
assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
for data_parallel in (True, False):
test('cifar10', 'resnet20_cifar', data_parallel)
test('imagenet', 'alexnet', data_parallel)
test('imagenet', 'resnext101_32x4d', data_parallel)
if __name__ == '__main__': if __name__ == '__main__':
#test_connectivity_summary() #test_connectivity_summary()
test_sg_macs() test_sg_macs()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment