From 0f5c82f83976edbaa25fb167fab130fa42b05d0a Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sat, 24 Aug 2019 19:52:51 +0300
Subject: [PATCH] Truncated SVD: update notebook documentation and add
 customized module

Updated documentation per issue wq#359
---
 distiller/modules/tsvd.py           |  67 +++++++++++++++++
 jupyter/truncated_svd.ipynb         | 109 ++++++++++++++++++++++++----
 licenses/py_faster_rcnn-license.txt |  81 +++++++++++++++++++++
 3 files changed, 241 insertions(+), 16 deletions(-)
 create mode 100755 distiller/modules/tsvd.py
 create mode 100755 licenses/py_faster_rcnn-license.txt

diff --git a/distiller/modules/tsvd.py b/distiller/modules/tsvd.py
new file mode 100755
index 0000000..aea20bc
--- /dev/null
+++ b/distiller/modules/tsvd.py
@@ -0,0 +1,67 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Truncated-SVD module.
+
+For an example of how truncated-SVD can be used, see this Jupyter notebook:
+https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb
+
+"""
+
+def truncated_svd(W, l):
+    """Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD.
+
+    For the original implementation (MIT license), see Faster-RCNN:
+    https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
+    We replaced numpy operations with pytorch operations (so that we can leverage the GPU).
+
+    Arguments:
+      W: N x M weights matrix
+      l: number of singular values to retain
+    Returns:
+      Ul, L: matrices such that W \approx Ul*L
+    """
+
+    U, s, V = torch.svd(W, some=True)
+
+    Ul = U[:, :l]
+    sl = s[:l]
+    V = V.t()
+    Vl = V[:l, :]
+
+    SV = torch.mm(torch.diag(sl), Vl)
+    return Ul, SV
+
+
+class TruncatedSVD(nn.Module):
+    def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):
+        super().__init__()
+        self.replaced_gemm = replaced_gemm
+        print("W = {}".format(gemm_weights.shape))
+        self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))
+        print("U = {}".format(self.U.shape))
+
+        self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
+        self.fc_u.weight.data = self.U
+
+        print("SV = {}".format(self.SV.shape))
+        self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
+        self.fc_sv.weight.data = self.SV#.t()
+
+    def forward(self, x):
+        x = self.fc_sv.forward(x)
+        x = self.fc_u.forward(x)
+        return x
diff --git a/jupyter/truncated_svd.ipynb b/jupyter/truncated_svd.ipynb
index 7331d3c..07f41d6 100644
--- a/jupyter/truncated_svd.ipynb
+++ b/jupyter/truncated_svd.ipynb
@@ -40,12 +40,54 @@
     "    - Top1: 75.65  \n",
     "    - Top5: 92.75\n",
     "    \n",
-    "    Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000) "
+    "    Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)\n",
+    "    \n",
+    "## Details\n",
+    "\n",
+    "[SVD (Singular Value Decomposition)](https://en.wikipedia.org/wiki/Singular_value_decomposition) is an exact factorization of a matrix, W (of shape m x n), to the form USV$^T$ (U is m x m, S is m x n, V$^T$ is n x n; V$^T$ is the transpose of V). Every matrix has an SVD.\n",
+    "\n",
+    "A Linear (fully-connected) layer performs: y = Wx + b   (or y = xW$^T$ + b)<br>\n",
+    "We can use SVD to refactor W to rewrite this as: y = (USV$^T$)x + b\n",
+    "\n",
+    "So far, we haven’t done any compression, so let’s get to it using Truncated SVD.<br> TSVD is a method to provide an approximated decomposition of W, in which S has a lower rank. We want to find an approximation of W that is “good enough” and also accelerates the computation of Wx.\n",
+    "\n",
+    "We choose some lower-rank, k, such that k<m (preferably k<<m).<br>\n",
+    "TSVD is straight-forward: keep the largest k singular values of S and discard the rest (truncate S).\n",
+    "\n",
+    "After TSVD we have:\n",
+    "U’ is m x t, S’ is k x k, V’$^T$ is k x n.<br>\n",
+    "y ~ (U’S’V’)x + b<br>\n",
+    "y ~ (U’(S’V’))x + b<br>\n",
+    "\n",
+    "We'll replace S’V’$^T$ with A, because we can pre-compute it once. A has shape k x n:<br>\n",
+    "y ~ (U’A)x + b<br>\n",
+    "y ~ U’(Ax) + b<br>\n",
+    "\n",
+    "Let’s ignore the bias and calculate the number of parameters and FLOPs (floating point operations) for the original y:\n",
+    "\n",
+    " - m * n weights coefficients<br>\n",
+    " - m * n FLOPs (for batch size = 1)<br>\n",
+    "\n",
+    "After TSVD we have:\n",
+    "\n",
+    " - mk + kn = k*(m+n) weights coefficients<br>\n",
+    " - kn + mk = k*(m+n) FLOPs (for batch size = 1)<br>\n",
+    "\n",
+    "To actually compress the weights after TSVD, we want: m * n > k*(m+n)<br>\n",
+    "Let’s rewrite k in terms of m: k = tm<br>\n",
+    "m * n > tm*(m+n)<br>\n",
+    "n > t*(m+n)<br>\n",
+    "n / (m+n) >= t<br>\n",
+    "\n",
+    "This is the math, but for an actual performance increase, we should strive for m * n >> k*(m+n)\n",
+    "\n",
+    "In the example notebook: m = 1000; n=2048<br>\n",
+    "So when t=2048/(1000+2048) (that is, k=2048/3048*1000=672), we have equilibrium. When 0.672>t (i.e. k is smaller than 672), the sum of the size of the weights of A and U’ is smaller than the size of W.\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -80,7 +122,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -105,14 +147,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
     "BATCH_SIZE = 4\n",
     "\n",
     "# Data loader\n",
-    "test_loader = imagenet_load_data(\"../../data.imagenet/\", \n",
+    "test_loader = imagenet_load_data(\"/datasets/imagenet/\", \n",
     "                                 batch_size=BATCH_SIZE, \n",
     "                                 num_workers=2)\n",
     "    \n",
@@ -136,11 +178,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "metadata": {
     "scrolled": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "fc_layer(2048, 1000)\n",
+      "W = torch.Size([1000, 2048])\n",
+      "U = torch.Size([1000, 400])\n",
+      "SV = torch.Size([400, 2048])\n"
+     ]
+    }
+   ],
    "source": [
     "# Load the various models\n",
     "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n",
@@ -170,11 +223,11 @@
     "\n",
     "\n",
     "class TruncatedSVD(nn.Module):\n",
-    "    def __init__(self, replaced_gemm, gemm_weights):\n",
-    "        super(TruncatedSVD,self).__init__()\n",
+    "    def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):\n",
+    "        super().__init__()\n",
     "        self.replaced_gemm = replaced_gemm\n",
     "        print(\"W = {}\".format(gemm_weights.shape))\n",
-    "        self.U, self.SV = truncated_svd(gemm_weights.data, int(0.4 * gemm_weights.size(0)))\n",
+    "        self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))\n",
     "        print(\"U = {}\".format(self.U.shape))\n",
     "        \n",
     "        self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()\n",
@@ -182,20 +235,21 @@
     "        \n",
     "        print(\"SV = {}\".format(self.SV.shape))\n",
     "        self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()\n",
-    "        self.fc_sv.weight.data = self.SV#.t()\n",
-    "        \n",
+    "        self.fc_sv.weight.data = self.SV#.t()        \n",
     "\n",
     "    def forward(self, x):\n",
     "        x = self.fc_sv.forward(x)\n",
     "        x = self.fc_u.forward(x)\n",
     "        return x\n",
     "\n",
+    "    \n",
     "def replace(model):\n",
     "    fc_weights = model.state_dict()['fc.weight']\n",
     "    fc_layer = model.fc\n",
     "    print(\"fc_layer({}, {})\".format(fc_layer.in_features, fc_layer.out_features))\n",
-    "    model.fc = TruncatedSVD(fc_layer, fc_weights)\n",
+    "    model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4)\n",
     "\n",
+    "    \n",
     "from copy import deepcopy\n",
     "resnet50 = deepcopy(resnet50)\n",
     "replace(resnet50)"
@@ -203,18 +257,34 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 14,
    "metadata": {
     "scrolled": false
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "progress: 51200 images\n",
+      "progress: 102400 images\n",
+      "progress: 153600 images\n",
+      "progress: 204800 images\n",
+      "progress: 256000 images\n",
+      "progress: 307200 images\n",
+      "progress: 358400 images\n",
+      "Top1: 75.70   Top5: 92.76\n",
+      "Duration:  168.111492395401\n"
+     ]
+    }
+   ],
    "source": [
     "# Standard loop to test the accuracy of a model.\n",
     "\n",
     "import time\n",
     "import torchnet.meter as tnt\n",
     "t0 = time.time()\n",
-    "test_loader = imagenet_load_data(\"../../data.imagenet\", \n",
+    "test_loader = imagenet_load_data(\"/datasets/imagenet\", \n",
     "                                 batch_size=64, \n",
     "                                 num_workers=4,\n",
     "                                 shuffle=False)\n",
@@ -234,6 +304,13 @@
     "t2 = time.time()\n",
     "print(\"Duration: \", t2-t0)"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
diff --git a/licenses/py_faster_rcnn-license.txt b/licenses/py_faster_rcnn-license.txt
new file mode 100755
index 0000000..1ab42b2
--- /dev/null
+++ b/licenses/py_faster_rcnn-license.txt
@@ -0,0 +1,81 @@
+Faster R-CNN
+
+The MIT License (MIT)
+
+Copyright (c) 2015 Microsoft Corporation
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+************************************************************************
+
+THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
+
+This project, Faster R-CNN, incorporates material from the project(s)
+listed below (collectively, "Third Party Code").  Microsoft is not the
+original author of the Third Party Code.  The original copyright notice
+and license under which Microsoft received such Third Party Code are set
+out below. This Third Party Code is licensed to you under their original
+license terms set forth below.  Microsoft reserves all other rights not
+expressly granted, whether by implication, estoppel or otherwise.
+
+1.	Caffe, (https://github.com/BVLC/caffe/)
+
+COPYRIGHT
+
+All contributions by the University of California:
+Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
+All rights reserved.
+
+All other contributions:
+Copyright (c) 2014, 2015, the respective contributors
+All rights reserved.
+
+Caffe uses a shared copyright model: each contributor holds copyright
+over their contributions to Caffe. The project versioning records all
+such contribution and copyright details. If a contributor wants to
+further mark their specific copyright on a particular contribution,
+they should indicate their copyright solely in the commit message of
+the change when it is committed.
+
+The BSD 2-Clause License
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+
+1. Redistributions of source code must retain the above copyright notice,
+this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in the
+documentation and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
-- 
GitLab