Skip to content
Snippets Groups Projects
Commit 2daf616f authored by Neta Zmora's avatar Neta Zmora
Browse files

truncated_svd.ipynb: replace numpy with pytorch

Replaced numpy operations with pytorch operations
(so that we can leverage the GPU).
parent ba1ee25b
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Truncated SVD
This is a simple application of [Truncated SVD](http://langvillea.people.cofc.edu/DISSECTION-LAB/Emmie'sLSI-SVDModule/p5module.html), just to get a feeling of what happens to the accuracy if we use TruncatedSVD **w/o fine-tuning.**
We apply Truncated SVD on the linear layer found at the end of ResNet50, and run a test over the validation dataset to measure the impact on the classification accuracy.
Spoiler:
At k=800 (80%) we get
- Top1: 76.02
- Top5: 92.86
Total weights: 1000 * 800 + 800 * 2048 = 2,438,400 (vs. 1000 * 2048 = 2,048,000)
At k=700 (70%) we get
- Top1: 76.03
- Top5: 92.85
Total weights: 1000 * 700 + 700 * 2048 = 2,133,600 (vs. 2,048,000)
At k=600 (60%) we get
- Top1: 75.98
- Top5: 92.82
Total weights: 1000 * 600 + 600 * 2048 = 1,828,800 (vs. 2,048,000)
At k=500 (50%) we get
- Top1: 75.78
- Top5: 92.77
Total weights: 1000 * 500 + 500 * 2048 = 1,524,000 (vs. 2,048,000)
At k=400 (40%) we get
- Top1: 75.65
- Top5: 92.75
Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)
%% Cell type:code id: tags:
``` python
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import scipy.stats as ss
# Relative import of code from distiller, w/o installing the package
import os
import sys
import numpy as np
import scipy
import matplotlib.pyplot as plt
module_path = os.path.abspath(os.path.join('..'))
import distiller
import distiller.apputils as apputils
import distiller.models as models
from distiller.apputils import *
plt.style.use('seaborn') # pretty matplotlib plots
```
%% Cell type:markdown id: tags:
## Utilities
%% Cell type:code id: tags:
``` python
def imagenet_load_data(data_dir, batch_size, num_workers, shuffle=True):
"""Load the ImageNet dataset"""
test_dir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(test_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=True)
return test_loader
```
%% Cell type:code id: tags:
``` python
BATCH_SIZE = 4
# Data loader
test_loader = imagenet_load_data("../../data.imagenet/",
batch_size=BATCH_SIZE,
num_workers=2)
# Reverse the normalization transformations we performed when we loaded the data
# for consumption by our CNN.
# See: https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
std = [ 1., 1., 1. ]),
])
```
%% Cell type:markdown id: tags:
## Load ResNet 50
%% Cell type:code id: tags:
``` python
# Load the various models
resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)
# See Faster-RCNN: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
# Replaced numpy operations with pytorch operations (so that we can leverage the GPU).
def truncated_svd(W, l):
"""Compress the weight matrix W of an inner product (fully connected) layer
using truncated SVD.
Parameters:
W: N x M weights matrix
l: number of singular values to retain
Returns:
Ul, L: matrices such that W \approx Ul*L
"""
U, s, V = np.linalg.svd(W, full_matrices=False)
U, s, V = torch.svd(W, some=True)
Ul = U[:, :l]
sl = s[:l]
V = V.t()
Vl = V[:l, :]
SV = np.dot(np.diag(sl), Vl)
SV = torch.from_numpy(SV).cuda()
Ul = torch.from_numpy(Ul).cuda()
SV = torch.mm(torch.diag(sl), Vl)
return Ul, SV
class TruncatedSVD(nn.Module):
def __init__(self, replaced_gemm, gemm_weights):
super(TruncatedSVD,self).__init__()
self.replaced_gemm = replaced_gemm
print("W = {}".format(gemm_weights.shape))
self.U, self.SV = truncated_svd(gemm_weights.cpu(), int(0.4 * gemm_weights.size(0)))
self.U, self.SV = truncated_svd(gemm_weights.data, int(0.4 * gemm_weights.size(0)))
print("U = {}".format(self.U.shape))
self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
self.fc_u.weight.data = self.U
print("SV = {}".format(self.SV.shape))
self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
self.fc_sv.weight.data = self.SV#.t()
def forward(self, x):
x = self.fc_sv.forward(x)
x = self.fc_u.forward(x)
return x
def replace(model):
fc_weights = model.state_dict()['fc.weight']
fc_layer = model.fc
print("fc_layer({}, {})".format(fc_layer.in_features, fc_layer.out_features))
model.fc = TruncatedSVD(fc_layer, fc_weights)
from copy import deepcopy
resnet50 = deepcopy(resnet50)
replace(resnet50)
```
%% Cell type:code id: tags:
``` python
# Standard loop to test the accuracy of a model.
import time
import torchnet.meter as tnt
t0 = time.time()
test_loader = imagenet_load_data("../../datasets/imagenet",
test_loader = imagenet_load_data("../../data.imagenet",
batch_size=64,
num_workers=4,
shuffle=False)
t1 = time.time()
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
resnet50.eval()
for validation_step, (inputs, target) in enumerate(test_loader):
with torch.no_grad():
inputs, target = inputs.to('cuda'), target.to('cuda')
outputs = resnet50(inputs)
classerr.add(outputs.data, target)
if (validation_step+1) % 100 == 0:
print((validation_step+1) * 512)
print("progress: %d images" % ((validation_step+1) * 512))
print(classerr.value(1), classerr.value(5))
print("Top1: %.2f Top5: %.2f" % (classerr.value(1), classerr.value(5)))
t2 = time.time()
print(t2-t0)
```
%% Cell type:code id: tags:
``` python
print("Duration: ", t2-t0)
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment