Skip to content
Snippets Groups Projects
Commit 91f845a4 authored by Neta Zmora's avatar Neta Zmora
Browse files

Jupyter: straight-forward application of Truncated SVD on ResNet50

This is a simple application of Truncated SVD, just to get a feeling of
what happens to the accuracy if we use TruncatedSVD w/o fine-tuning.

We apply Truncated SVD on the linear layer found at the end of ResNet50,
and run a test over the validation dataset to measure the impact on the
classification accuracy.
parent 95546cd1
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Truncated SVD
This is a simple application of [Truncated SVD](http://langvillea.people.cofc.edu/DISSECTION-LAB/Emmie'sLSI-SVDModule/p5module.html), just to get a feeling of what happens to the accuracy if we use TruncatedSVD **w/o fine-tuning.**
We apply Truncated SVD on the linear layer found at the end of ResNet50, and run a test over the validation dataset to measure the impact on the classification accuracy.
Spoiler:
At k=800 (80%) we get
- Top1: 76.02
- Top5: 92.86
Total weights: 1000 * 800 + 800 * 2048 = 2,438,400 (vs. 1000 * 2048 = 2,048,000)
At k=700 (70%) we get
- Top1: 76.03
- Top5: 92.85
Total weights: 1000 * 700 + 700 * 2048 = 2,133,600 (vs. 2,048,000)
At k=600 (60%) we get
- Top1: 75.98
- Top5: 92.82
Total weights: 1000 * 600 + 600 * 2048 = 1,828,800 (vs. 2,048,000)
At k=500 (50%) we get
- Top1: 75.78
- Top5: 92.77
Total weights: 1000 * 500 + 500 * 2048 = 1,524,000 (vs. 2,048,000)
At k=400 (40%) we get
- Top1: 75.65
- Top5: 92.75
Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)
%% Cell type:code id: tags:
``` python
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import scipy.stats as ss
# Relative import of code from distiller, w/o installing the package
import os
import sys
import numpy as np
import scipy
import matplotlib.pyplot as plt
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
import distiller
import apputils
import models
from apputils import *
plt.style.use('seaborn') # pretty matplotlib plots
```
%% Cell type:markdown id: tags:
## Utilities
%% Cell type:code id: tags:
``` python
def imagenet_load_data(data_dir, batch_size, num_workers, shuffle=True):
"""Load the ImageNet dataset"""
test_dir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(test_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=True)
return test_loader
```
%% Cell type:code id: tags:
``` python
BATCH_SIZE = 4
# Data loader
test_loader = imagenet_load_data("../../data.imagenet",
batch_size=BATCH_SIZE,
num_workers=2)
# Reverse the normalization transformations we performed when we loaded the data
# for consumption by our CNN.
# See: https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
std = [ 1., 1., 1. ]),
])
```
%% Cell type:markdown id: tags:
## Load ResNet 50
%% Cell type:code id: tags:
``` python
# Load the various models
resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)
# See Faster-RCNN: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
def truncated_svd(W, l):
"""Compress the weight matrix W of an inner product (fully connected) layer
using truncated SVD.
Parameters:
W: N x M weights matrix
l: number of singular values to retain
Returns:
Ul, L: matrices such that W \approx Ul*L
"""
U, s, V = np.linalg.svd(W, full_matrices=False)
Ul = U[:, :l]
sl = s[:l]
Vl = V[:l, :]
SV = np.dot(np.diag(sl), Vl)
SV = torch.from_numpy(SV).cuda()
Ul = torch.from_numpy(Ul).cuda()
return Ul, SV
class TruncatedSVD(nn.Module):
def __init__(self, replaced_gemm, gemm_weights):
super(TruncatedSVD,self).__init__()
self.replaced_gemm = replaced_gemm
print("W = {}".format(gemm_weights.shape))
self.U, self.SV = truncated_svd(gemm_weights, int(0.4 * gemm_weights.size(0)))
print("U = {}".format(self.U.shape))
self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
self.fc_u.weight.data = self.U
print("SV = {}".format(self.SV.shape))
self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
self.fc_sv.weight.data = self.SV#.t()
def forward(self, x):
x = self.fc_sv.forward(x)
x = self.fc_u.forward(x)
return x
def replace(model):
fc_weights = model.state_dict()['fc.weight']
fc_layer = model.fc
print("fc_layer({}, {})".format(fc_layer.in_features, fc_layer.out_features))
model.fc = TruncatedSVD(fc_layer, fc_weights)
resnet50 = deepcopy(resnet50)
replace(resnet50)
```
%% Cell type:code id: tags:
``` python
# Standard loop to test the accuracy of a model.
import time
import torchnet.meter as tnt
t0 = time.time()
test_loader = imagenet_load_data("../../data.imagenet",
batch_size=512,
num_workers=4,
shuffle=False)
t1 = time.time()
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))
resnet50.eval()
for validation_step, (inputs, target) in enumerate(test_loader):
with torch.no_grad():
inputs, target = inputs.to('cuda'), target.to('cuda')
outputs = resnet50(inputs)
classerr.add(outputs.data, target)
if (validation_step+1) % 10 == 0:
print((validation_step+1) * 512)
print(classerr.value(1), classerr.value(5))
t2 = time.time()
print(t2-t0)
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment