Skip to content
Snippets Groups Projects
Commit a89dfe6b authored by Neta Zmora's avatar Neta Zmora
Browse files

Jupyter notebook: Fix PyTorch 0.4 compatability issue when rendering

Sometimes the gmin/gmax in group color-normalization ends up with a zero
dimensional tensor, which needs to be accessed using .item()
parent aa8862bd
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Alexnet insights: visualizing the pruning process
This notebook examines the results of pruning Alexnet using sensitivity pruning, through a few chosen visualizations created from checkpoints created during pruning. We also compare the results of an element-wise pruning session, with 2D (kernel) regularization.
For the notebook, we pruned Alexnet using sensitivity pruning and captured the checkpoints after the first epoch ends (epoch 0) and after the last epoch ends (epoch 89).
You can download these checkpoints from here:
* https://s3-us-west-1.amazonaws.com/nndistiller/sensitivity-pruning/alexnet.checkpoint.0.pth.tar
* https://s3-us-west-1.amazonaws.com/nndistiller/sensitivity-pruning/alexnet.checkpoint.89.pth.tar
* https://s3-us-west-1.amazonaws.com/nndistiller/hybrid/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar
## Table of Contents
1. [Load the training checkpoints](#Load-the-training-checkpoints)
2. [Let's see some statistics](#Let's-see-some-statistics)
3. [Compare weights distributions](#Compare-weights-distributions)
4. [Visualize the weights](#Visualize-the-weights)<br>
4.1. [Fully-connected layers](#Fully-connected-layers)<br>
4.2. [Convolutional layers](#Convolutional-layers)<br>
4.3. [Kernel pruning](#Kernel-pruning)<br>
4.4. [Let's isolate just the 2D kernels](#Let's-isolate-just-the-2D-kernels)
%% Cell type:code id: tags:
``` python
import matplotlib
# Load some common jupyter code
%run distiller_jupyter_helpers.ipynb
from models import create_model
from apputils import *
import qgrid
from ipywidgets import *
from bqplot import *
import bqplot.pyplot as bqplt
from functools import partial
```
%% Cell type:markdown id: tags:
## Load the training checkpoints
Load the checkpoint captured after one pruning event, and fine-tuning for one epoch:
%% Cell type:code id: tags:
``` python
epoch0_model = create_model(False, 'imagenet', 'alexnet', parallel=True)
checkpoint_file = "../examples/classifier_compression/alexnet.checkpoint.0.pth.tar"
try:
load_checkpoint(epoch0_model, checkpoint_file);
except NameError as e:
print("Did you forget to download the checkpoint file?")
raise e
```
%% Cell type:markdown id: tags:
Load the checkpoint captured at the end of the pruning and fine-tuning process:
%% Cell type:code id: tags:
``` python
epoch89_model = create_model(False, 'imagenet', 'alexnet', parallel=True)
checkpoint_file = "../examples/classifier_compression/alexnet.checkpoint.89.pth.tar"
try:
load_checkpoint(epoch89_model, checkpoint_file);
except Exception as e:
print("Did you forget to download the checkpoint file?")
raise e
sparse_model = epoch89_model
```
%% Cell type:markdown id: tags:
Load a pre-trained dense Alexnet:
%% Cell type:code id: tags:
``` python
pretrained_model = create_model(True, 'imagenet', 'alexnet', parallel=True)
dense_model = pretrained_model
```
%% Cell type:markdown id: tags:
We want to compare the output of element-wise pruning, to a similar schedule but which also adds 2D (kernel-wise) Lasso regularization, so we load the last checkpoint of that pruning session.
The schedule is available at: ```distiller/examples/hybrid/alexnet.schedule_sensitivity_2D-reg.yaml```.
%% Cell type:code id: tags:
``` python
reg2D_model = create_model(False, 'imagenet', 'alexnet', parallel=True)
checkpoint_file = "../examples/classifier_compression/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar"
try:
load_checkpoint(reg2D_model, checkpoint_file);
except Exception as e:
print("Did you forget to download the checkpoint file?")
raise e
```
%% Cell type:markdown id: tags:
Create a dictionary of the models, with name as key, so that we can refer to it later:
%% Cell type:code id: tags:
``` python
models_dict = {'Dense': dense_model, 'Sparse': sparse_model, 'Epoch 0': epoch0_model, '2D-Sparse': reg2D_model}
```
%% Cell type:markdown id: tags:
## Let's see some statistics
You can use the dropwdown widget to choose which model to display.
You can also choose to display the sparsity or the density of the tensors. These are reported for various granularities (structures) of sparsities.
%% Cell type:code id: tags:
``` python
def view_data(what, model_choice):
df_sparsity = distiller.weights_sparsity_summary(models_dict[model_choice])
if what == 'Density':
for granularity in ['Fine (%)', 'Ch (%)', '2D (%)', '3D (%)']:
df_sparsity[granularity] = 100 - df_sparsity[granularity]
display(df_sparsity)
model_dropdown = Dropdown(description='Model:', options=models_dict.keys())
display_radio = widgets.RadioButtons(options=['Sparsity', 'Density'], value='Sparsity', description='Display:')
interact(view_data, what=display_radio, model_choice=model_dropdown);
```
%% Cell type:markdown id: tags:
## Compare weights distributions
Compare the distributions of the weight tensors in the sparse and dense models
%% Cell type:code id: tags:
``` python
nbins = 100
def get_hist2(model, nbins, param_name, remove_zeros):
tensor = flatten(model.state_dict()[param_name])
if remove_zeros:
tensor = tensor[tensor != 0]
hist, edges = np.histogram(tensor, bins=nbins, density=False)
return hist, edges
def graph_setup(models, titles, param_name, remove_zeros):
#xs, ys = [LinearScale(), LinearScale()], [LinearScale(), LinearScale()]
xs = [LinearScale() for i in range(len(models))]
ys = [LinearScale() for i in range(len(models))]
xax = [Axis(scale=xs[0])] * len(models)
yax = [Axis(scale=ys[0], orientation='vertical', grid_lines='solid')] * len(models)
bars = []
funcs = []
for i in range(len(models)):
hist, edges = get_hist2(models[i], nbins, param_name, remove_zeros)
bars.append(Bars(x=edges, y=[hist], scales={'x': xs[i], 'y': ys[i]}, padding=0.2, type='grouped'))
funcs.append(Figure(marks=[bars[i]], axes=[xax[i], yax[i]], animation_duration=1000, title=titles[i]))
shape = distiller.size2str(next (iter (models[0].state_dict().values())).size())
param_info = widgets.Text(value=shape, description='shape:', disabled=True)
return bars, funcs, param_info
params_names = conv_fc_param_names(sparse_model)
weights_dropdown = Dropdown(description='weights', options=params_names)
def update_models(stam, bars, funcs, param_shape_desc, models):
param_name = weights_dropdown.value
for i in range(len(models)):
bars[i].y, bars[i].x = get_hist2(models[i], nbins, param_name, remove_zeros.value)
shape = distiller.size2str(models[0].state_dict()[param_name].size())
param_shape_desc.value = shape
```
%% Cell type:code id: tags:
``` python
titles = ['Dense', 'Epoch 0', 'Sparse', '2D-Sparse']
models = [models_dict[title] for title in titles]
bars, funcs, param_shape_desc = graph_setup(models, titles, param_name=weights_dropdown.value, remove_zeros=False)
update1 = partial(update_models, bars=bars, funcs=funcs, param_shape_desc = param_shape_desc, models=models)
weights_dropdown.observe(update1)
remove_zeros = widgets.Checkbox(value=False, description='Remove zeros')
remove_zeros.observe(update1)
def draw_graph(models):
if len(models) > 2:
return (VBox([
HBox([weights_dropdown, param_shape_desc, remove_zeros] ),
VBox([
HBox([funcs[i] for i in range(len(models)//2)]),
HBox([funcs[i+2] for i in range(len(models)//2)])
])
]))
else:
return (VBox([
HBox([weights_dropdown, param_shape_desc, remove_zeros] ),
HBox([funcs[i] for i in range(len(models))])
]))
draw_graph(models)
```
%% Cell type:markdown id: tags:
## Visualize the weights
### Fully-connected layers
%% Cell type:code id: tags:
``` python
def view_weights(pname, model_choice):
model = models_dict[model_choice]
weights = model.state_dict()[pname]
# # Color normalization - we want all parameters to share the same color ranges in the kernel plots,
# # so we need to find the min and max across all weight tensors in the model.
# # As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity.
# extreme_vals = [list((p.max(), p.min())) for param_name, p in model.state_dict().items()
# if (p.dim()>1) and ("weight" in param_name)]
# flat = [item for sublist in extreme_vals for item in sublist]
# center = (max(flat) + min(flat)) / 2
# model_max = max(flat) - center
# model_min = min(flat) - center
params_names = fc_param_names(model)
aspect_ratio = weights.size(0) / weights.size(1)
size = 100
plot_params2d([weights], figsize=(int(size*aspect_ratio),size), binary_mask=True);
model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='Sparse')
params_names = fc_param_names(sparse_model)
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names)
interact(view_weights, pname=params_dropdown, model_choice=model_dropdown);
```
%% Cell type:code id: tags:
``` python
import matplotlib.gridspec as gridspec
def color_min_max(model, weights, color_normalization, nrow=-1, ncol=-1):
gmin, gmax = None, None
if color_normalization=='Model':
# Color normalization - we want all parameters to share the same color ranges in the kernel plots,
# so we need to find the min and max across all weight tensors in the model.
# As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity.
extreme_vals = [list((p.max(), p.min())) for param_name, p in model.state_dict().items()
if (p.dim()>1) and ("weight" in param_name)]
flat = [item for sublist in extreme_vals for item in sublist]
center = (max(flat) + min(flat)) / 2
gmax = model_max = max(flat) - center
gmin = model_min = min(flat) - center
elif color_normalization=='Tensor':
# We want to normalize the grayscale brightness levels for all of the images we display (group),
# otherwise, each image is normalized separately and this causes distortion between the different
# filters images we ddisplay.
# We don't normalize across all of the filters images, because the outliers cause the image of each
# filter to be very muted. This is because each group of filters we display usually has low variance
# between the element values of that group.
gmin = weights.min()
gmax = weights.max()
elif color_normalization=='Group':
gmin = weights[0:nrow, 0:ncol].min()
gmax = weights[0:nrow, 0:ncol].max()
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
return gmin, gmax
def plot_param_kernels(model, weights, layout, size_ctrl, binary_mask=False, color_normalization='Model',
interpolation=None, first_kernel=0):
ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3]
print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max()))
shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights)
print("size=%s = %d elements" % (shape_str, volume))
# Clone because we are going to change the tensor values
weights = weights.clone()
if binary_mask:
weights[weights!=0] = 1
kernels = weights.view(ofms * ifms, kh, kw)
nrow, ncol = layout[0], layout[1]
gmin, gmax = color_min_max(model, weights, color_normalization, nrow, ncol)
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
fig = plt.figure(figsize=(size_ctrl*8, size_ctrl*8))
# gridspec inside gridspec
outer_grid = gridspec.GridSpec(4, 4, wspace=0.05, hspace=0.05)
for i in range(4*4):
inner_grid = gridspec.GridSpecFromSubplotSpec(3, 3, subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)
for j in range(3*3):
ax = plt.Subplot(fig, inner_grid[j])
if binary_mask:
ax.matshow(kernels[first_kernel+i*4*3+j], cmap='binary', vmin=0, vmax=1);
else:
# Use siesmic so that colors around the center are lighter. Red and blue are used
# to represent (and visually separate) negative and positive weights
ax.matshow(kernels[first_kernel+i*4*3+j], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);
ax.set_xticks([])
ax.set_yticks([])
fig.add_subplot(ax);
all_axes = fig.get_axes()
#show only the outside spines
for ax in all_axes:
for sp in ax.spines.values():
sp.set_visible(False)
if ax.is_first_row():
ax.spines['top'].set_visible(True)
if ax.is_last_row():
ax.spines['bottom'].set_visible(True)
if ax.is_first_col():
ax.spines['left'].set_visible(True)
if ax.is_last_col():
ax.spines['right'].set_visible(True)
```
%% Cell type:markdown id: tags:
### Convolutional layers
%% Cell type:code id: tags:
``` python
# Some models have long node names and require longer lines
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import math
params_names = conv_param_names(sparse_model)
def view_weights(pname, model_choice, apply_mask, color_normalization, interpolation):
weights = models_dict[model_choice].state_dict()[pname]
num_kernels = weights.size(0) * weights.size(1)
first_kernel = 0
width = 15
size = int(min((num_kernels-first_kernel)//width, width))
layout=(size,width)
plot_param_kernels(model=models_dict[model_choice], weights=weights, layout=layout, size_ctrl=2,
binary_mask=apply_mask, color_normalization=color_normalization,
interpolation=interpolation, first_kernel=first_kernel);
interpolations = [None, 'none', 'nearest', 'bilinear', 'bicubic', 'spline16',
'spline36', 'hanning', 'hamming', 'hermite', 'kaiser', 'quadric',
'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos']
#model_radio = widgets.RadioButtons(options=['Sparse', 'Dense'], value='Sparse', description='Model:')
model_dropdown = Dropdown(description='Model:', options=models_dict.keys())
normalize_radio = widgets.RadioButtons(options=['Group', 'Tensor', 'Model'], value='Model', description='Normalize:')
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names)
interpolation_dropdown = widgets.Dropdown(description='Interploation:', options=interpolations, value='bilinear')
mask_choice = widgets.Checkbox(value=False, description='Binary mask')
interact(view_weights, pname=params_dropdown,
model_choice=model_dropdown, apply_mask=mask_choice,
color_normalization=normalize_radio,
interpolation=interpolation_dropdown);
```
%% Cell type:markdown id: tags:
### Kernel pruning
Look how 2D (kernel) pruning removes kernels.
Each row is a flattened view of the kernels that generate one OFM (output feature map).
%% Cell type:code id: tags:
``` python
# The default font size is too small, so let's increase it
matplotlib.rcParams.update({'font.size': 32})
params_names = conv_param_names(sparse_model)
def view_weights(pname, unused, binary_mask, model_choice):
model = models_dict[model_choice]
weights = model.state_dict()[pname]
weights = weights.view(weights.size(0), -1)
gmin, gmax = color_min_max(model, weights, color_normalization="Model")
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
plot_params2d([weights], figsize=(50,50), binary_mask=binary_mask, xlabel="#channels * k * k", ylabel="OFM", gmin=gmin, gmax=gmax);
shape = distiller.size2str(model.state_dict()[pname].size())
param_info.value = shape
shape = distiller.size2str(next (iter (dense_model.state_dict().values())).size())
param_info = widgets.Text(value=shape, description='shape:', disabled=True)
mask_choice = widgets.Checkbox(value=True, description='Binary mask')
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='features.module.6.weight')
model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='2D-Sparse')
interact(view_weights, pname=params_dropdown, unused=param_info, binary_mask=mask_choice, model_choice=model_dropdown);
```
%% Cell type:markdown id: tags:
### Let's isolate just the 2D kernels
Now let's try something slightly different: in the diagram below, we fold each kernel (k * k) into a single pixel. If the value of all of the elements in the kernel is zero, then the 2D kernel is colored white (100% sparse); otherwise, it is colored black (has at least one non-zero element in it).
%% Cell type:code id: tags:
``` python
matplotlib.rcParams.update({'font.size': 22})
params_names = conv_param_names(sparse_model)
def view_weights(pname, unused, model_choice):
model = models_dict[model_choice]
weights = model.state_dict()[pname]
k_view = weights.view(weights.size(0) * weights.size(1), -1).abs().sum(dim=1)
weights = k_view.view(weights.size(0), weights.size(1))
#gmin, gmax = color_min_max(model, weights, color_normalization="Model")
plot_params2d([weights], figsize=(10,10), binary_mask=True, xlabel="#channels", ylabel="OFM");
shape = distiller.size2str(sparse_model.state_dict()[pname].size())
param_info.value = shape
shape = distiller.size2str(next (iter (sparse_model.state_dict().values())).size())
param_info = widgets.Text(value=shape, description='shape:', disabled=True)
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='features.module.6.weight')
model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='2D-Sparse')
interact(view_weights, pname=params_dropdown, unused=param_info, model_choice=model_dropdown);
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment