diff --git a/jupyter/truncated_svd.ipynb b/jupyter/truncated_svd.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..85a8a954c0cda8fe1fd72709ee28acba4537f179 --- /dev/null +++ b/jupyter/truncated_svd.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Truncated SVD\n", + "\n", + "This is a simple application of [Truncated SVD](http://langvillea.people.cofc.edu/DISSECTION-LAB/Emmie'sLSI-SVDModule/p5module.html), just to get a feeling of what happens to the accuracy if we use TruncatedSVD **w/o fine-tuning.**\n", + "\n", + "We apply Truncated SVD on the linear layer found at the end of ResNet50, and run a test over the validation dataset to measure the impact on the classification accuracy.\n", + "\n", + "Spoiler: \n", + "\n", + "At k=800 (80%) we get\n", + " - Top1: 76.02 \n", + " - Top5: 92.86\n", + " \n", + " Total weights: 1000 * 800 + 800 * 2048 = 2,438,400 (vs. 1000 * 2048 = 2,048,000)\n", + " \n", + "At k=700 (70%) we get\n", + " - Top1: 76.03 \n", + " - Top5: 92.85\n", + " \n", + " Total weights: 1000 * 700 + 700 * 2048 = 2,133,600 (vs. 2,048,000) \n", + "\n", + "At k=600 (60%) we get\n", + " - Top1: 75.98 \n", + " - Top5: 92.82\n", + " \n", + " Total weights: 1000 * 600 + 600 * 2048 = 1,828,800 (vs. 2,048,000) \n", + " \n", + "At k=500 (50%) we get\n", + " - Top1: 75.78 \n", + " - Top5: 92.77\n", + " \n", + " Total weights: 1000 * 500 + 500 * 2048 = 1,524,000 (vs. 2,048,000) \n", + "\n", + "At k=400 (40%) we get\n", + " - Top1: 75.65 \n", + " - Top5: 92.75\n", + " \n", + " Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torch.nn as nn\n", + "from torch.autograd import Variable\n", + "import scipy.stats as ss\n", + "\n", + "# Relative import of code from distiller, w/o installing the package\n", + "import os\n", + "import sys\n", + "import numpy as np\n", + "import scipy \n", + "import matplotlib.pyplot as plt\n", + "\n", + "module_path = os.path.abspath(os.path.join('..'))\n", + "if module_path not in sys.path:\n", + " sys.path.append(module_path)\n", + "\n", + "import distiller\n", + "import apputils\n", + "import models\n", + "from apputils import *\n", + "\n", + "plt.style.use('seaborn') # pretty matplotlib plots" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def imagenet_load_data(data_dir, batch_size, num_workers, shuffle=True):\n", + " \"\"\"Load the ImageNet dataset\"\"\"\n", + " test_dir = os.path.join(data_dir, 'val')\n", + " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", + " std=[0.229, 0.224, 0.225])\n", + "\n", + " test_loader = torch.utils.data.DataLoader(\n", + " datasets.ImageFolder(test_dir, transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.ToTensor(),\n", + " normalize,\n", + " ])),\n", + " batch_size=batch_size, shuffle=shuffle,\n", + " num_workers=num_workers, pin_memory=True)\n", + "\n", + " return test_loader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 4\n", + "\n", + "# Data loader\n", + "test_loader = imagenet_load_data(\"../../data.imagenet\", \n", + " batch_size=BATCH_SIZE, \n", + " num_workers=2)\n", + " \n", + " \n", + "# Reverse the normalization transformations we performed when we loaded the data \n", + "# for consumption by our CNN.\n", + "# See: https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3\n", + "invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],\n", + " std = [ 1/0.229, 1/0.224, 1/0.225 ]),\n", + " transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],\n", + " std = [ 1., 1., 1. ]),\n", + " ])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load ResNet 50" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Load the various models\n", + "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n", + "\n", + "\n", + "# See Faster-RCNN: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py\n", + "def truncated_svd(W, l):\n", + " \"\"\"Compress the weight matrix W of an inner product (fully connected) layer\n", + " using truncated SVD.\n", + " Parameters:\n", + " W: N x M weights matrix\n", + " l: number of singular values to retain\n", + " Returns:\n", + " Ul, L: matrices such that W \\approx Ul*L\n", + " \"\"\"\n", + "\n", + " U, s, V = np.linalg.svd(W, full_matrices=False)\n", + "\n", + " Ul = U[:, :l]\n", + " sl = s[:l]\n", + " Vl = V[:l, :]\n", + "\n", + " SV = np.dot(np.diag(sl), Vl)\n", + " SV = torch.from_numpy(SV).cuda()\n", + " Ul = torch.from_numpy(Ul).cuda()\n", + " return Ul, SV\n", + "\n", + "\n", + "class TruncatedSVD(nn.Module):\n", + " def __init__(self, replaced_gemm, gemm_weights):\n", + " super(TruncatedSVD,self).__init__()\n", + " self.replaced_gemm = replaced_gemm\n", + " print(\"W = {}\".format(gemm_weights.shape))\n", + " self.U, self.SV = truncated_svd(gemm_weights, int(0.4 * gemm_weights.size(0)))\n", + " print(\"U = {}\".format(self.U.shape))\n", + " \n", + " self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()\n", + " self.fc_u.weight.data = self.U\n", + " \n", + " print(\"SV = {}\".format(self.SV.shape))\n", + " self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()\n", + " self.fc_sv.weight.data = self.SV#.t()\n", + " \n", + "\n", + " def forward(self, x):\n", + " x = self.fc_sv.forward(x)\n", + " x = self.fc_u.forward(x)\n", + " return x\n", + "\n", + "def replace(model):\n", + " fc_weights = model.state_dict()['fc.weight']\n", + " fc_layer = model.fc\n", + " print(\"fc_layer({}, {})\".format(fc_layer.in_features, fc_layer.out_features))\n", + " model.fc = TruncatedSVD(fc_layer, fc_weights)\n", + "\n", + "resnet50 = deepcopy(resnet50)\n", + "replace(resnet50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Standard loop to test the accuracy of a model.\n", + "\n", + "import time\n", + "import torchnet.meter as tnt\n", + "t0 = time.time()\n", + "test_loader = imagenet_load_data(\"../../data.imagenet\", \n", + " batch_size=512, \n", + " num_workers=4,\n", + " shuffle=False)\n", + "\n", + "t1 = time.time()\n", + "classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))\n", + "resnet50.eval()\n", + "for validation_step, (inputs, target) in enumerate(test_loader):\n", + " with torch.no_grad():\n", + " inputs, target = inputs.to('cuda'), target.to('cuda')\n", + " outputs = resnet50(inputs)\n", + " classerr.add(outputs.data, target)\n", + " if (validation_step+1) % 10 == 0:\n", + " print((validation_step+1) * 512)\n", + " \n", + "print(classerr.value(1), classerr.value(5))\n", + "t2 = time.time()\n", + "print(t2-t0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}