From 671038d5489a8b6c02019498518e5d9c12bfa430 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 25 Apr 2018 18:47:17 +0300 Subject: [PATCH] Add some protection around the experimental features not supported in PyTorch 3.1 --- apputils/model_summaries.py | 32 +++++++++++++++++++++----------- jupyter/experimental.ipynb | 6 +++++- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index f7beec7..b1c04a7 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -257,17 +257,27 @@ def draw_model_to_file(sgraph, png_fname): fid.write(png) def draw_img_classifier_to_file(model, png_fname, dataset): - if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) - else: - print("Unsupported dataset (%s) - aborting draw operation" % dataset) - return - - g = SummaryGraph(model, dummy_input) - draw_model_to_file(g, png_fname) - + try: + if dataset == 'imagenet': + dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) + elif dataset == 'cifar10': + dummy_input = Variable(torch.randn(1, 3, 32, 32)) + else: + print("Unsupported dataset (%s) - aborting draw operation" % dataset) + return + + g = SummaryGraph(model, dummy_input) + draw_model_to_file(g, png_fname) + print("Network PNG image generation completed") + except TypeError as e: + print("An error has occured while generating the network PNG image.") + print("This feature is not supported on official PyTorch releases.") + print("Please check that you are using a valid PyTorch version.") + print("You are using pytorch version %s" %torch.__version__) + except FileNotFoundError: + print("An error has occured while generating the network PNG image.") + print("Please check that you have graphviz installed.") + print("\t$ sudo apt-get install graphviz") def create_png(sgraph): """Create a PNG object containing a graphiz-dot graph of the netowrk represented diff --git a/jupyter/experimental.ipynb b/jupyter/experimental.ipynb index 098ada8..d73e876 100644 --- a/jupyter/experimental.ipynb +++ b/jupyter/experimental.ipynb @@ -9,7 +9,11 @@ "<br>\n", "<font size=\"6\" color=\"red\"> ⚠ WARNING </font>\n", "<br>\n", - "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font>\n", + "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font><br><br>\n", + "Please also note that for generating a PNG image of the network (last cell of the notebook), you will need to have graphviz installed:\n", + " ```\n", + " $ sudo apt-get install graphviz\n", + " ```\n", "<br>" ] }, -- GitLab