From 093a2e17be29de8dc1c0197a1a1b3890cf67ac48 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Wed, 19 Jun 2019 13:49:43 +0300
Subject: [PATCH] Model drawing: allow rendering models w/o specifying the
 dataset (#294)

This commit allows us to draw diagrams of models, even if we don't support the specific dataset used by the model.  All you need is to specify the dimensions of the model inputs.
* Added `input_shape` argument to `draw_img_classifier_to_file`
* updated docstring
---
 distiller/model_summaries.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index cf9b124..7ad42de 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -418,8 +418,8 @@ def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB
         fid.write(png)
 
 
-def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False,
-                                rankdir='TB', styles=None):
+def draw_img_classifier_to_file(model, png_fname, dataset=None, display_param_nodes=False,
+                                rankdir='TB', styles=None, input_shape=None):
     """Draw a PyTorch image classifier to a PNG file.  This a helper function that
     simplifies the interface of draw_model_to_file().
 
@@ -436,8 +436,10 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
                 styles['conv1'] = {'shape': 'oval',
                                    'fillcolor': 'gray',
                                    'style': 'rounded, filled'}
+        input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
     """
-    dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
+    dummy_input = distiller.get_dummy_input(dataset=dataset,
+                                            device=distiller.model_device(model), input_shape=input_shape)
     try:
         non_para_model = distiller.make_non_parallel_copy(model)
         g = SummaryGraph(non_para_model, dummy_input)
-- 
GitLab