From 671038d5489a8b6c02019498518e5d9c12bfa430 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 25 Apr 2018 18:47:17 +0300
Subject: [PATCH] Add some protection around the experimental features not
 supported in PyTorch 3.1

---
 apputils/model_summaries.py | 32 +++++++++++++++++++++-----------
 jupyter/experimental.ipynb  |  6 +++++-
 2 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py
index f7beec7..b1c04a7 100755
--- a/apputils/model_summaries.py
+++ b/apputils/model_summaries.py
@@ -257,17 +257,27 @@ def draw_model_to_file(sgraph, png_fname):
         fid.write(png)
 
 def draw_img_classifier_to_file(model, png_fname, dataset):
-    if dataset == 'imagenet':
-        dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
-    elif dataset == 'cifar10':
-        dummy_input = Variable(torch.randn(1, 3, 32, 32))
-    else:
-        print("Unsupported dataset (%s) - aborting draw operation" % dataset)
-        return
-
-    g = SummaryGraph(model, dummy_input)
-    draw_model_to_file(g, png_fname)
-
+    try:
+        if dataset == 'imagenet':
+            dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
+        elif dataset == 'cifar10':
+            dummy_input = Variable(torch.randn(1, 3, 32, 32))
+        else:
+            print("Unsupported dataset (%s) - aborting draw operation" % dataset)
+            return
+
+        g = SummaryGraph(model, dummy_input)
+        draw_model_to_file(g, png_fname)
+        print("Network PNG image generation completed")
+    except TypeError as e:
+        print("An error has occured while generating the network PNG image.")
+        print("This feature is not supported on official PyTorch releases.")
+        print("Please check that you are using a valid PyTorch version.")
+        print("You are using pytorch version %s" %torch.__version__)
+    except FileNotFoundError:
+        print("An error has occured while generating the network PNG image.")
+        print("Please check that you have graphviz installed.")
+        print("\t$ sudo apt-get install graphviz")
 
 def create_png(sgraph):
     """Create a PNG object containing a graphiz-dot graph of the netowrk represented
diff --git a/jupyter/experimental.ipynb b/jupyter/experimental.ipynb
index 098ada8..d73e876 100644
--- a/jupyter/experimental.ipynb
+++ b/jupyter/experimental.ipynb
@@ -9,7 +9,11 @@
     "<br>\n",
     "<font size=\"6\" color=\"red\"> &#9888; WARNING </font>\n",
     "<br>\n",
-    "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font>\n",
+    "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font><br><br>\n",
+    "Please also note that for generating a PNG image of the network (last cell of the notebook), you will need to have graphviz installed:\n",
+    "   ```\n",
+    "   $ sudo apt-get install graphviz\n",
+    "   ```\n",
     "<br>"
    ]
   },
-- 
GitLab