{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Interpreting your pruning and regularization experiments\n", "This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:<br>\n", "```%run distiller_jupyter_helpers.ipynb```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Relative import of code from distiller, w/o installing the package\n", "import os\n", "import sys\n", "module_path = os.path.abspath(os.path.join('..'))\n", "if module_path not in sys.path:\n", " sys.path.append(module_path)\n", "\n", "import distiller.utils\n", "import distiller\n", "import apputils.checkpoint " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "import os\n", "import collections\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "def to_np(x):\n", " return x.cpu().numpy()\n", "\n", "def flatten(weights):\n", " weights = weights.clone().view(weights.numel())\n", " weights = to_np(weights)\n", " return weights\n", "\n", "\n", "import scipy.stats as stats\n", "def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None):\n", " weights = flatten(weights_pytorch)\n", " if remove_zeros:\n", " weights = weights[weights!=0]\n", " n, bins, patches = plt.hist(weights, bins=200)\n", " plt.title(name)\n", " \n", " if kmeans is not None:\n", " labels = kmeans.labels_\n", " centroids = kmeans.cluster_centers_\n", " cnt_coefficients = [len(labels[labels==i]) for i in range(16)]\n", " # Normalize the coefficients so they display in the same range as the float32 histogram\n", " cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients] \n", " centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients)))\n", " cnt_coefficients = list(cnt_coefficients)\n", " centroids = list(centroids)\n", " if remove_zeros:\n", " for i in range(len(centroids)):\n", " if abs(centroids[i]) < 0.0001: # almost zero\n", " centroids.remove(centroids[i])\n", " cnt_coefficients.remove(cnt_coefficients[i])\n", " break\n", " \n", " plt.plot(centroids, cnt_coefficients)\n", " zeros = [0] * len(centroids)\n", " plt.plot(centroids, zeros, 'r+', markersize=15)\n", " \n", " h = cnt_coefficients\n", " hmean = np.mean(h)\n", " hstd = np.std(h)\n", " pdf = stats.norm.pdf(h, hmean, hstd)\n", " #plt.plot(h, pdf)\n", " \n", " plt.show()\n", " print(\"mean: %f\\nstddev: %f\" % (weights.mean(), weights.std()))\n", " print(\"size=%s %d elements\" % distiller.size2str(weights_pytorch.size()))\n", " print(\"min: %.3f\\nmax:%.3f\" % (weights.min(), weights.max()))\n", "\n", " \n", "def plot_params_hist(params, which='weight', remove_zeros=False): \n", " for name, weights_pytorch in params.items():\n", " if which not in name:\n", " continue\n", " plot_params_hist_single(name, weights_pytorch, remove_zeros)\n", " \n", "def plot_params2d(classifier_weights, figsize, binary_mask=True, \n", " gmin=None, gmax=None,\n", " xlabel=\"\", ylabel=\"\", title=\"\"):\n", " if not isinstance(classifier_weights, list):\n", " classifier_weights = [classifier_weights]\n", " \n", " for weights in classifier_weights:\n", " assert weights.dim() in [2,4], \"something's wrong\"\n", " \n", " shape_str = distiller.size2str(weights.size())\n", " volume = distiller.volume(weights)\n", " \n", " # Clone because we are going to change the tensor values\n", " if binary_mask:\n", " weights2d = weights.clone()\n", " else:\n", " weights2d = weights\n", " \n", " if weights.dim() == 4:\n", " weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)\n", "\n", " sparsity = len(weights2d[weights2d==0]) / volume\n", " \n", " cmap='seismic'\n", " # create a binary image (non-zero elements are black; zeros are white)\n", " if binary_mask:\n", " cmap='binary'\n", " weights2d[weights2d!=0] = 1\n", " \n", " fig = plt.figure(figsize=figsize)\n", " if (not binary_mask) and (gmin is not None) and (gmax is not None):\n", " if isinstance(gmin, torch.Tensor):\n", " gmin = gmin.item()\n", " gmax = gmax.item()\n", " plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)\n", " else:\n", " plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)\n", " #plt.figure(figsize=(20,40))\n", " \n", " plt.xlabel(xlabel)\n", " plt.ylabel(ylabel)\n", " plt.title(title)\n", " plt.colorbar( pad=0.01, fraction=0.01)\n", " plt.show()\n", " print(\"sparsity = %.1f%% (nnz=black)\" % (sparsity*100))\n", " print(\"size=%s = %d elements\" % (shape_str, volume))\n", " \n", " \n", "def printk(k):\n", " \"\"\"Print the values of the elements of a kernel as a list\"\"\"\n", " print(list(k.view(k.numel())))\n", "\n", " \n", "def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model', \n", " gmin=None, gmax=None, interpolation=None, first_kernel=0):\n", " ofms, ifms = weights.size()[0], weights.size()[1]\n", " kw, kh = weights.size()[2], weights.size()[3]\n", " \n", " print(\"min=%.4f\\tmax=%.4f\" % (weights.min(), weights.max()))\n", " shape_str = distiller.size2str(weights.size())\n", " volume = distiller.volume(weights)\n", " print(\"size=%s = %d elements\" % (shape_str, volume))\n", " \n", " # Clone because we are going to change the tensor values\n", " weights = weights.clone()\n", " if binary_mask:\n", " weights[weights!=0] = 1\n", " # Take the inverse of the pixels, because we want zeros to appear white\n", " #weights = 1 - weights\n", " \n", " kernels = weights.view(ofms * ifms, kh, kw)\n", " nrow, ncol = layout[0], layout[1]\n", "\n", " # Plot the graph\n", " plt.gray()\n", " #plt.tight_layout()\n", " fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );\n", "\n", " # We want to normalize the grayscale brightness levels for all of the images we display (group),\n", " # otherwise, each image is normalized separately and this causes distortion between the different\n", " # filters images we ddisplay.\n", " # We don't normalize across all of the filters images, because the outliers cause the image of each \n", " # filter to be very muted. This is because each group of filters we display usually has low variance\n", " # between the element values of that group.\n", " if color_normalization=='Tensor':\n", " gmin = weights.min()\n", " gmax = weights.max()\n", " elif color_normalization=='Group':\n", " gmin = weights[0:nrow, 0:ncol].min()\n", " gmax = weights[0:nrow, 0:ncol].max()\n", " print(\"gmin=%.4f\\tgmax=%.4f\" % (gmin, gmax))\n", " if isinstance(gmin, torch.Tensor):\n", " gmin = gmin.item()\n", " gmax = gmax.item()\n", " \n", " i = 0 \n", " for row in range(0, nrow):\n", " for col in range (0, ncol):\n", " ax = fig.add_subplot(layout[0], layout[1], i+1)\n", " if binary_mask:\n", " ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1);\n", " else:\n", " # Use siesmic so that colors around the center are lighter. Red and blue are used\n", " # to represent (and visually separate) negative and positive weights \n", " ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);\n", " ax.set(xticks=[], yticks=[])\n", " i += 1\n", " \n", " \n", "def l1_norm_histogram(weights):\n", " \"\"\"Compute a histogram of the L1-norms of the kernels of a weights tensor.\n", " \n", " The L1-norm of a kernel is one way to quantify the \"magnitude\" of the total coeffiecients\n", " making up this kernel.\n", " \n", " Another interesting look at filters is to compute a histogram per filter.\n", " \"\"\"\n", " ofms, ifms = weights.size()[0], weights.size()[1]\n", " kw, kh = weights.size()[2], weights.size()[3]\n", " kernels = weights.view(ofms * ifms, kh, kw)\n", " \n", " l1_hist = []\n", " for kernel in range(ofms*ifms):\n", " l1_hist.append(kernels[kernel].norm(1))\n", " return l1_hist\n", "\n", "def plot_l1_norm_hist(weights): \n", " l1_hist = l1_norm_histogram(weights)\n", " n, bins, patches = plt.hist(l1_hist, bins=200)\n", " plt.title('Kernel L1-norm histograms')\n", " plt.ylabel('Frequency')\n", " plt.xlabel('Kernel L1-norm')\n", " plt.show()\n", " \n", "\n", "def plot_layer_sizes(which, sparse_model, dense_model):\n", " dense = []\n", " sparse = []\n", " names = []\n", " for name, sparse_weights in sparse_model.state_dict().items():\n", " if ('weight' not in name) or (which!='*' and which not in name):\n", " continue \n", " sparse.append(len(sparse_weights[sparse_weights!=0]))\n", " names.append(name)\n", "\n", " for name, dense_weights in dense_model.state_dict().items():\n", " if ('weight' not in name) or (which!='*' and which not in name):\n", " continue\n", " dense.append(dense_weights.numel())\n", "\n", " N = len(sparse)\n", " ind = np.arange(N) # the x locations for the groups\n", "\n", " fig, ax = plt.subplots()\n", " width = .47\n", " p1 = plt.bar(ind, dense, width = .47, color = '#278DBC')\n", " p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099')\n", "\n", " plt.ylabel('Size')\n", " plt.title('Layer sizes')\n", " plt.xticks(rotation='vertical')\n", " plt.xticks(ind, names)\n", " #plt.yticks(np.arange(0, 100, 150))\n", " plt.legend((p1[0], p2[0]), ('Dense', 'Sparse'))\n", "\n", " #Remove plot borders\n", " for location in ['right', 'left', 'top', 'bottom']:\n", " ax.spines[location].set_visible(False) \n", "\n", " #Fix grid to be horizontal lines only and behind the plots\n", " ax.yaxis.grid(color='gray', linestyle='solid')\n", " ax.set_axisbelow(True)\n", " plt.show()\n", " \n", " \n", "def conv_param_names(model):\n", " return [param_name for param_name, p in model.state_dict().items() \n", " if (p.dim()>2) and (\"weight\" in param_name)]\n", "\n", "def conv_fc_param_names(model):\n", " return [param_name for param_name, p in model.state_dict().items() \n", " if (p.dim()>1) and (\"weight\" in param_name)]\n", "\n", "def conv_fc_params(model):\n", " return [(param_name,p) for (param_name, p) in model.state_dict()\n", " if (p.dim()>1) and (\"weight\" in param_name)]\n", "\n", "def fc_param_names(model):\n", " return [param_name for param_name, p in model.state_dict().items() \n", " if (p.dim()==2) and (\"weight\" in param_name)]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_bars(which, setA, setAName, setB, setBName, names, title):\n", " N = len(setA)\n", " ind = np.arange(N) # the x locations for the groups\n", "\n", " fig, ax = plt.subplots(figsize=(20,10))\n", " width = .47\n", " p1 = plt.bar(ind, setA, width = .47, color = '#278DBC')\n", " p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')\n", "\n", " plt.ylabel('Size')\n", " plt.title(title)\n", " plt.xticks(rotation='vertical')\n", " plt.xticks(ind, names)\n", " #plt.yticks(np.arange(0, 100, 150))\n", " plt.legend((p1[0], p2[0]), (setAName, setBName))\n", "\n", " #Remove plot borders\n", " for location in ['right', 'left', 'top', 'bottom']:\n", " ax.spines[location].set_visible(False) \n", "\n", " #Fix grid to be horizontal lines only and behind the plots\n", " ax.yaxis.grid(color='gray', linestyle='solid')\n", " ax.set_axisbelow(True)\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }