Skip to content
Snippets Groups Projects
Commit 093a2e17 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by Neta Zmora
Browse files

Model drawing: allow rendering models w/o specifying the dataset (#294)

This commit allows us to draw diagrams of models, even if we don't support the specific dataset used by the model.  All you need is to specify the dimensions of the model inputs.
* Added `input_shape` argument to `draw_img_classifier_to_file`
* updated docstring
parent 2cab7741
No related branches found
No related tags found
No related merge requests found
......@@ -418,8 +418,8 @@ def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB
fid.write(png)
def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False,
rankdir='TB', styles=None):
def draw_img_classifier_to_file(model, png_fname, dataset=None, display_param_nodes=False,
rankdir='TB', styles=None, input_shape=None):
"""Draw a PyTorch image classifier to a PNG file. This a helper function that
simplifies the interface of draw_model_to_file().
......@@ -436,8 +436,10 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None
"""
dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
dummy_input = distiller.get_dummy_input(dataset=dataset,
device=distiller.model_device(model), input_shape=input_shape)
try:
non_para_model = distiller.make_non_parallel_copy(model)
g = SummaryGraph(non_para_model, dummy_input)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment