diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index f7beec7b429a455c8332686e9e424ec97357fe14..b1c04a7a8f67d1dec4bb0460986d2d20c0ad8ca8 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -257,17 +257,27 @@ def draw_model_to_file(sgraph, png_fname): fid.write(png) def draw_img_classifier_to_file(model, png_fname, dataset): - if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) - else: - print("Unsupported dataset (%s) - aborting draw operation" % dataset) - return - - g = SummaryGraph(model, dummy_input) - draw_model_to_file(g, png_fname) - + try: + if dataset == 'imagenet': + dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) + elif dataset == 'cifar10': + dummy_input = Variable(torch.randn(1, 3, 32, 32)) + else: + print("Unsupported dataset (%s) - aborting draw operation" % dataset) + return + + g = SummaryGraph(model, dummy_input) + draw_model_to_file(g, png_fname) + print("Network PNG image generation completed") + except TypeError as e: + print("An error has occured while generating the network PNG image.") + print("This feature is not supported on official PyTorch releases.") + print("Please check that you are using a valid PyTorch version.") + print("You are using pytorch version %s" %torch.__version__) + except FileNotFoundError: + print("An error has occured while generating the network PNG image.") + print("Please check that you have graphviz installed.") + print("\t$ sudo apt-get install graphviz") def create_png(sgraph): """Create a PNG object containing a graphiz-dot graph of the netowrk represented diff --git a/jupyter/experimental.ipynb b/jupyter/experimental.ipynb index 098ada8acb978af09c71ccc5d0b263de5a540ff3..d73e8768a8beb634e090993ff78813620ae5d03b 100644 --- a/jupyter/experimental.ipynb +++ b/jupyter/experimental.ipynb @@ -9,7 +9,11 @@ "<br>\n", "<font size=\"6\" color=\"red\"> ⚠ WARNING </font>\n", "<br>\n", - "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font>\n", + "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font><br><br>\n", + "Please also note that for generating a PNG image of the network (last cell of the notebook), you will need to have graphviz installed:\n", + " ```\n", + " $ sudo apt-get install graphviz\n", + " ```\n", "<br>" ] },