diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index 5a89685e22b7ff74fa3840fe1f7afab7ac5b07f2..cf9b124929917e76e2361e759bd71e8639e223be 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -51,7 +51,7 @@ def model_summary(model, what, dataset=None):
         distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger])
     elif what == 'compute':
         try:
-            dummy_input = dataset_dummy_input(dataset)
+            dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
         except ValueError as e:
             print(e)
             return
@@ -319,7 +319,7 @@ def connectivity_tbl_summary(sgraph, verbose=False):
     return tabulate(df, headers='keys', tablefmt='psql')
 
 
-def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', styles=None):
+def create_pydot_graph(op_nodes_desc, data_nodes, param_nodes, edges, rankdir='TB', styles=None):
     """Low-level API to create a PyDot graph (dot formatted).
     """
     pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir=rankdir)
@@ -328,7 +328,7 @@ def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', s
                      'fillcolor': '#6495ED',
                      'style': 'rounded, filled'}
 
-    for op_node in op_nodes:
+    for op_node in op_nodes_desc:
         style = op_node_style
         # Check if we should override the style of this node.
         if styles is not None and op_node[0] in styles:
@@ -368,6 +368,12 @@ def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
                                    'style': 'rounded, filled'}
     """
 
+    def annotate_op_node(op):
+        if op['type'] == 'Conv':
+            return ["sh={}".format(distiller.size2str(op['attrs']['kernel_shape'])),
+                    "g={}".format(str(op['attrs']['group']))]
+        return ''   
+
     op_nodes = [op['name'] for op in sgraph.ops.values()]
     data_nodes = []
     param_nodes = []
@@ -385,8 +391,8 @@ def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
         edges = [edge for edge in sgraph.edges if edge.src in non_param_ids]
         param_nodes = None
 
-    op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()]
-    pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir, styles)
+    op_nodes_desc = [(op['name'], op['type'], *annotate_op_node(op)) for op in sgraph.ops.values()]
+    pydot_graph = create_pydot_graph(op_nodes_desc, data_nodes, param_nodes, edges, rankdir, styles)
     png = pydot_graph.create_png()
     return png
 
@@ -431,7 +437,7 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
                                    'fillcolor': 'gray',
                                    'style': 'rounded, filled'}
     """
-    dummy_input = dataset_dummy_input(dataset)
+    dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
     try:
         non_para_model = distiller.make_non_parallel_copy(model)
         g = SummaryGraph(non_para_model, dummy_input)
@@ -446,16 +452,6 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
         del non_para_model
 
 
-def dataset_dummy_input(dataset):
-    if dataset == 'imagenet':
-        dummy_input = torch.randn(1, 3, 224, 224)
-    elif dataset == 'cifar10':
-        dummy_input = torch.randn(1, 3, 32, 32)
-    else:
-        raise ValueError("Unsupported dataset (%s) - aborting operation" % dataset)
-    return dummy_input
-
-
 def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, **kwargs):
     """Export a PyTorch image classifier to ONNX.
 
@@ -463,7 +459,7 @@ def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True,
         add_softmax: when True, adds softmax layer to the output model.
         kwargs: arguments to be passed to torch.onnx.export
     """
-    dummy_input = dataset_dummy_input(dataset).to('cuda')
+    dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
     # Pytorch doesn't support exporting modules wrapped in DataParallel
     non_para_model = distiller.make_non_parallel_copy(model)
 
diff --git a/distiller/utils.py b/distiller/utils.py
index 580fa518ccb1b6ed5880fc99dbe113d148c2a2e4..f0b24a57afc72ed8b7ea0e7be0307ec1c362b2d2 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -60,12 +60,14 @@ def size2str(torch_size):
         return size_to_str(torch_size.size())
     if isinstance(torch_size, torch.autograd.Variable):
         return size_to_str(torch_size.data.size())
+    if isinstance(torch_size, tuple) or isinstance(torch_size, list):
+        return size_to_str(torch_size)
     raise TypeError
 
 
 def size_to_str(torch_size):
     """Convert a pytorch Size object to a string"""
-    assert isinstance(torch_size, torch.Size)
+    assert isinstance(torch_size, torch.Size) or isinstance(torch_size, tuple) or isinstance(torch_size, list)
     return '('+(', ').join(['%d' % v for v in torch_size])+')'