Skip to content
Snippets Groups Projects
Unverified Commit a32d2c96 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

sensitivity_analysis.ipynb: fix example checkpoint path

fix path to the resnet20 checkpoint in one of the examples
parent ef2832b7
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Sensitivity Analysis
Some pruning algorthims tune their hyperparameters based on the results of pruning sensitivity analysis. Distiller support L1-norm element-wise pruning sensitivity analysis, and filter-wise pruning sensitivity analysis based on the mean L1-norm ranking of filters.
## Table of Contents
1. [Load a pruning sensitivity analysis file](#Load-a-pruning-sensitivity-analysis-file)
2. [Examine parameters sensitivities](#Examine-parameters-sensitivities)<br>
2.1. [Plot layer sensitivities at a selected sparsity level](#Plot-layer-sensitivities-at-a-selected-sparsity-level)<br>
2.2. [Compare layer sensitivities](#Compare-layer-sensitivities)
3. [Filter pruning sensitivity analysis](#Filter-pruning-sensitivity-analysis)
%% Cell type:markdown id: tags:
## Load a pruning sensitivity analysis file
You prepare a sensitivity analysis file by invoking ```distiller.perform_sensitivity_analysis()```. Checkout the documentation of ```distiller.perform_sensitivity_analysis()``` for more information.<br>
Alternatively, you can use the sample ```compress_classifier.py``` application to perform sensitivity analysis on one of the supported models. In the example below, we invoke sensitivity analysis on a pretrained Resnet18 from torchvision, using the ImageNet test dataset for evaluation.
```
$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet -j 12 --pretrained --sense=element
```
The outputs of performing pruning sensitivity analysis on several different networks is available at ```../examples/sensitivity-analysis```
%% Cell type:code id: tags:
``` python
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from ipywidgets import interactive, interact, Layout
df = df = pd.read_csv('../examples/sensitivity-analysis/mobilenet-imagenet/sensitivity.csv')
#df = pd.read_csv('../examples/sensitivity-analysis/resnet18-imagenet/sensitivity.csv')
#df = pd.read_csv('../examples/sensitivity-analysis/resnet56-cifar/sensitivity_filter_wise.csv')
#df = pd.read_csv('../examples/sensitivity-analysis/resnet20-cifar/sensitivity_filter_wise.csv')
df['sparsity'] = round(df['sparsity'], 2)
```
%% Cell type:markdown id: tags:
The code below converts the sensitivities dataframe to a sensitivities dictionary. <br>
Using this dictionary makes it easier for us when we want to plot sensitivities.
%% Cell type:code id: tags:
``` python
from collections import OrderedDict
def get_param_names(df):
return list(set(df['parameter']))
def get_sensitivity_levels(df):
return list(set(df['sparsity']))
def df2sensitivities(df):
param_names = get_param_names(df)
sparsities = get_sensitivity_levels(df)
sensitivities = {}
for param_name in param_names:
sensitivities[param_name] = OrderedDict()
param_stats = df[(df.parameter == param_name)]
for row in range(len(param_stats.index)):
s = param_stats.iloc[[row]].sparsity
top1 = param_stats.iloc[[row]].top1
top5 = param_stats.iloc[[row]].top5
sensitivities[param_name][float(s)] = (float(top1), float(top5))
return sensitivities
```
%% Cell type:markdown id: tags:
## Examine parameters sensitivities
After loading the sensitivity analysis CSV file into a Pandas dataframe, we can examine it.
### Plot layer sensitivities at a selected sparsity level
Use the dropdown to choose the sparsity level, and select whether you choose to view the top1 accuracies or top5.<br>
Under the plot we display the numerical values of the accuracies, in case you want to have a closer look at the details.
%% Cell type:code id: tags:
``` python
def view2(level, acc):
filtered = df[df.sparsity == level]
s = filtered.style.apply(highlight_min_max)
param_names = filtered['parameter']
# Plot the sensitivities
x = range(filtered[acc].shape[0])
y = filtered[acc].values.tolist()
fig = plt.figure(figsize=(20,10))
plt.plot(x, y, label=param_names, marker="o", markersize=10, markerfacecolor="C1")
plt.ylabel(str(acc))
plt.xlabel('parameter')
plt.xticks(rotation='vertical')
plt.xticks(x, param_names)
plt.title('Pruning Sensitivity per layer %d' % level)
#return s
def highlight_min_max(s):
"""Highlight the max and min values in the series"""
if s.name not in ['top1', 'top5']:
return ['' for v in s]
is_max = s == s.max()
maxes = ['background-color: green' if v else '' for v in is_max]
is_min = s == s.min()
mins = ['background-color: red' if v else '' for v in is_min]
return [h1 if len(h1)>len(h2) else h2 for (h1,h2) in zip(maxes, mins)]
```
%% Cell type:code id: tags:
``` python
sparsities = np.sort(get_sensitivity_levels(df))
acc_radio = widgets.RadioButtons(options=['top1', 'top5'], value='top1', description='Accuracy:')
levels_dropdown = widgets.Dropdown(description='Sparsity:', options=sparsities)
interact(view2, level=levels_dropdown, acc=acc_radio);
```
%% Cell type:markdown id: tags:
Sometimes we want to look at the sensitivies of a specific weights tensor:
%% Cell type:code id: tags:
``` python
def view_sparsity(param_name):
display(df[df['parameter']==param_name])
param_names = sorted(df['parameter'].unique().tolist())
param_dropdown = widgets.Dropdown(description='Parameter:', options=param_names)
interact(view_sparsity, param_name=param_dropdown);
```
%% Cell type:markdown id: tags:
### Compare layer sensitivities
Plot the pruning sensitivities of selected layers.
<br>Select multiple parameters using SHIFT and CTRL.
%% Cell type:code id: tags:
``` python
# assign a different color to each parameter (otherwise, colors change on us as we make different selections)
param_names = sorted(df['parameter'].unique().tolist())
color_idx = np.linspace(0, 1, len(param_names))
colors = {}
for i, pname in zip(color_idx, param_names):
colors[pname] = color= plt.get_cmap('tab20')(i)
plt.rcParams.update({'font.size': 18})
def view(weights='', acc=0):
sensitivities= None
if weights[0]=='All':
sensitivities = df2sensitivities(df)
else:
mask = False
mask = [(df.parameter == pname) for pname in weights]
mask = np.logical_or.reduce(mask)
sensitivities = df2sensitivities(df[mask])
# Plot the sensitivities
fig = plt.figure(figsize=(20,10))
for param_name, sensitivity in sensitivities.items():
sense = [values[acc] for sparsity, values in sensitivity.items()]
sparsities = [sparsity for sparsity, values in sensitivity.items()]
plt.plot(sparsities, sense, label=param_name, marker="o", markersize=10, color=colors[param_name])
plt.ylabel('top1')
plt.xlabel('sparsity')
plt.title('Pruning Sensitivity')
#plt.legend(loc='lower center', ncol=2, mode="expand", borderaxespad=0.);
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=True, ncol=3)
items = ['All']+param_names
w = widgets.SelectMultiple(options=items, value=[items[1]], layout=Layout(width='50%'), description='Weights:')
acc_widget = widgets.RadioButtons(options={'top1': 0, 'top5': 1}, value=0, description='Accuracy:')
interactive(view, acc=acc_widget, weights=w)
```
%% Cell type:markdown id: tags:
## Filter pruning sensitivity analysis
Just as we perform element-wise pruning sensitivity analysis, we can also analyze a model's filter-wise pruning sensitivity. Although the sparsity levels are reported in percentage steps, the actual pruning level might be somewhat lower, because when we prune filters the minimum granularity of pruning is ```1/numer_of_filters```.
We performed a filter-wise pruning sensitivity analysis on ResNet20-Cifar using the following command:
```
python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j 12 --resume=../cifar10/resnet20/checkpoint_trained_dense.pth.tar --sense=filter
python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10/ -j 12 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --sense=filter
```
%% Cell type:code id: tags:
``` python
df_filter = pd.read_csv('../examples/sensitivity-analysis/resnet20-cifar/sensitivity_filter_wise.csv')
df_element = pd.read_csv('../examples/sensitivity-analysis/resnet20-cifar/sensitivity.csv')
```
%% Cell type:code id: tags:
``` python
def view_sparsity(param_name):
display(df_filter[df_filter['parameter']==param_name])
param_names = sorted(df_filter['parameter'].unique().tolist())
param_dropdown = widgets.Dropdown(description='Parameter:', options=param_names)
interact(view_sparsity, param_name=param_dropdown);
```
%% Cell type:markdown id: tags:
Now let's look at the sparsity vs. the compute:
%% Cell type:code id: tags:
``` python
def view_fliters(level, acc):
filtered = df_filter[df_filter.sparsity == level]
s = filtered.style.apply(highlight_min_max)
param_names = filtered['parameter']
# Plot the sensitivities
x = range(filtered[acc].shape[0])
y = filtered[acc].values.tolist()
fig = plt.figure(figsize=(20,10))
plt.plot(x, y, label=param_names, marker="o", markersize=10, markerfacecolor="C1")
plt.ylabel(str(acc))
plt.xlabel('parameter')
plt.xticks(rotation='vertical')
plt.xticks(x, param_names)
plt.title('Filter pruning sensitivity per layer %d' % level)
return s
df_filter['sparsity'] = round(df_filter['sparsity'], 2)
sparsities = np.sort(get_sensitivity_levels(df_filter))
acc_radio = widgets.RadioButtons(options=['top1', 'top5'], value='top1', description='Accuracy:')
levels_dropdown = widgets.Dropdown(description='Sparsity:', options=sparsities)
interact(view_fliters, level=levels_dropdown, acc=acc_radio);
```
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment