{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpreting your pruning and regularization experiments\n",
    "This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:<br>\n",
    "```%run distiller_jupyter_helpers.ipynb```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Relative import of code from distiller, w/o installing the package\n",
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "import distiller.utils\n",
    "import distiller\n",
    "import apputils.checkpoint "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import os\n",
    "import collections\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def to_np(x):\n",
    "    return x.cpu().numpy()\n",
    "\n",
    "def flatten(weights):\n",
    "    weights = weights.clone().view(weights.numel())\n",
    "    weights = to_np(weights)\n",
    "    return weights\n",
    "\n",
    "\n",
    "import scipy.stats as stats\n",
    "def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None):\n",
    "    weights = flatten(weights_pytorch)\n",
    "    if remove_zeros:\n",
    "        weights = weights[weights!=0]\n",
    "    n, bins, patches = plt.hist(weights, bins=200)\n",
    "    plt.title(name)\n",
    "    \n",
    "    if kmeans is not None:\n",
    "        labels = kmeans.labels_\n",
    "        centroids = kmeans.cluster_centers_\n",
    "        cnt_coefficients = [len(labels[labels==i]) for i in range(16)]\n",
    "        # Normalize the coefficients so they display in the same range as the float32 histogram\n",
    "        cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients] \n",
    "        centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients)))\n",
    "        cnt_coefficients = list(cnt_coefficients)\n",
    "        centroids = list(centroids)\n",
    "        if remove_zeros:\n",
    "            for i in range(len(centroids)):\n",
    "                if abs(centroids[i]) < 0.0001:  # almost zero\n",
    "                    centroids.remove(centroids[i])\n",
    "                    cnt_coefficients.remove(cnt_coefficients[i])\n",
    "                    break\n",
    "        \n",
    "        plt.plot(centroids, cnt_coefficients)\n",
    "        zeros = [0] * len(centroids)\n",
    "        plt.plot(centroids, zeros, 'r+', markersize=15)\n",
    "        \n",
    "        h = cnt_coefficients\n",
    "        hmean = np.mean(h)\n",
    "        hstd = np.std(h)\n",
    "        pdf = stats.norm.pdf(h, hmean, hstd)\n",
    "        #plt.plot(h, pdf)\n",
    "        \n",
    "    plt.show()\n",
    "    print(\"mean: %f\\nstddev: %f\" % (weights.mean(), weights.std()))\n",
    "    print(\"size=%s %d elements\" % distiller.size2str(weights_pytorch.size()))\n",
    "    print(\"min: %.3f\\nmax:%.3f\" % (weights.min(), weights.max()))\n",
    "\n",
    "    \n",
    "def plot_params_hist(params, which='weight', remove_zeros=False):      \n",
    "    for name, weights_pytorch in params.items():\n",
    "        if which not in name:\n",
    "            continue\n",
    "        plot_params_hist_single(name, weights_pytorch, remove_zeros)\n",
    "        \n",
    "def plot_params2d(classifier_weights, figsize, binary_mask=True, \n",
    "                  gmin=None, gmax=None,\n",
    "                  xlabel=\"\", ylabel=\"\", title=\"\"):\n",
    "    if not isinstance(classifier_weights, list):\n",
    "        classifier_weights = [classifier_weights]\n",
    "    \n",
    "    for weights in classifier_weights:\n",
    "        assert weights.dim() in [2,4], \"something's wrong\"\n",
    "        \n",
    "        shape_str = distiller.size2str(weights.size())\n",
    "        volume = distiller.volume(weights)\n",
    "        \n",
    "        # Clone because we are going to change the tensor values\n",
    "        if binary_mask:\n",
    "            weights2d = weights.clone()\n",
    "        else:\n",
    "            weights2d = weights\n",
    " \n",
    "        if weights.dim() == 4:\n",
    "            weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)\n",
    "\n",
    "        sparsity = len(weights2d[weights2d==0]) / volume\n",
    "        \n",
    "        cmap='seismic'\n",
    "        # create a binary image (non-zero elements are black; zeros are white)\n",
    "        if binary_mask:\n",
    "            cmap='binary'\n",
    "            weights2d[weights2d!=0] = 1\n",
    "                    \n",
    "        fig = plt.figure(figsize=figsize)\n",
    "        if (not binary_mask) and (gmin is not None) and (gmax is not None):\n",
    "            if isinstance(gmin, torch.Tensor):\n",
    "                gmin = gmin.item()\n",
    "                gmax = gmax.item()\n",
    "            plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)\n",
    "        else:\n",
    "            plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)\n",
    "        #plt.figure(figsize=(20,40))\n",
    "        \n",
    "        plt.xlabel(xlabel)\n",
    "        plt.ylabel(ylabel)\n",
    "        plt.title(title)\n",
    "        plt.colorbar( pad=0.01, fraction=0.01)\n",
    "        plt.show()\n",
    "        print(\"sparsity = %.1f%% (nnz=black)\" % (sparsity*100))\n",
    "        print(\"size=%s = %d elements\" % (shape_str, volume))\n",
    "        \n",
    "        \n",
    "def printk(k):\n",
    "    \"\"\"Print the values of the elements of a kernel as a list\"\"\"\n",
    "    print(list(k.view(k.numel())))\n",
    "\n",
    "    \n",
    "def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model', \n",
    "                       gmin=None, gmax=None, interpolation=None, first_kernel=0):\n",
    "    ofms, ifms = weights.size()[0], weights.size()[1]\n",
    "    kw, kh = weights.size()[2], weights.size()[3]\n",
    "    \n",
    "    print(\"min=%.4f\\tmax=%.4f\" % (weights.min(), weights.max()))\n",
    "    shape_str = distiller.size2str(weights.size())\n",
    "    volume = distiller.volume(weights)\n",
    "    print(\"size=%s = %d elements\" % (shape_str, volume))\n",
    "    \n",
    "    # Clone because we are going to change the tensor values\n",
    "    weights = weights.clone()\n",
    "    if binary_mask:\n",
    "        weights[weights!=0] = 1\n",
    "        # Take the inverse of the pixels, because we want zeros to appear white\n",
    "        #weights = 1 - weights\n",
    "    \n",
    "    kernels = weights.view(ofms * ifms, kh, kw)\n",
    "    nrow, ncol = layout[0], layout[1]\n",
    "\n",
    "    # Plot the graph\n",
    "    plt.gray()\n",
    "    #plt.tight_layout()\n",
    "    fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );\n",
    "\n",
    "    # We want to normalize the grayscale brightness levels for all of the images we display (group),\n",
    "    # otherwise, each image is normalized separately and this causes distortion between the different\n",
    "    # filters images we ddisplay.\n",
    "    # We don't normalize across all of the filters images, because the outliers cause the image of each \n",
    "    # filter to be very muted.  This is because each group of filters we display usually has low variance\n",
    "    # between the element values of that group.\n",
    "    if color_normalization=='Tensor':\n",
    "        gmin = weights.min()\n",
    "        gmax = weights.max()\n",
    "    elif color_normalization=='Group':\n",
    "        gmin = weights[0:nrow, 0:ncol].min()\n",
    "        gmax = weights[0:nrow, 0:ncol].max()\n",
    "    print(\"gmin=%.4f\\tgmax=%.4f\" % (gmin, gmax))\n",
    "    if isinstance(gmin, torch.Tensor):\n",
    "        gmin = gmin.item()\n",
    "        gmax = gmax.item()\n",
    "    \n",
    "    i = 0 \n",
    "    for row in range(0, nrow):\n",
    "        for col in range (0, ncol):\n",
    "            ax = fig.add_subplot(layout[0], layout[1], i+1)\n",
    "            if binary_mask:\n",
    "                ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1);\n",
    "            else:\n",
    "                # Use siesmic so that colors around the center are lighter.  Red and blue are used\n",
    "                # to represent (and visually separate) negative and positive weights \n",
    "                ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);\n",
    "            ax.set(xticks=[], yticks=[])\n",
    "            i += 1\n",
    "    \n",
    "    \n",
    "def l1_norm_histogram(weights):\n",
    "    \"\"\"Compute a histogram of the L1-norms of the kernels of a weights tensor.\n",
    "    \n",
    "    The L1-norm of a kernel is one way to quantify the \"magnitude\" of the total coeffiecients\n",
    "    making up this kernel.\n",
    "    \n",
    "    Another interesting look at filters is to compute a histogram per filter.\n",
    "    \"\"\"\n",
    "    ofms, ifms = weights.size()[0], weights.size()[1]\n",
    "    kw, kh = weights.size()[2], weights.size()[3]\n",
    "    kernels = weights.view(ofms * ifms, kh, kw)\n",
    "    \n",
    "    l1_hist = []\n",
    "    for kernel in range(ofms*ifms):\n",
    "        l1_hist.append(kernels[kernel].norm(1))\n",
    "    return l1_hist\n",
    "\n",
    "def plot_l1_norm_hist(weights):    \n",
    "    l1_hist = l1_norm_histogram(weights)\n",
    "    n, bins, patches = plt.hist(l1_hist, bins=200)\n",
    "    plt.title('Kernel L1-norm histograms')\n",
    "    plt.ylabel('Frequency')\n",
    "    plt.xlabel('Kernel L1-norm')\n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "def plot_layer_sizes(which, sparse_model, dense_model):\n",
    "    dense = []\n",
    "    sparse = []\n",
    "    names = []\n",
    "    for name, sparse_weights in sparse_model.state_dict().items():\n",
    "        if ('weight' not in name) or (which!='*' and which not in name):\n",
    "                continue    \n",
    "        sparse.append(len(sparse_weights[sparse_weights!=0]))\n",
    "        names.append(name)\n",
    "\n",
    "    for name, dense_weights in dense_model.state_dict().items():\n",
    "        if ('weight' not in name) or (which!='*' and which not in name):\n",
    "                continue\n",
    "        dense.append(dense_weights.numel())\n",
    "\n",
    "    N = len(sparse)\n",
    "    ind = np.arange(N)    # the x locations for the groups\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    width = .47\n",
    "    p1 = plt.bar(ind, dense,  width = .47, color = '#278DBC')\n",
    "    p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099')\n",
    "\n",
    "    plt.ylabel('Size')\n",
    "    plt.title('Layer sizes')\n",
    "    plt.xticks(rotation='vertical')\n",
    "    plt.xticks(ind, names)\n",
    "    #plt.yticks(np.arange(0, 100, 150))\n",
    "    plt.legend((p1[0], p2[0]), ('Dense', 'Sparse'))\n",
    "\n",
    "    #Remove plot borders\n",
    "    for location in ['right', 'left', 'top', 'bottom']:\n",
    "        ax.spines[location].set_visible(False)  \n",
    "\n",
    "    #Fix grid to be horizontal lines only and behind the plots\n",
    "    ax.yaxis.grid(color='gray', linestyle='solid')\n",
    "    ax.set_axisbelow(True)\n",
    "    plt.show()\n",
    "    \n",
    "    \n",
    "def conv_param_names(model):\n",
    "    return [param_name for param_name, p in model.state_dict().items()  \n",
    "            if (p.dim()>2) and (\"weight\" in param_name)]\n",
    "\n",
    "def conv_fc_param_names(model):\n",
    "    return [param_name for param_name, p in model.state_dict().items()  \n",
    "            if (p.dim()>1) and (\"weight\" in param_name)]\n",
    "\n",
    "def conv_fc_params(model):\n",
    "    return [(param_name,p) for (param_name, p) in model.state_dict()\n",
    "            if (p.dim()>1) and (\"weight\" in param_name)]\n",
    "\n",
    "def fc_param_names(model):\n",
    "    return [param_name for param_name, p in model.state_dict().items()  \n",
    "            if (p.dim()==2) and (\"weight\" in param_name)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_bars(which, setA, setAName, setB, setBName, names, title):\n",
    "    N = len(setA)\n",
    "    ind = np.arange(N)    # the x locations for the groups\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(20,10))\n",
    "    width = .47\n",
    "    p1 = plt.bar(ind, setA,  width = .47, color = '#278DBC')\n",
    "    p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')\n",
    "\n",
    "    plt.ylabel('Size')\n",
    "    plt.title(title)\n",
    "    plt.xticks(rotation='vertical')\n",
    "    plt.xticks(ind, names)\n",
    "    #plt.yticks(np.arange(0, 100, 150))\n",
    "    plt.legend((p1[0], p2[0]), (setAName, setBName))\n",
    "\n",
    "    #Remove plot borders\n",
    "    for location in ['right', 'left', 'top', 'bottom']:\n",
    "        ax.spines[location].set_visible(False)  \n",
    "\n",
    "    #Fix grid to be horizontal lines only and behind the plots\n",
    "    ax.yaxis.grid(color='gray', linestyle='solid')\n",
    "    ax.set_axisbelow(True)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}