From aa8862bdb49b619ee8d547132512e575e1089b79 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 27 Aug 2018 18:31:34 +0300 Subject: [PATCH] Fix PyTorch 0.4 compatability issue Sometimes the gmin/gmax in group color-normalization ends up with a zero dimensional tensor, which needs to be accessed using .item() --- jupyter/distiller_jupyter_helpers.ipynb | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/jupyter/distiller_jupyter_helpers.ipynb b/jupyter/distiller_jupyter_helpers.ipynb index c3d45c0..2c9252f 100755 --- a/jupyter/distiller_jupyter_helpers.ipynb +++ b/jupyter/distiller_jupyter_helpers.ipynb @@ -113,7 +113,7 @@ " weights2d = weights.clone()\n", " else:\n", " weights2d = weights\n", - "\n", + " \n", " if weights.dim() == 4:\n", " weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)\n", "\n", @@ -127,6 +127,9 @@ " \n", " fig = plt.figure(figsize=figsize)\n", " if (not binary_mask) and (gmin is not None) and (gmax is not None):\n", + " if isinstance(gmin, torch.Tensor):\n", + " gmin = gmin.item()\n", + " gmax = gmax.item()\n", " plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)\n", " else:\n", " plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)\n", @@ -169,7 +172,7 @@ " # Plot the graph\n", " plt.gray()\n", " #plt.tight_layout()\n", - " fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) )\n", + " fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );\n", "\n", " # We want to normalize the grayscale brightness levels for all of the images we display (group),\n", " # otherwise, each image is normalized separately and this causes distortion between the different\n", @@ -184,6 +187,9 @@ " gmin = weights[0:nrow, 0:ncol].min()\n", " gmax = weights[0:nrow, 0:ncol].max()\n", " print(\"gmin=%.4f\\tgmax=%.4f\" % (gmin, gmax))\n", + " if isinstance(gmin, torch.Tensor):\n", + " gmin = gmin.item()\n", + " gmax = gmax.item()\n", " \n", " i = 0 \n", " for row in range(0, nrow):\n", @@ -197,8 +203,6 @@ " ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);\n", " ax.set(xticks=[], yticks=[])\n", " i += 1\n", - " #plt.show();\n", - " #return fig\n", " \n", " \n", "def l1_norm_histogram(weights):\n", -- GitLab