{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experimental\n",
    "\n",
    "<br>\n",
    "<font size=\"6\" color=\"red\"> &#9888; WARNING </font>\n",
    "<br>\n",
    "<font size=\"4\" color=\"red\">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font><br><br>\n",
    "Please also note that for generating a PNG image of the network (last cell of the notebook), you will need to have graphviz installed:\n",
    "   ```\n",
    "   $ sudo apt-get install graphviz\n",
    "   ```\n",
    "<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "    \n",
    "from distiller import SummaryGraph\n",
    "from distiller.model_summaries import *\n",
    "from distiller.models import create_model\n",
    "from distiller.apputils import *\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.autograd import Variable\n",
    "import qgrid\n",
    "\n",
    "# Load some common jupyter code\n",
    "%run distiller_jupyter_helpers.ipynb\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interactive, interact, Layout\n",
    "\n",
    "# Some models have long node names and require longer lines\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))\n",
    "\n",
    "print(\"You are using pytorch version %s\" %torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose which model you want to examine\n",
    "\n",
    "If you are studying the structure of a neural network model, you probably don't need a pruned model, although you can use one.\n",
    "<br>\n",
    "In this example, we look at a pretrained ResNet18 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'imagenet'\n",
    "dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)\n",
    "arch = 'resnet18'\n",
    "#arch = 'alexnet'\n",
    "checkpoint_file = None \n",
    "\n",
    "if checkpoint_file is not None:\n",
    "    model = create_model(pretrained=False, dataset=dataset, arch=arch)\n",
    "    load_checkpoint(model, checkpoint_file)\n",
    "else:\n",
    "    model = create_model(pretrained=False, dataset=dataset, arch=arch, parallel=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can examine layer connectivity:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dummy_imagenet_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)\n",
    "dummy_cifar_input = Variable(torch.randn(1, 3, 32, 32), requires_grad=False)\n",
    "\n",
    "dummy_input = dummy_imagenet_input if dataset=='imagenet' else dummy_cifar_input\n",
    "\n",
    "g = SummaryGraph(model, dummy_input)\n",
    "df = connectivity_summary(g)\n",
    "#qgrid.set_grid_option('defaultColumnWidth', 10)\n",
    "qgrid.show_grid(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also print the shapes of the various tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "df = connectivity_summary_verbose(g)\n",
    "qgrid.show_grid(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And you can discover the attributes of each layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\n",
    "# add_macs_attr(g)\n",
    "# add_footprint_attr(g)\n",
    "# add_arithmetic_intensity_attr(g)\n",
    "ignore_attrs = ['group', 'is_test', 'consumed_inputs', 'alpha', 'beta', 'MACs', 'footprint', 'ai', 'fm_vol', 'weights_vol']\n",
    "df = attributes_summary(g, ignore_attrs)\n",
    "df['MAC'] = g.get_attr('MACs')\n",
    "df['BW'] = g.get_attr('footprint')\n",
    "df['AI'] = g.get_attr('ai')\n",
    "#df = df.assign([5]*len(df)).values\n",
    "\n",
    "qgrid.show_grid(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_bars(which, setA, setAName, setB, setBName, names, title):\n",
    "    N = len(setA)\n",
    "    ind = np.arange(N)    # the x locations for the groups\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(20,10))\n",
    "    width = .47\n",
    "    p1 = plt.bar(ind, setA,  width = .47, color = '#278DBC')\n",
    "    p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')\n",
    "\n",
    "    plt.ylabel('Size')\n",
    "    plt.title(title)\n",
    "    plt.xticks(rotation='vertical')\n",
    "    plt.xticks(ind, names)\n",
    "    #plt.yticks(np.arange(0, 100, 150))\n",
    "    plt.legend((p1[0], p2[0]), (setAName, setBName))\n",
    "\n",
    "    #Remove plot borders\n",
    "    for location in ['right', 'left', 'top', 'bottom']:\n",
    "        ax.spines[location].set_visible(False)  \n",
    "\n",
    "    #Fix grid to be horizontal lines only and behind the plots\n",
    "    ax.yaxis.grid(color='gray', linestyle='solid')\n",
    "    ax.set_axisbelow(True)\n",
    "    plt.show()\n",
    "\n",
    "sgraph = g\n",
    "names = [op['name'] for op in sgraph.ops.values()]\n",
    "setA = g.get_attr('fm_vol') \n",
    "setB = g.get_attr('weights_vol') \n",
    "plot_bars(None, setA, 'Feature maps', setB, 'Weights', names, 'Weights footprint vs. feature-maps footprint\\n(Normalized)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = [op['name'] for op in sgraph.ops.values() if 'MACs' in op['attrs'] and op['attrs']['MACs']>0]\n",
    "macs = g.get_attr('MACs', lambda op: op['attrs']['MACs']>0)\n",
    "\n",
    "y_pos = np.arange(len(names))\n",
    "fig, ax = plt.subplots(figsize=(20,10))\n",
    "barlist = plt.bar(y_pos, macs, align='center', alpha=0.5, color = '#278DBC')\n",
    "plt.xticks(y_pos, names)\n",
    "plt.ylabel('MACs')\n",
    "plt.title('MACs per layer')\n",
    "\n",
    "\n",
    "#Fix grid to be horizontal lines only and behind the plots\n",
    "ax.yaxis.grid(color='gray', linestyle='solid')\n",
    "ax.set_axisbelow(True)\n",
    "\n",
    "plt.xticks(rotation='vertical')\n",
    "\n",
    "#Remove plot borders\n",
    "for location in ['right', 'left', 'top', 'bottom']:\n",
    "    ax.spines[location].set_visible(False) \n",
    "\n",
    "ops = g.get_ops('MACs', lambda op: op['attrs']['MACs']>0)\n",
    "for bar,op in zip(barlist, ops):\n",
    "    kernel = op['attrs'].get('kernel_shape', None)\n",
    "    if str(kernel) == '[7, 7]':\n",
    "        bar.set_color('r') \n",
    "    if str(kernel) == '[3, 3]':\n",
    "        bar.set_color('g') \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ops = sgraph.ops\n",
    "positive_mac = lambda op: op['attrs']['MACs']>0\n",
    "names = g.get_attr('name', positive_mac) \n",
    "\n",
    "macs = g.get_attr('MACs', positive_mac)\n",
    "norm_macs = [float(i)/np.sum(macs) for i in macs]\n",
    "\n",
    "footprint = g.get_attr('footprint', positive_mac)\n",
    "norm_footprint = [float(i)/np.sum(footprint) for i in footprint]\n",
    "\n",
    "plot_bars(None, norm_macs, 'MACs', norm_footprint, 'footprint', names, \"MACs vs footprint\")\n",
    "\n",
    "#norm = [float(i)/sum(raw) for i in raw]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a PNG image of the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "\n",
    "    \n",
    "g = SummaryGraph(model, dummy_input)\n",
    "\n",
    "DRAW_TO_FILE = True\n",
    "if DRAW_TO_FILE:\n",
    "    draw_model_to_file(g, 'graph.png')\n",
    "\n",
    "# Draw on notebook\n",
    "png = create_png(g)\n",
    "Image(png)\n",
    "   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "<div id=\"Gray-et-al-2015\"></div> **Andrew Lavin and Scott Gray**. \n",
    "    [*Fast Algorithms for Convolutional Neural Networks*](https://arxiv.org/pdf/1509.09308.pdf),\n",
    "    2015.\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}