{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automated Gradual Pruning Schedule\n",
    "\n",
    "Michael Zhu and Suyog Gupta, [\"To prune, or not to prune: exploring the efficacy of pruning for model compression\"](https://arxiv.org/pdf/1710.01878), 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices<br>\n",
    "<br>\n",
    "After completing sensitivity analysis, decide on your pruning schedule.\n",
    "\n",
    "## Table of Contents\n",
    "1. [Implementation of the gradual sparsity function](#Implementation-of-the-gradual-sparsity-function)\n",
    "2. [Visualize pruning schedule](#Visualize-pruning-schedule)\n",
    "3. [References](#References)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "import matplotlib.pyplot as plt\n",
    "from functools import partial\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from ipywidgets import widgets, interact"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementation of the gradual sparsity function\n",
    "\n",
    "The function ```sparsity_target``` implements the gradual sparsity schedule from [[1]](#zhu-gupta):<br><br>\n",
    "<b><i>\"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value $s_i$ (usually 0) to a final sparsity value $s_f$ over a span of $n$ pruning steps, starting at training step $t_0$ and with pruning frequency $\\Delta t$.\"</i></b><br>\n",
    "<br>\n",
    "\n",
    "<div id=\"eq:zhu_gupta_schedule\"></div>\n",
    "<center>\n",
    "$\\large\n",
    "\\begin{align}\n",
    "s_t = s_f + (s_i - s_f) \\left(1- \\frac{t-t_0}{n\\Delta t}\\right)^3\n",
    "\\end{align}\n",
    "\\ \\ for\n",
    "\\large \\ \\ t \\in \\{t_0, t_0+\\Delta t, ..., t_0+n\\Delta t\\}\n",
    "$\n",
    "</center>\n",
    "<br>\n",
    "Pruning happens once at the beginning of each epoch, until the duration of the pruning (the number of epochs to prune) is exceeded.  After pruning ends, the training continues without pruning, but the pruned weights are kept at zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sparsity_target(starting_epoch, ending_epoch, initial_sparsity, final_sparsity, current_epoch):\n",
    "    if final_sparsity < initial_sparsity:\n",
    "        return current_epoch \n",
    "    if current_epoch < starting_epoch:\n",
    "        return current_epoch\n",
    "    \n",
    "    span = ending_epoch - starting_epoch\n",
    "    target_sparsity = ( final_sparsity +\n",
    "                        (initial_sparsity - final_sparsity) *\n",
    "                        (1.0 - ((current_epoch-starting_epoch)/span))**3)\n",
    "    return target_sparsity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize pruning schedule\n",
    "When using the Automated Gradual Pruning (AGP) schedule, you may want to visualize how the pruning schedule will look as a function of the epoch number.  This is called the *sparsity function*.  The widget below will help you do this.<br>\n",
    "There are three knobs you can use to change the schedule:\n",
    "- ```duration```: this is the number of epochs over which to use the AGP schedule ($n\\Delta t$).\n",
    "- ```initial_sparsity```: $s_i$\n",
    "- ```final_sparsity```: $s_f$\n",
    "- ```frequency```: this is the pruning frequency ($\\Delta t$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_pruning(duration, initial_sparsity, final_sparsity, frequency):\n",
    "    epochs = []\n",
    "    sparsity_levels = []\n",
    "    # The derivative of the sparsity (i.e. sparsity rate of change)\n",
    "    d_sparsity = []\n",
    "\n",
    "    if frequency=='':\n",
    "        frequency = 1 \n",
    "    else:\n",
    "        frequency = int(frequency)\n",
    "    for epoch in range(0,40):\n",
    "        epochs.append(epoch)\n",
    "        current_epoch=Variable(torch.FloatTensor([epoch]), requires_grad=True)\n",
    "        if epoch<duration and epoch%frequency == 0:\n",
    "            sparsity = sparsity_target(\n",
    "                     starting_epoch=0, \n",
    "                     ending_epoch=duration, \n",
    "                     initial_sparsity=initial_sparsity, \n",
    "                     final_sparsity=final_sparsity,\n",
    "                current_epoch=current_epoch\n",
    "            )\n",
    "            \n",
    "            sparsity_levels.append(sparsity)\n",
    "            sparsity.backward()\n",
    "            d_sparsity.append(current_epoch.grad.item())\n",
    "            current_epoch.grad.data.zero_()\n",
    "        else:\n",
    "            sparsity_levels.append(sparsity)\n",
    "            d_sparsity.append(0)\n",
    "            \n",
    "\n",
    "    plt.plot(epochs, sparsity_levels, epochs, d_sparsity)\n",
    "    plt.ylabel('sparsity (%)')\n",
    "    plt.xlabel('epoch')\n",
    "    plt.title('Pruning Rate')\n",
    "    plt.ylim(0, 100)\n",
    "    plt.draw()\n",
    "\n",
    "\n",
    "duration_widget = widgets.IntSlider(min=0, max=100, step=1, value=28)\n",
    "si_widget = widgets.IntSlider(min=0, max=100, step=1, value=0)\n",
    "interact(draw_pruning, \n",
    "         duration=duration_widget, \n",
    "         initial_sparsity=si_widget, \n",
    "         final_sparsity=(0,100,1),\n",
    "         frequency='2');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"toc\"></div>\n",
    "## References\n",
    "1. <div id=\"zhu-gupta\"></div> **Michael Zhu and Suyog Gupta**. \n",
    "    [*To prune, or not to prune: exploring the efficacy of pruning for model compression*](https://arxiv.org/pdf/1710.01878),\n",
    "    NIPS Workshop on Machine Learning of Phones and other Consumer Devices,\n",
    "    2017."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}