diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index cf9b124929917e76e2361e759bd71e8639e223be..7ad42de3c42d75fec2b568fc427026cfefde88ce 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -418,8 +418,8 @@ def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB fid.write(png) -def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False, - rankdir='TB', styles=None): +def draw_img_classifier_to_file(model, png_fname, dataset=None, display_param_nodes=False, + rankdir='TB', styles=None, input_shape=None): """Draw a PyTorch image classifier to a PNG file. This a helper function that simplifies the interface of draw_model_to_file(). @@ -436,8 +436,10 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F styles['conv1'] = {'shape': 'oval', 'fillcolor': 'gray', 'style': 'rounded, filled'} + input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None """ - dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) + dummy_input = distiller.get_dummy_input(dataset=dataset, + device=distiller.model_device(model), input_shape=input_shape) try: non_para_model = distiller.make_non_parallel_copy(model) g = SummaryGraph(non_para_model, dummy_input)