diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index d3facd54e9ff58360b9f89b57c69872bc29c0ff6..d246fd58dc13ea4c6e2f10399e015a719f830890 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -21,7 +21,6 @@ import collections
 import torch
 import torch.jit as jit
 import logging
-from collections import OrderedDict
 msglogger = logging.getLogger()
 
 
@@ -100,17 +99,17 @@ class SummaryGraph(object):
             
             device = next(model_clone.parameters()).device
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
-            trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
+            trace, _ = jit.get_trace_graph(model_clone, dummy_input)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
 
             graph = trace.graph()
-            self.ops = OrderedDict()
-            self.params = OrderedDict()
+            self.ops = {}
+            self.params = {}
             self.edges = []
-            self.temp = OrderedDict()
+            self.temp = {}
 
             in_out = list(graph.inputs()) + list(graph.outputs())
             for param in in_out:
@@ -149,7 +148,7 @@ class SummaryGraph(object):
                     self.__add_output(new_op, output)
                     self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName()))
 
-                new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
+                new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
 
         self.add_macs_attr()
         self.add_footprint_attr()
@@ -157,7 +156,7 @@ class SummaryGraph(object):
         del model_clone
 
     def __create_op(self, onnx_node):
-        op = OrderedDict()
+        op = {}
         op['name'] = onnx_node.scopeName()
         op['orig-name'] = onnx_node.scopeName()
         op['type'] = onnx_node.kind().lstrip('::onnx')
@@ -189,7 +188,7 @@ class SummaryGraph(object):
         return param
 
     def __tensor_desc(self, n):
-        tensor = OrderedDict()
+        tensor = {}
         tensor['id'] = n.uniqueName()
         try:
             # try parsing the FM tensor type.  For example: Float(1, 64, 8, 8)
@@ -246,8 +245,14 @@ class SummaryGraph(object):
                 ofm_vol = self.param_volume(conv_out)
                 ifm_vol = self.param_volume(conv_in)
                 if op['type'] == 'Conv' or op['type'] == 'Gemm':
-                    conv_w = op['inputs'][1]
-                    weights_vol = self.param_volume(conv_w)
+                    if op['type'] == 'Conv':
+                        kernel_size =  self.volume(op['attrs']['kernel_shape'])
+                        group = op['attrs']['group']
+                    else:
+                        kernel_size, group = 1, 1
+                    n_ifm = self.param_shape(conv_in)[1] / group
+                    n_ofm = self.param_shape(conv_out)[1] 
+                    weights_vol = kernel_size * n_ifm * n_ofm
                     op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
                     op['attrs']['fm_vol'] = ofm_vol + ifm_vol
                     op['attrs']['weights_vol'] = weights_vol
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 8cf1e7dd2632f3c238d95de87e95afa4e5c454e3..f74bfdaafab546091b58a44fb39e54cbc7c43494 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -116,6 +116,24 @@ def test_layer_search():
     assert preds == ['layer1.0.conv2', 'conv1']
 
 
+def test_weights_size_attr():
+    def test(dataset, arch, dataparallel:bool):
+        model = create_model(False, dataset, arch, parallel=False)
+        sgraph = SummaryGraph(model, get_input(dataset))
+
+        distiller.assign_layer_fq_names(model)
+        for name, mod in model.named_modules():
+            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
+                op = sgraph.find_op(name)
+                assert op is not None
+                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
+
+    for data_parallel in (True, False):
+        test('cifar10', 'resnet20_cifar', data_parallel)
+        test('imagenet', 'alexnet', data_parallel)
+        test('imagenet', 'resnext101_32x4d', data_parallel)
+
+
 def test_vgg():
     g = create_graph('imagenet', 'vgg19')
     assert g is not None
@@ -170,11 +188,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
 
 
 def test_named_params_layers():
-    for dataParallelModel in (True, False):
-        named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
-        named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
-        named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
-        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
+    for data_parallel in (True, False):
+        named_params_layers_test_aux('imagenet', 'vgg19', data_parallel)
+        named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel)
+        named_params_layers_test_aux('imagenet', 'alexnet', data_parallel)
+        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel)
 
 
 def test_onnx_name_2_pytorch_name():
@@ -213,4 +231,4 @@ def test_sg_macs():
 
 if __name__ == '__main__':
     #test_connectivity_summary()
-    test_sg_macs()
\ No newline at end of file
+    test_sg_macs()