Skip to content
Snippets Groups Projects
Commit 0f5c82f8 authored by Neta Zmora's avatar Neta Zmora
Browse files

Truncated SVD: update notebook documentation and add customized module

Updated documentation per issue wq#359
parent 61fab0d0
No related branches found
No related tags found
No related merge requests found
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Truncated-SVD module.
For an example of how truncated-SVD can be used, see this Jupyter notebook:
https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb
"""
def truncated_svd(W, l):
"""Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD.
For the original implementation (MIT license), see Faster-RCNN:
https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
We replaced numpy operations with pytorch operations (so that we can leverage the GPU).
Arguments:
W: N x M weights matrix
l: number of singular values to retain
Returns:
Ul, L: matrices such that W \approx Ul*L
"""
U, s, V = torch.svd(W, some=True)
Ul = U[:, :l]
sl = s[:l]
V = V.t()
Vl = V[:l, :]
SV = torch.mm(torch.diag(sl), Vl)
return Ul, SV
class TruncatedSVD(nn.Module):
def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):
super().__init__()
self.replaced_gemm = replaced_gemm
print("W = {}".format(gemm_weights.shape))
self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))
print("U = {}".format(self.U.shape))
self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
self.fc_u.weight.data = self.U
print("SV = {}".format(self.SV.shape))
self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
self.fc_sv.weight.data = self.SV#.t()
def forward(self, x):
x = self.fc_sv.forward(x)
x = self.fc_u.forward(x)
return x
%% Cell type:markdown id: tags:
# Truncated SVD
This is a simple application of [Truncated SVD](http://langvillea.people.cofc.edu/DISSECTION-LAB/Emmie'sLSI-SVDModule/p5module.html), just to get a feeling of what happens to the accuracy if we use TruncatedSVD **w/o fine-tuning.**
We apply Truncated SVD on the linear layer found at the end of ResNet50, and run a test over the validation dataset to measure the impact on the classification accuracy.
Spoiler:
At k=800 (80%) we get
- Top1: 76.02
- Top5: 92.86
Total weights: 1000 * 800 + 800 * 2048 = 2,438,400 (vs. 1000 * 2048 = 2,048,000)
At k=700 (70%) we get
- Top1: 76.03
- Top5: 92.85
Total weights: 1000 * 700 + 700 * 2048 = 2,133,600 (vs. 2,048,000)
At k=600 (60%) we get
- Top1: 75.98
- Top5: 92.82
Total weights: 1000 * 600 + 600 * 2048 = 1,828,800 (vs. 2,048,000)
At k=500 (50%) we get
- Top1: 75.78
- Top5: 92.77
Total weights: 1000 * 500 + 500 * 2048 = 1,524,000 (vs. 2,048,000)
At k=400 (40%) we get
- Top1: 75.65
- Top5: 92.75
Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)
## Details
[SVD (Singular Value Decomposition)](https://en.wikipedia.org/wiki/Singular_value_decomposition) is an exact factorization of a matrix, W (of shape m x n), to the form USV$^T$ (U is m x m, S is m x n, V$^T$ is n x n; V$^T$ is the transpose of V). Every matrix has an SVD.
A Linear (fully-connected) layer performs: y = Wx + b (or y = xW$^T$ + b)<br>
We can use SVD to refactor W to rewrite this as: y = (USV$^T$)x + b
So far, we haven’t done any compression, so let’s get to it using Truncated SVD.<br> TSVD is a method to provide an approximated decomposition of W, in which S has a lower rank. We want to find an approximation of W that is “good enough” and also accelerates the computation of Wx.
We choose some lower-rank, k, such that k<m (preferably k<<m).<br>
TSVD is straight-forward: keep the largest k singular values of S and discard the rest (truncate S).
After TSVD we have:
U’ is m x t, S’ is k x k, V’$^T$ is k x n.<br>
y ~ (U’S’V’)x + b<br>
y ~ (U’(S’V’))x + b<br>
We'll replace S’V’$^T$ with A, because we can pre-compute it once. A has shape k x n:<br>
y ~ (U’A)x + b<br>
y ~ U’(Ax) + b<br>
Let’s ignore the bias and calculate the number of parameters and FLOPs (floating point operations) for the original y:
- m * n weights coefficients<br>
- m * n FLOPs (for batch size = 1)<br>
After TSVD we have:
- mk + kn = k*(m+n) weights coefficients<br>
- kn + mk = k*(m+n) FLOPs (for batch size = 1)<br>
To actually compress the weights after TSVD, we want: m * n > k*(m+n)<br>
Let’s rewrite k in terms of m: k = tm<br>
m * n > tm*(m+n)<br>
n > t*(m+n)<br>
n / (m+n) >= t<br>
This is the math, but for an actual performance increase, we should strive for m * n >> k*(m+n)
In the example notebook: m = 1000; n=2048<br>
So when t=2048/(1000+2048) (that is, k=2048/3048*1000=672), we have equilibrium. When 0.672>t (i.e. k is smaller than 672), the sum of the size of the weights of A and U’ is smaller than the size of W.
%% Cell type:code id: tags:
``` python
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import scipy.stats as ss
# Relative import of code from distiller, w/o installing the package
import os
import sys
import numpy as np
import scipy
import matplotlib.pyplot as plt
module_path = os.path.abspath(os.path.join('..'))
import distiller
import distiller.apputils as apputils
import distiller.models as models
from distiller.apputils import *
plt.style.use('seaborn') # pretty matplotlib plots
```
%% Cell type:markdown id: tags:
## Utilities
%% Cell type:code id: tags:
``` python
def imagenet_load_data(data_dir, batch_size, num_workers, shuffle=True):
"""Load the ImageNet dataset"""
test_dir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(test_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=True)
return test_loader
```
%% Cell type:code id: tags:
``` python
BATCH_SIZE = 4
# Data loader
test_loader = imagenet_load_data("../../data.imagenet/",
test_loader = imagenet_load_data("/datasets/imagenet/",
batch_size=BATCH_SIZE,
num_workers=2)
# Reverse the normalization transformations we performed when we loaded the data
# for consumption by our CNN.
# See: https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
std = [ 1., 1., 1. ]),
])
```
%% Cell type:markdown id: tags:
## Load ResNet 50
%% Cell type:code id: tags:
``` python
# Load the various models
resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)
# See Faster-RCNN: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
# Replaced numpy operations with pytorch operations (so that we can leverage the GPU).
def truncated_svd(W, l):
"""Compress the weight matrix W of an inner product (fully connected) layer
using truncated SVD.
Parameters:
W: N x M weights matrix
l: number of singular values to retain
Returns:
Ul, L: matrices such that W \approx Ul*L
"""
U, s, V = torch.svd(W, some=True)
Ul = U[:, :l]
sl = s[:l]
V = V.t()
Vl = V[:l, :]
SV = torch.mm(torch.diag(sl), Vl)
return Ul, SV
class TruncatedSVD(nn.Module):
def __init__(self, replaced_gemm, gemm_weights):
super(TruncatedSVD,self).__init__()
def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):
super().__init__()
self.replaced_gemm = replaced_gemm
print("W = {}".format(gemm_weights.shape))
self.U, self.SV = truncated_svd(gemm_weights.data, int(0.4 * gemm_weights.size(0)))
self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))
print("U = {}".format(self.U.shape))
self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
self.fc_u.weight.data = self.U
print("SV = {}".format(self.SV.shape))
self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
self.fc_sv.weight.data = self.SV#.t()
def forward(self, x):
x = self.fc_sv.forward(x)
x = self.fc_u.forward(x)
return x
def replace(model):
fc_weights = model.state_dict()['fc.weight']
fc_layer = model.fc
print("fc_layer({}, {})".format(fc_layer.in_features, fc_layer.out_features))
model.fc = TruncatedSVD(fc_layer, fc_weights)
model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4)
from copy import deepcopy
resnet50 = deepcopy(resnet50)
replace(resnet50)
```
%% Output
fc_layer(2048, 1000)
W = torch.Size([1000, 2048])
U = torch.Size([1000, 400])
SV = torch.Size([400, 2048])
%% Cell type:code id: tags:
``` python
# Standard loop to test the accuracy of a model.
import time
import torchnet.meter as tnt
t0 = time.time()
test_loader = imagenet_load_data("../../data.imagenet",
test_loader = imagenet_load_data("/datasets/imagenet",
batch_size=64,
num_workers=4,
shuffle=False)
t1 = time.time()
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
resnet50.eval()
for validation_step, (inputs, target) in enumerate(test_loader):
with torch.no_grad():
inputs, target = inputs.to('cuda'), target.to('cuda')
outputs = resnet50(inputs)
classerr.add(outputs.data, target)
if (validation_step+1) % 100 == 0:
print("progress: %d images" % ((validation_step+1) * 512))
print("Top1: %.2f Top5: %.2f" % (classerr.value(1), classerr.value(5)))
t2 = time.time()
print("Duration: ", t2-t0)
```
%% Output
progress: 51200 images
progress: 102400 images
progress: 153600 images
progress: 204800 images
progress: 256000 images
progress: 307200 images
progress: 358400 images
Top1: 75.70 Top5: 92.76
Duration: 168.111492395401
%% Cell type:code id: tags:
``` python
```
......
Faster R-CNN
The MIT License (MIT)
Copyright (c) 2015 Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
************************************************************************
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
This project, Faster R-CNN, incorporates material from the project(s)
listed below (collectively, "Third Party Code"). Microsoft is not the
original author of the Third Party Code. The original copyright notice
and license under which Microsoft received such Third Party Code are set
out below. This Third Party Code is licensed to you under their original
license terms set forth below. Microsoft reserves all other rights not
expressly granted, whether by implication, estoppel or otherwise.
1. Caffe, (https://github.com/BVLC/caffe/)
COPYRIGHT
All contributions by the University of California:
Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
All rights reserved.
All other contributions:
Copyright (c) 2014, 2015, the respective contributors
All rights reserved.
Caffe uses a shared copyright model: each contributor holds copyright
over their contributions to Caffe. The project versioning records all
such contribution and copyright details. If a contributor wants to
further mark their specific copyright on a particular contribution,
they should indicate their copyright solely in the commit message of
the change when it is committed.
The BSD 2-Clause License
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment