From 0f5c82f83976edbaa25fb167fab130fa42b05d0a Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sat, 24 Aug 2019 19:52:51 +0300 Subject: [PATCH] Truncated SVD: update notebook documentation and add customized module Updated documentation per issue wq#359 --- distiller/modules/tsvd.py | 67 +++++++++++++++++ jupyter/truncated_svd.ipynb | 109 ++++++++++++++++++++++++---- licenses/py_faster_rcnn-license.txt | 81 +++++++++++++++++++++ 3 files changed, 241 insertions(+), 16 deletions(-) create mode 100755 distiller/modules/tsvd.py create mode 100755 licenses/py_faster_rcnn-license.txt diff --git a/distiller/modules/tsvd.py b/distiller/modules/tsvd.py new file mode 100755 index 0000000..aea20bc --- /dev/null +++ b/distiller/modules/tsvd.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Truncated-SVD module. + +For an example of how truncated-SVD can be used, see this Jupyter notebook: +https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb + +""" + +def truncated_svd(W, l): + """Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD. + + For the original implementation (MIT license), see Faster-RCNN: + https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py + We replaced numpy operations with pytorch operations (so that we can leverage the GPU). + + Arguments: + W: N x M weights matrix + l: number of singular values to retain + Returns: + Ul, L: matrices such that W \approx Ul*L + """ + + U, s, V = torch.svd(W, some=True) + + Ul = U[:, :l] + sl = s[:l] + V = V.t() + Vl = V[:l, :] + + SV = torch.mm(torch.diag(sl), Vl) + return Ul, SV + + +class TruncatedSVD(nn.Module): + def __init__(self, replaced_gemm, gemm_weights, preserve_ratio): + super().__init__() + self.replaced_gemm = replaced_gemm + print("W = {}".format(gemm_weights.shape)) + self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0))) + print("U = {}".format(self.U.shape)) + + self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda() + self.fc_u.weight.data = self.U + + print("SV = {}".format(self.SV.shape)) + self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda() + self.fc_sv.weight.data = self.SV#.t() + + def forward(self, x): + x = self.fc_sv.forward(x) + x = self.fc_u.forward(x) + return x diff --git a/jupyter/truncated_svd.ipynb b/jupyter/truncated_svd.ipynb index 7331d3c..07f41d6 100644 --- a/jupyter/truncated_svd.ipynb +++ b/jupyter/truncated_svd.ipynb @@ -40,12 +40,54 @@ " - Top1: 75.65 \n", " - Top5: 92.75\n", " \n", - " Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000) " + " Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)\n", + " \n", + "## Details\n", + "\n", + "[SVD (Singular Value Decomposition)](https://en.wikipedia.org/wiki/Singular_value_decomposition) is an exact factorization of a matrix, W (of shape m x n), to the form USV$^T$ (U is m x m, S is m x n, V$^T$ is n x n; V$^T$ is the transpose of V). Every matrix has an SVD.\n", + "\n", + "A Linear (fully-connected) layer performs: y = Wx + b (or y = xW$^T$ + b)<br>\n", + "We can use SVD to refactor W to rewrite this as: y = (USV$^T$)x + b\n", + "\n", + "So far, we haven’t done any compression, so let’s get to it using Truncated SVD.<br> TSVD is a method to provide an approximated decomposition of W, in which S has a lower rank. We want to find an approximation of W that is “good enough†and also accelerates the computation of Wx.\n", + "\n", + "We choose some lower-rank, k, such that k<m (preferably k<<m).<br>\n", + "TSVD is straight-forward: keep the largest k singular values of S and discard the rest (truncate S).\n", + "\n", + "After TSVD we have:\n", + "U’ is m x t, S’ is k x k, V’$^T$ is k x n.<br>\n", + "y ~ (U’S’V’)x + b<br>\n", + "y ~ (U’(S’V’))x + b<br>\n", + "\n", + "We'll replace S’V’$^T$ with A, because we can pre-compute it once. A has shape k x n:<br>\n", + "y ~ (U’A)x + b<br>\n", + "y ~ U’(Ax) + b<br>\n", + "\n", + "Let’s ignore the bias and calculate the number of parameters and FLOPs (floating point operations) for the original y:\n", + "\n", + " - m * n weights coefficients<br>\n", + " - m * n FLOPs (for batch size = 1)<br>\n", + "\n", + "After TSVD we have:\n", + "\n", + " - mk + kn = k*(m+n) weights coefficients<br>\n", + " - kn + mk = k*(m+n) FLOPs (for batch size = 1)<br>\n", + "\n", + "To actually compress the weights after TSVD, we want: m * n > k*(m+n)<br>\n", + "Let’s rewrite k in terms of m: k = tm<br>\n", + "m * n > tm*(m+n)<br>\n", + "n > t*(m+n)<br>\n", + "n / (m+n) >= t<br>\n", + "\n", + "This is the math, but for an actual performance increase, we should strive for m * n >> k*(m+n)\n", + "\n", + "In the example notebook: m = 1000; n=2048<br>\n", + "So when t=2048/(1000+2048) (that is, k=2048/3048*1000=672), we have equilibrium. When 0.672>t (i.e. k is smaller than 672), the sum of the size of the weights of A and U’ is smaller than the size of W.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -105,14 +147,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 4\n", "\n", "# Data loader\n", - "test_loader = imagenet_load_data(\"../../data.imagenet/\", \n", + "test_loader = imagenet_load_data(\"/datasets/imagenet/\", \n", " batch_size=BATCH_SIZE, \n", " num_workers=2)\n", " \n", @@ -136,11 +178,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fc_layer(2048, 1000)\n", + "W = torch.Size([1000, 2048])\n", + "U = torch.Size([1000, 400])\n", + "SV = torch.Size([400, 2048])\n" + ] + } + ], "source": [ "# Load the various models\n", "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n", @@ -170,11 +223,11 @@ "\n", "\n", "class TruncatedSVD(nn.Module):\n", - " def __init__(self, replaced_gemm, gemm_weights):\n", - " super(TruncatedSVD,self).__init__()\n", + " def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):\n", + " super().__init__()\n", " self.replaced_gemm = replaced_gemm\n", " print(\"W = {}\".format(gemm_weights.shape))\n", - " self.U, self.SV = truncated_svd(gemm_weights.data, int(0.4 * gemm_weights.size(0)))\n", + " self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))\n", " print(\"U = {}\".format(self.U.shape))\n", " \n", " self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()\n", @@ -182,20 +235,21 @@ " \n", " print(\"SV = {}\".format(self.SV.shape))\n", " self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()\n", - " self.fc_sv.weight.data = self.SV#.t()\n", - " \n", + " self.fc_sv.weight.data = self.SV#.t() \n", "\n", " def forward(self, x):\n", " x = self.fc_sv.forward(x)\n", " x = self.fc_u.forward(x)\n", " return x\n", "\n", + " \n", "def replace(model):\n", " fc_weights = model.state_dict()['fc.weight']\n", " fc_layer = model.fc\n", " print(\"fc_layer({}, {})\".format(fc_layer.in_features, fc_layer.out_features))\n", - " model.fc = TruncatedSVD(fc_layer, fc_weights)\n", + " model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4)\n", "\n", + " \n", "from copy import deepcopy\n", "resnet50 = deepcopy(resnet50)\n", "replace(resnet50)" @@ -203,18 +257,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "progress: 51200 images\n", + "progress: 102400 images\n", + "progress: 153600 images\n", + "progress: 204800 images\n", + "progress: 256000 images\n", + "progress: 307200 images\n", + "progress: 358400 images\n", + "Top1: 75.70 Top5: 92.76\n", + "Duration: 168.111492395401\n" + ] + } + ], "source": [ "# Standard loop to test the accuracy of a model.\n", "\n", "import time\n", "import torchnet.meter as tnt\n", "t0 = time.time()\n", - "test_loader = imagenet_load_data(\"../../data.imagenet\", \n", + "test_loader = imagenet_load_data(\"/datasets/imagenet\", \n", " batch_size=64, \n", " num_workers=4,\n", " shuffle=False)\n", @@ -234,6 +304,13 @@ "t2 = time.time()\n", "print(\"Duration: \", t2-t0)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/licenses/py_faster_rcnn-license.txt b/licenses/py_faster_rcnn-license.txt new file mode 100755 index 0000000..1ab42b2 --- /dev/null +++ b/licenses/py_faster_rcnn-license.txt @@ -0,0 +1,81 @@ +Faster R-CNN + +The MIT License (MIT) + +Copyright (c) 2015 Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project, Faster R-CNN, incorporates material from the project(s) +listed below (collectively, "Third Party Code"). Microsoft is not the +original author of the Third Party Code. The original copyright notice +and license under which Microsoft received such Third Party Code are set +out below. This Third Party Code is licensed to you under their original +license terms set forth below. Microsoft reserves all other rights not +expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014, 2015, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, 2015, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright +over their contributions to Caffe. The project versioning records all +such contribution and copyright details. If a contributor wants to +further mark their specific copyright on a particular contribution, +they should indicate their copyright solely in the commit message of +the change when it is committed. + +The BSD 2-Clause License + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** -- GitLab