diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index cf9b124929917e76e2361e759bd71e8639e223be..7ad42de3c42d75fec2b568fc427026cfefde88ce 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -418,8 +418,8 @@ def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB
         fid.write(png)
 
 
-def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False,
-                                rankdir='TB', styles=None):
+def draw_img_classifier_to_file(model, png_fname, dataset=None, display_param_nodes=False,
+                                rankdir='TB', styles=None, input_shape=None):
     """Draw a PyTorch image classifier to a PNG file.  This a helper function that
     simplifies the interface of draw_model_to_file().
 
@@ -436,8 +436,10 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
                 styles['conv1'] = {'shape': 'oval',
                                    'fillcolor': 'gray',
                                    'style': 'rounded, filled'}
+        input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
     """
-    dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
+    dummy_input = distiller.get_dummy_input(dataset=dataset,
+                                            device=distiller.model_device(model), input_shape=input_shape)
     try:
         non_para_model = distiller.make_non_parallel_copy(model)
         g = SummaryGraph(non_para_model, dummy_input)