diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index d3facd54e9ff58360b9f89b57c69872bc29c0ff6..a5f8ce41d3762395a02c3da786aeca003912b3f7 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -246,8 +246,14 @@ class SummaryGraph(object):
                 ofm_vol = self.param_volume(conv_out)
                 ifm_vol = self.param_volume(conv_in)
                 if op['type'] == 'Conv' or op['type'] == 'Gemm':
-                    conv_w = op['inputs'][1]
-                    weights_vol = self.param_volume(conv_w)
+                    if op['type'] == 'Conv':
+                        kernel_size =  self.volume(op['attrs']['kernel_shape'])
+                        group = op['attrs']['group']
+                    else:
+                        kernel_size, group = 1, 1
+                    n_ifm = self.param_shape(conv_in)[1] / group
+                    n_ofm = self.param_shape(conv_out)[1] 
+                    weights_vol = kernel_size * n_ifm * n_ofm
                     op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
                     op['attrs']['fm_vol'] = ofm_vol + ifm_vol
                     op['attrs']['weights_vol'] = weights_vol
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 8cf1e7dd2632f3c238d95de87e95afa4e5c454e3..46fedb3b6a99734b46d0bc801247e93502a1016d 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -211,6 +211,24 @@ def test_sg_macs():
             assert summary_macs == sg_macs
  
 
+def test_weights_size_attr():
+    def test(dataset, arch, dataparallel:bool):
+        model = create_model(False, dataset, arch, parallel=dataparallel)
+        sgraph = SummaryGraph(model, get_input(dataset))
+
+        distiller.assign_layer_fq_names(model)
+        for name, mod in model.named_modules():
+            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
+                op = sgraph.find_op(name)
+                assert op is not None
+                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
+
+    for data_parallel in (True, False):
+        test('cifar10', 'resnet20_cifar', data_parallel)
+        test('imagenet', 'alexnet', data_parallel)
+        test('imagenet', 'resnext101_32x4d', data_parallel)
+
+
 if __name__ == '__main__':
     #test_connectivity_summary()
     test_sg_macs()
\ No newline at end of file