diff --git a/jupyter/parameter_histograms.ipynb b/jupyter/parameter_histograms.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a2c1f2a49089c2958925a9f22848a5394ac2fce0 --- /dev/null +++ b/jupyter/parameter_histograms.ipynb @@ -0,0 +1,145 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Histograms\n", + "\n", + "This notebook loads a model and draws the histograms of the parameters tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torch.nn as nn\n", + "from torch.autograd import Variable\n", + "import scipy.stats as ss\n", + "\n", + "# Relative import of code from distiller, w/o installing the package\n", + "import os\n", + "import sys\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "module_path = os.path.abspath(os.path.join('..'))\n", + "if module_path not in sys.path:\n", + " sys.path.append(module_path)\n", + "\n", + "import distiller\n", + "import models\n", + "from apputils import *\n", + "\n", + "plt.style.use('seaborn') # pretty matplotlib plots" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load your model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# It is interesting to compare the distribution of non-pretrained model (Normally-distributed)\n", + "# vs. the distribution of the pretrained model.\n", + "model = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=True)\n", + "\n", + "# Optionally load your compressed model \n", + "# load_checkpoint(model, <path-to-your-checkpoint-file>);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot the distributions\n", + "\n", + "We plot the distributions of the weights of each convolution layer, and we also plot the fitted Gaussian and Laplacian distributions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def flatten(weights):\n", + " weights = weights.view(weights.numel())\n", + " weights = weights.data.cpu().numpy()\n", + " return weights\n", + "\n", + "REMOVE_ZEROS = False\n", + "nbins = 500\n", + "for name, weights in model.named_parameters():\n", + " if weights.dim() == 4:\n", + " size_str = \"x\".join([str(s) for s in weights.size()])\n", + " weights = flatten(weights)\n", + " \n", + " if REMOVE_ZEROS:\n", + " # Optionally remove zeros (lots of zeros will dominate the histogram and the \n", + " # other data will be hard to see\n", + " weights = weights[weights!=0]\n", + " \n", + " # Fit the data to the Normal distribution\n", + " (mean_fitted, std_fitted) = ss.norm.fit(weights)\n", + " x = np.linspace(min(weights), max(weights), nbins)\n", + " weights_gauss_fitted = ss.norm.pdf(x, loc=mean_fitted, scale=std_fitted)\n", + "\n", + " # Fit the data to the Laplacian distribution\n", + " (mean_fitted, std_fitted) = ss.laplace.fit(weights)\n", + " weights_laplace_fitted = ss.laplace.pdf(x, loc=mean_fitted, scale=std_fitted)\n", + "\n", + " n, bins, patches = plt.hist(weights, histtype='stepfilled', \n", + " cumulative=False, bins=nbins, normed=1)\n", + " plt.plot(x, weights_gauss_fitted, label='gauss')\n", + " plt.plot(x, weights_laplace_fitted, label='laplace')\n", + " plt.title(name + \" - \" +size_str)\n", + " #plt.figure(figsize=(10,5))\n", + " plt.legend()\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}