From aa8862bdb49b619ee8d547132512e575e1089b79 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 27 Aug 2018 18:31:34 +0300
Subject: [PATCH] Fix PyTorch 0.4 compatability issue

Sometimes the gmin/gmax in group color-normalization ends up with a zero
dimensional tensor, which needs to be accessed using .item()
---
 jupyter/distiller_jupyter_helpers.ipynb | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/jupyter/distiller_jupyter_helpers.ipynb b/jupyter/distiller_jupyter_helpers.ipynb
index c3d45c0..2c9252f 100755
--- a/jupyter/distiller_jupyter_helpers.ipynb
+++ b/jupyter/distiller_jupyter_helpers.ipynb
@@ -113,7 +113,7 @@
     "            weights2d = weights.clone()\n",
     "        else:\n",
     "            weights2d = weights\n",
-    "\n",
+    " \n",
     "        if weights.dim() == 4:\n",
     "            weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)\n",
     "\n",
@@ -127,6 +127,9 @@
     "                    \n",
     "        fig = plt.figure(figsize=figsize)\n",
     "        if (not binary_mask) and (gmin is not None) and (gmax is not None):\n",
+    "            if isinstance(gmin, torch.Tensor):\n",
+    "                gmin = gmin.item()\n",
+    "                gmax = gmax.item()\n",
     "            plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)\n",
     "        else:\n",
     "            plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)\n",
@@ -169,7 +172,7 @@
     "    # Plot the graph\n",
     "    plt.gray()\n",
     "    #plt.tight_layout()\n",
-    "    fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) )\n",
+    "    fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );\n",
     "\n",
     "    # We want to normalize the grayscale brightness levels for all of the images we display (group),\n",
     "    # otherwise, each image is normalized separately and this causes distortion between the different\n",
@@ -184,6 +187,9 @@
     "        gmin = weights[0:nrow, 0:ncol].min()\n",
     "        gmax = weights[0:nrow, 0:ncol].max()\n",
     "    print(\"gmin=%.4f\\tgmax=%.4f\" % (gmin, gmax))\n",
+    "    if isinstance(gmin, torch.Tensor):\n",
+    "        gmin = gmin.item()\n",
+    "        gmax = gmax.item()\n",
     "    \n",
     "    i = 0 \n",
     "    for row in range(0, nrow):\n",
@@ -197,8 +203,6 @@
     "                ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);\n",
     "            ax.set(xticks=[], yticks=[])\n",
     "            i += 1\n",
-    "    #plt.show();\n",
-    "    #return fig\n",
     "    \n",
     "    \n",
     "def l1_norm_histogram(weights):\n",
-- 
GitLab