Skip to content
Snippets Groups Projects
Commit a89dfe6b authored by Neta Zmora's avatar Neta Zmora
Browse files

Jupyter notebook: Fix PyTorch 0.4 compatability issue when rendering

Sometimes the gmin/gmax in group color-normalization ends up with a zero
dimensional tensor, which needs to be accessed using .item()
parent aa8862bd
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Alexnet insights: visualizing the pruning process # Alexnet insights: visualizing the pruning process
This notebook examines the results of pruning Alexnet using sensitivity pruning, through a few chosen visualizations created from checkpoints created during pruning. We also compare the results of an element-wise pruning session, with 2D (kernel) regularization. This notebook examines the results of pruning Alexnet using sensitivity pruning, through a few chosen visualizations created from checkpoints created during pruning. We also compare the results of an element-wise pruning session, with 2D (kernel) regularization.
For the notebook, we pruned Alexnet using sensitivity pruning and captured the checkpoints after the first epoch ends (epoch 0) and after the last epoch ends (epoch 89). For the notebook, we pruned Alexnet using sensitivity pruning and captured the checkpoints after the first epoch ends (epoch 0) and after the last epoch ends (epoch 89).
You can download these checkpoints from here: You can download these checkpoints from here:
* https://s3-us-west-1.amazonaws.com/nndistiller/sensitivity-pruning/alexnet.checkpoint.0.pth.tar * https://s3-us-west-1.amazonaws.com/nndistiller/sensitivity-pruning/alexnet.checkpoint.0.pth.tar
* https://s3-us-west-1.amazonaws.com/nndistiller/sensitivity-pruning/alexnet.checkpoint.89.pth.tar * https://s3-us-west-1.amazonaws.com/nndistiller/sensitivity-pruning/alexnet.checkpoint.89.pth.tar
* https://s3-us-west-1.amazonaws.com/nndistiller/hybrid/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar * https://s3-us-west-1.amazonaws.com/nndistiller/hybrid/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar
## Table of Contents ## Table of Contents
1. [Load the training checkpoints](#Load-the-training-checkpoints) 1. [Load the training checkpoints](#Load-the-training-checkpoints)
2. [Let's see some statistics](#Let's-see-some-statistics) 2. [Let's see some statistics](#Let's-see-some-statistics)
3. [Compare weights distributions](#Compare-weights-distributions) 3. [Compare weights distributions](#Compare-weights-distributions)
4. [Visualize the weights](#Visualize-the-weights)<br> 4. [Visualize the weights](#Visualize-the-weights)<br>
4.1. [Fully-connected layers](#Fully-connected-layers)<br> 4.1. [Fully-connected layers](#Fully-connected-layers)<br>
4.2. [Convolutional layers](#Convolutional-layers)<br> 4.2. [Convolutional layers](#Convolutional-layers)<br>
4.3. [Kernel pruning](#Kernel-pruning)<br> 4.3. [Kernel pruning](#Kernel-pruning)<br>
4.4. [Let's isolate just the 2D kernels](#Let's-isolate-just-the-2D-kernels) 4.4. [Let's isolate just the 2D kernels](#Let's-isolate-just-the-2D-kernels)
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import matplotlib import matplotlib
# Load some common jupyter code # Load some common jupyter code
%run distiller_jupyter_helpers.ipynb %run distiller_jupyter_helpers.ipynb
from models import create_model from models import create_model
from apputils import * from apputils import *
import qgrid import qgrid
from ipywidgets import * from ipywidgets import *
from bqplot import * from bqplot import *
import bqplot.pyplot as bqplt import bqplot.pyplot as bqplt
from functools import partial from functools import partial
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Load the training checkpoints ## Load the training checkpoints
Load the checkpoint captured after one pruning event, and fine-tuning for one epoch: Load the checkpoint captured after one pruning event, and fine-tuning for one epoch:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
epoch0_model = create_model(False, 'imagenet', 'alexnet', parallel=True) epoch0_model = create_model(False, 'imagenet', 'alexnet', parallel=True)
checkpoint_file = "../examples/classifier_compression/alexnet.checkpoint.0.pth.tar" checkpoint_file = "../examples/classifier_compression/alexnet.checkpoint.0.pth.tar"
try: try:
load_checkpoint(epoch0_model, checkpoint_file); load_checkpoint(epoch0_model, checkpoint_file);
except NameError as e: except NameError as e:
print("Did you forget to download the checkpoint file?") print("Did you forget to download the checkpoint file?")
raise e raise e
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Load the checkpoint captured at the end of the pruning and fine-tuning process: Load the checkpoint captured at the end of the pruning and fine-tuning process:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
epoch89_model = create_model(False, 'imagenet', 'alexnet', parallel=True) epoch89_model = create_model(False, 'imagenet', 'alexnet', parallel=True)
checkpoint_file = "../examples/classifier_compression/alexnet.checkpoint.89.pth.tar" checkpoint_file = "../examples/classifier_compression/alexnet.checkpoint.89.pth.tar"
try: try:
load_checkpoint(epoch89_model, checkpoint_file); load_checkpoint(epoch89_model, checkpoint_file);
except Exception as e: except Exception as e:
print("Did you forget to download the checkpoint file?") print("Did you forget to download the checkpoint file?")
raise e raise e
sparse_model = epoch89_model sparse_model = epoch89_model
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Load a pre-trained dense Alexnet: Load a pre-trained dense Alexnet:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
pretrained_model = create_model(True, 'imagenet', 'alexnet', parallel=True) pretrained_model = create_model(True, 'imagenet', 'alexnet', parallel=True)
dense_model = pretrained_model dense_model = pretrained_model
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We want to compare the output of element-wise pruning, to a similar schedule but which also adds 2D (kernel-wise) Lasso regularization, so we load the last checkpoint of that pruning session. We want to compare the output of element-wise pruning, to a similar schedule but which also adds 2D (kernel-wise) Lasso regularization, so we load the last checkpoint of that pruning session.
The schedule is available at: ```distiller/examples/hybrid/alexnet.schedule_sensitivity_2D-reg.yaml```. The schedule is available at: ```distiller/examples/hybrid/alexnet.schedule_sensitivity_2D-reg.yaml```.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
reg2D_model = create_model(False, 'imagenet', 'alexnet', parallel=True) reg2D_model = create_model(False, 'imagenet', 'alexnet', parallel=True)
checkpoint_file = "../examples/classifier_compression/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar" checkpoint_file = "../examples/classifier_compression/checkpoint.alexnet.schedule_sensitivity_2D-reg.pth.tar"
try: try:
load_checkpoint(reg2D_model, checkpoint_file); load_checkpoint(reg2D_model, checkpoint_file);
except Exception as e: except Exception as e:
print("Did you forget to download the checkpoint file?") print("Did you forget to download the checkpoint file?")
raise e raise e
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Create a dictionary of the models, with name as key, so that we can refer to it later: Create a dictionary of the models, with name as key, so that we can refer to it later:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
models_dict = {'Dense': dense_model, 'Sparse': sparse_model, 'Epoch 0': epoch0_model, '2D-Sparse': reg2D_model} models_dict = {'Dense': dense_model, 'Sparse': sparse_model, 'Epoch 0': epoch0_model, '2D-Sparse': reg2D_model}
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Let's see some statistics ## Let's see some statistics
You can use the dropwdown widget to choose which model to display. You can use the dropwdown widget to choose which model to display.
You can also choose to display the sparsity or the density of the tensors. These are reported for various granularities (structures) of sparsities. You can also choose to display the sparsity or the density of the tensors. These are reported for various granularities (structures) of sparsities.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def view_data(what, model_choice): def view_data(what, model_choice):
df_sparsity = distiller.weights_sparsity_summary(models_dict[model_choice]) df_sparsity = distiller.weights_sparsity_summary(models_dict[model_choice])
if what == 'Density': if what == 'Density':
for granularity in ['Fine (%)', 'Ch (%)', '2D (%)', '3D (%)']: for granularity in ['Fine (%)', 'Ch (%)', '2D (%)', '3D (%)']:
df_sparsity[granularity] = 100 - df_sparsity[granularity] df_sparsity[granularity] = 100 - df_sparsity[granularity]
display(df_sparsity) display(df_sparsity)
model_dropdown = Dropdown(description='Model:', options=models_dict.keys()) model_dropdown = Dropdown(description='Model:', options=models_dict.keys())
display_radio = widgets.RadioButtons(options=['Sparsity', 'Density'], value='Sparsity', description='Display:') display_radio = widgets.RadioButtons(options=['Sparsity', 'Density'], value='Sparsity', description='Display:')
interact(view_data, what=display_radio, model_choice=model_dropdown); interact(view_data, what=display_radio, model_choice=model_dropdown);
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Compare weights distributions ## Compare weights distributions
Compare the distributions of the weight tensors in the sparse and dense models Compare the distributions of the weight tensors in the sparse and dense models
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
nbins = 100 nbins = 100
def get_hist2(model, nbins, param_name, remove_zeros): def get_hist2(model, nbins, param_name, remove_zeros):
tensor = flatten(model.state_dict()[param_name]) tensor = flatten(model.state_dict()[param_name])
if remove_zeros: if remove_zeros:
tensor = tensor[tensor != 0] tensor = tensor[tensor != 0]
hist, edges = np.histogram(tensor, bins=nbins, density=False) hist, edges = np.histogram(tensor, bins=nbins, density=False)
return hist, edges return hist, edges
def graph_setup(models, titles, param_name, remove_zeros): def graph_setup(models, titles, param_name, remove_zeros):
#xs, ys = [LinearScale(), LinearScale()], [LinearScale(), LinearScale()] #xs, ys = [LinearScale(), LinearScale()], [LinearScale(), LinearScale()]
xs = [LinearScale() for i in range(len(models))] xs = [LinearScale() for i in range(len(models))]
ys = [LinearScale() for i in range(len(models))] ys = [LinearScale() for i in range(len(models))]
xax = [Axis(scale=xs[0])] * len(models) xax = [Axis(scale=xs[0])] * len(models)
yax = [Axis(scale=ys[0], orientation='vertical', grid_lines='solid')] * len(models) yax = [Axis(scale=ys[0], orientation='vertical', grid_lines='solid')] * len(models)
bars = [] bars = []
funcs = [] funcs = []
for i in range(len(models)): for i in range(len(models)):
hist, edges = get_hist2(models[i], nbins, param_name, remove_zeros) hist, edges = get_hist2(models[i], nbins, param_name, remove_zeros)
bars.append(Bars(x=edges, y=[hist], scales={'x': xs[i], 'y': ys[i]}, padding=0.2, type='grouped')) bars.append(Bars(x=edges, y=[hist], scales={'x': xs[i], 'y': ys[i]}, padding=0.2, type='grouped'))
funcs.append(Figure(marks=[bars[i]], axes=[xax[i], yax[i]], animation_duration=1000, title=titles[i])) funcs.append(Figure(marks=[bars[i]], axes=[xax[i], yax[i]], animation_duration=1000, title=titles[i]))
shape = distiller.size2str(next (iter (models[0].state_dict().values())).size()) shape = distiller.size2str(next (iter (models[0].state_dict().values())).size())
param_info = widgets.Text(value=shape, description='shape:', disabled=True) param_info = widgets.Text(value=shape, description='shape:', disabled=True)
return bars, funcs, param_info return bars, funcs, param_info
params_names = conv_fc_param_names(sparse_model) params_names = conv_fc_param_names(sparse_model)
weights_dropdown = Dropdown(description='weights', options=params_names) weights_dropdown = Dropdown(description='weights', options=params_names)
def update_models(stam, bars, funcs, param_shape_desc, models): def update_models(stam, bars, funcs, param_shape_desc, models):
param_name = weights_dropdown.value param_name = weights_dropdown.value
for i in range(len(models)): for i in range(len(models)):
bars[i].y, bars[i].x = get_hist2(models[i], nbins, param_name, remove_zeros.value) bars[i].y, bars[i].x = get_hist2(models[i], nbins, param_name, remove_zeros.value)
shape = distiller.size2str(models[0].state_dict()[param_name].size()) shape = distiller.size2str(models[0].state_dict()[param_name].size())
param_shape_desc.value = shape param_shape_desc.value = shape
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
titles = ['Dense', 'Epoch 0', 'Sparse', '2D-Sparse'] titles = ['Dense', 'Epoch 0', 'Sparse', '2D-Sparse']
models = [models_dict[title] for title in titles] models = [models_dict[title] for title in titles]
bars, funcs, param_shape_desc = graph_setup(models, titles, param_name=weights_dropdown.value, remove_zeros=False) bars, funcs, param_shape_desc = graph_setup(models, titles, param_name=weights_dropdown.value, remove_zeros=False)
update1 = partial(update_models, bars=bars, funcs=funcs, param_shape_desc = param_shape_desc, models=models) update1 = partial(update_models, bars=bars, funcs=funcs, param_shape_desc = param_shape_desc, models=models)
weights_dropdown.observe(update1) weights_dropdown.observe(update1)
remove_zeros = widgets.Checkbox(value=False, description='Remove zeros') remove_zeros = widgets.Checkbox(value=False, description='Remove zeros')
remove_zeros.observe(update1) remove_zeros.observe(update1)
def draw_graph(models): def draw_graph(models):
if len(models) > 2: if len(models) > 2:
return (VBox([ return (VBox([
HBox([weights_dropdown, param_shape_desc, remove_zeros] ), HBox([weights_dropdown, param_shape_desc, remove_zeros] ),
VBox([ VBox([
HBox([funcs[i] for i in range(len(models)//2)]), HBox([funcs[i] for i in range(len(models)//2)]),
HBox([funcs[i+2] for i in range(len(models)//2)]) HBox([funcs[i+2] for i in range(len(models)//2)])
]) ])
])) ]))
else: else:
return (VBox([ return (VBox([
HBox([weights_dropdown, param_shape_desc, remove_zeros] ), HBox([weights_dropdown, param_shape_desc, remove_zeros] ),
HBox([funcs[i] for i in range(len(models))]) HBox([funcs[i] for i in range(len(models))])
])) ]))
draw_graph(models) draw_graph(models)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Visualize the weights ## Visualize the weights
### Fully-connected layers ### Fully-connected layers
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def view_weights(pname, model_choice): def view_weights(pname, model_choice):
model = models_dict[model_choice] model = models_dict[model_choice]
weights = model.state_dict()[pname] weights = model.state_dict()[pname]
# # Color normalization - we want all parameters to share the same color ranges in the kernel plots, # # Color normalization - we want all parameters to share the same color ranges in the kernel plots,
# # so we need to find the min and max across all weight tensors in the model. # # so we need to find the min and max across all weight tensors in the model.
# # As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity. # # As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity.
# extreme_vals = [list((p.max(), p.min())) for param_name, p in model.state_dict().items() # extreme_vals = [list((p.max(), p.min())) for param_name, p in model.state_dict().items()
# if (p.dim()>1) and ("weight" in param_name)] # if (p.dim()>1) and ("weight" in param_name)]
# flat = [item for sublist in extreme_vals for item in sublist] # flat = [item for sublist in extreme_vals for item in sublist]
# center = (max(flat) + min(flat)) / 2 # center = (max(flat) + min(flat)) / 2
# model_max = max(flat) - center # model_max = max(flat) - center
# model_min = min(flat) - center # model_min = min(flat) - center
params_names = fc_param_names(model) params_names = fc_param_names(model)
aspect_ratio = weights.size(0) / weights.size(1) aspect_ratio = weights.size(0) / weights.size(1)
size = 100 size = 100
plot_params2d([weights], figsize=(int(size*aspect_ratio),size), binary_mask=True); plot_params2d([weights], figsize=(int(size*aspect_ratio),size), binary_mask=True);
model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='Sparse') model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='Sparse')
params_names = fc_param_names(sparse_model) params_names = fc_param_names(sparse_model)
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names) params_dropdown = widgets.Dropdown(description='Weights:', options=params_names)
interact(view_weights, pname=params_dropdown, model_choice=model_dropdown); interact(view_weights, pname=params_dropdown, model_choice=model_dropdown);
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
def color_min_max(model, weights, color_normalization, nrow=-1, ncol=-1): def color_min_max(model, weights, color_normalization, nrow=-1, ncol=-1):
gmin, gmax = None, None gmin, gmax = None, None
if color_normalization=='Model': if color_normalization=='Model':
# Color normalization - we want all parameters to share the same color ranges in the kernel plots, # Color normalization - we want all parameters to share the same color ranges in the kernel plots,
# so we need to find the min and max across all weight tensors in the model. # so we need to find the min and max across all weight tensors in the model.
# As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity. # As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity.
extreme_vals = [list((p.max(), p.min())) for param_name, p in model.state_dict().items() extreme_vals = [list((p.max(), p.min())) for param_name, p in model.state_dict().items()
if (p.dim()>1) and ("weight" in param_name)] if (p.dim()>1) and ("weight" in param_name)]
flat = [item for sublist in extreme_vals for item in sublist] flat = [item for sublist in extreme_vals for item in sublist]
center = (max(flat) + min(flat)) / 2 center = (max(flat) + min(flat)) / 2
gmax = model_max = max(flat) - center gmax = model_max = max(flat) - center
gmin = model_min = min(flat) - center gmin = model_min = min(flat) - center
elif color_normalization=='Tensor': elif color_normalization=='Tensor':
# We want to normalize the grayscale brightness levels for all of the images we display (group), # We want to normalize the grayscale brightness levels for all of the images we display (group),
# otherwise, each image is normalized separately and this causes distortion between the different # otherwise, each image is normalized separately and this causes distortion between the different
# filters images we ddisplay. # filters images we ddisplay.
# We don't normalize across all of the filters images, because the outliers cause the image of each # We don't normalize across all of the filters images, because the outliers cause the image of each
# filter to be very muted. This is because each group of filters we display usually has low variance # filter to be very muted. This is because each group of filters we display usually has low variance
# between the element values of that group. # between the element values of that group.
gmin = weights.min() gmin = weights.min()
gmax = weights.max() gmax = weights.max()
elif color_normalization=='Group': elif color_normalization=='Group':
gmin = weights[0:nrow, 0:ncol].min() gmin = weights[0:nrow, 0:ncol].min()
gmax = weights[0:nrow, 0:ncol].max() gmax = weights[0:nrow, 0:ncol].max()
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
return gmin, gmax return gmin, gmax
def plot_param_kernels(model, weights, layout, size_ctrl, binary_mask=False, color_normalization='Model', def plot_param_kernels(model, weights, layout, size_ctrl, binary_mask=False, color_normalization='Model',
interpolation=None, first_kernel=0): interpolation=None, first_kernel=0):
ofms, ifms = weights.size()[0], weights.size()[1] ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3] kw, kh = weights.size()[2], weights.size()[3]
print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max())) print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max()))
shape_str = distiller.size2str(weights.size()) shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights) volume = distiller.volume(weights)
print("size=%s = %d elements" % (shape_str, volume)) print("size=%s = %d elements" % (shape_str, volume))
# Clone because we are going to change the tensor values # Clone because we are going to change the tensor values
weights = weights.clone() weights = weights.clone()
if binary_mask: if binary_mask:
weights[weights!=0] = 1 weights[weights!=0] = 1
kernels = weights.view(ofms * ifms, kh, kw) kernels = weights.view(ofms * ifms, kh, kw)
nrow, ncol = layout[0], layout[1] nrow, ncol = layout[0], layout[1]
gmin, gmax = color_min_max(model, weights, color_normalization, nrow, ncol) gmin, gmax = color_min_max(model, weights, color_normalization, nrow, ncol)
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax)) print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
fig = plt.figure(figsize=(size_ctrl*8, size_ctrl*8)) fig = plt.figure(figsize=(size_ctrl*8, size_ctrl*8))
# gridspec inside gridspec # gridspec inside gridspec
outer_grid = gridspec.GridSpec(4, 4, wspace=0.05, hspace=0.05) outer_grid = gridspec.GridSpec(4, 4, wspace=0.05, hspace=0.05)
for i in range(4*4): for i in range(4*4):
inner_grid = gridspec.GridSpecFromSubplotSpec(3, 3, subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0) inner_grid = gridspec.GridSpecFromSubplotSpec(3, 3, subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)
for j in range(3*3): for j in range(3*3):
ax = plt.Subplot(fig, inner_grid[j]) ax = plt.Subplot(fig, inner_grid[j])
if binary_mask: if binary_mask:
ax.matshow(kernels[first_kernel+i*4*3+j], cmap='binary', vmin=0, vmax=1); ax.matshow(kernels[first_kernel+i*4*3+j], cmap='binary', vmin=0, vmax=1);
else: else:
# Use siesmic so that colors around the center are lighter. Red and blue are used # Use siesmic so that colors around the center are lighter. Red and blue are used
# to represent (and visually separate) negative and positive weights # to represent (and visually separate) negative and positive weights
ax.matshow(kernels[first_kernel+i*4*3+j], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation); ax.matshow(kernels[first_kernel+i*4*3+j], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);
ax.set_xticks([]) ax.set_xticks([])
ax.set_yticks([]) ax.set_yticks([])
fig.add_subplot(ax); fig.add_subplot(ax);
all_axes = fig.get_axes() all_axes = fig.get_axes()
#show only the outside spines #show only the outside spines
for ax in all_axes: for ax in all_axes:
for sp in ax.spines.values(): for sp in ax.spines.values():
sp.set_visible(False) sp.set_visible(False)
if ax.is_first_row(): if ax.is_first_row():
ax.spines['top'].set_visible(True) ax.spines['top'].set_visible(True)
if ax.is_last_row(): if ax.is_last_row():
ax.spines['bottom'].set_visible(True) ax.spines['bottom'].set_visible(True)
if ax.is_first_col(): if ax.is_first_col():
ax.spines['left'].set_visible(True) ax.spines['left'].set_visible(True)
if ax.is_last_col(): if ax.is_last_col():
ax.spines['right'].set_visible(True) ax.spines['right'].set_visible(True)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Convolutional layers ### Convolutional layers
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Some models have long node names and require longer lines # Some models have long node names and require longer lines
from IPython.core.display import display, HTML from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>")) display(HTML("<style>.container { width:100% !important; }</style>"))
import math import math
params_names = conv_param_names(sparse_model) params_names = conv_param_names(sparse_model)
def view_weights(pname, model_choice, apply_mask, color_normalization, interpolation): def view_weights(pname, model_choice, apply_mask, color_normalization, interpolation):
weights = models_dict[model_choice].state_dict()[pname] weights = models_dict[model_choice].state_dict()[pname]
num_kernels = weights.size(0) * weights.size(1) num_kernels = weights.size(0) * weights.size(1)
first_kernel = 0 first_kernel = 0
width = 15 width = 15
size = int(min((num_kernels-first_kernel)//width, width)) size = int(min((num_kernels-first_kernel)//width, width))
layout=(size,width) layout=(size,width)
plot_param_kernels(model=models_dict[model_choice], weights=weights, layout=layout, size_ctrl=2, plot_param_kernels(model=models_dict[model_choice], weights=weights, layout=layout, size_ctrl=2,
binary_mask=apply_mask, color_normalization=color_normalization, binary_mask=apply_mask, color_normalization=color_normalization,
interpolation=interpolation, first_kernel=first_kernel); interpolation=interpolation, first_kernel=first_kernel);
interpolations = [None, 'none', 'nearest', 'bilinear', 'bicubic', 'spline16', interpolations = [None, 'none', 'nearest', 'bilinear', 'bicubic', 'spline16',
'spline36', 'hanning', 'hamming', 'hermite', 'kaiser', 'quadric', 'spline36', 'hanning', 'hamming', 'hermite', 'kaiser', 'quadric',
'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos'] 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos']
#model_radio = widgets.RadioButtons(options=['Sparse', 'Dense'], value='Sparse', description='Model:') #model_radio = widgets.RadioButtons(options=['Sparse', 'Dense'], value='Sparse', description='Model:')
model_dropdown = Dropdown(description='Model:', options=models_dict.keys()) model_dropdown = Dropdown(description='Model:', options=models_dict.keys())
normalize_radio = widgets.RadioButtons(options=['Group', 'Tensor', 'Model'], value='Model', description='Normalize:') normalize_radio = widgets.RadioButtons(options=['Group', 'Tensor', 'Model'], value='Model', description='Normalize:')
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names) params_dropdown = widgets.Dropdown(description='Weights:', options=params_names)
interpolation_dropdown = widgets.Dropdown(description='Interploation:', options=interpolations, value='bilinear') interpolation_dropdown = widgets.Dropdown(description='Interploation:', options=interpolations, value='bilinear')
mask_choice = widgets.Checkbox(value=False, description='Binary mask') mask_choice = widgets.Checkbox(value=False, description='Binary mask')
interact(view_weights, pname=params_dropdown, interact(view_weights, pname=params_dropdown,
model_choice=model_dropdown, apply_mask=mask_choice, model_choice=model_dropdown, apply_mask=mask_choice,
color_normalization=normalize_radio, color_normalization=normalize_radio,
interpolation=interpolation_dropdown); interpolation=interpolation_dropdown);
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Kernel pruning ### Kernel pruning
Look how 2D (kernel) pruning removes kernels. Look how 2D (kernel) pruning removes kernels.
Each row is a flattened view of the kernels that generate one OFM (output feature map). Each row is a flattened view of the kernels that generate one OFM (output feature map).
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# The default font size is too small, so let's increase it # The default font size is too small, so let's increase it
matplotlib.rcParams.update({'font.size': 32}) matplotlib.rcParams.update({'font.size': 32})
params_names = conv_param_names(sparse_model) params_names = conv_param_names(sparse_model)
def view_weights(pname, unused, binary_mask, model_choice): def view_weights(pname, unused, binary_mask, model_choice):
model = models_dict[model_choice] model = models_dict[model_choice]
weights = model.state_dict()[pname] weights = model.state_dict()[pname]
weights = weights.view(weights.size(0), -1) weights = weights.view(weights.size(0), -1)
gmin, gmax = color_min_max(model, weights, color_normalization="Model") gmin, gmax = color_min_max(model, weights, color_normalization="Model")
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax)) print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
plot_params2d([weights], figsize=(50,50), binary_mask=binary_mask, xlabel="#channels * k * k", ylabel="OFM", gmin=gmin, gmax=gmax); plot_params2d([weights], figsize=(50,50), binary_mask=binary_mask, xlabel="#channels * k * k", ylabel="OFM", gmin=gmin, gmax=gmax);
shape = distiller.size2str(model.state_dict()[pname].size()) shape = distiller.size2str(model.state_dict()[pname].size())
param_info.value = shape param_info.value = shape
shape = distiller.size2str(next (iter (dense_model.state_dict().values())).size()) shape = distiller.size2str(next (iter (dense_model.state_dict().values())).size())
param_info = widgets.Text(value=shape, description='shape:', disabled=True) param_info = widgets.Text(value=shape, description='shape:', disabled=True)
mask_choice = widgets.Checkbox(value=True, description='Binary mask') mask_choice = widgets.Checkbox(value=True, description='Binary mask')
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='features.module.6.weight') params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='features.module.6.weight')
model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='2D-Sparse') model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='2D-Sparse')
interact(view_weights, pname=params_dropdown, unused=param_info, binary_mask=mask_choice, model_choice=model_dropdown); interact(view_weights, pname=params_dropdown, unused=param_info, binary_mask=mask_choice, model_choice=model_dropdown);
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Let's isolate just the 2D kernels ### Let's isolate just the 2D kernels
Now let's try something slightly different: in the diagram below, we fold each kernel (k * k) into a single pixel. If the value of all of the elements in the kernel is zero, then the 2D kernel is colored white (100% sparse); otherwise, it is colored black (has at least one non-zero element in it). Now let's try something slightly different: in the diagram below, we fold each kernel (k * k) into a single pixel. If the value of all of the elements in the kernel is zero, then the 2D kernel is colored white (100% sparse); otherwise, it is colored black (has at least one non-zero element in it).
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
matplotlib.rcParams.update({'font.size': 22}) matplotlib.rcParams.update({'font.size': 22})
params_names = conv_param_names(sparse_model) params_names = conv_param_names(sparse_model)
def view_weights(pname, unused, model_choice): def view_weights(pname, unused, model_choice):
model = models_dict[model_choice] model = models_dict[model_choice]
weights = model.state_dict()[pname] weights = model.state_dict()[pname]
k_view = weights.view(weights.size(0) * weights.size(1), -1).abs().sum(dim=1) k_view = weights.view(weights.size(0) * weights.size(1), -1).abs().sum(dim=1)
weights = k_view.view(weights.size(0), weights.size(1)) weights = k_view.view(weights.size(0), weights.size(1))
#gmin, gmax = color_min_max(model, weights, color_normalization="Model") #gmin, gmax = color_min_max(model, weights, color_normalization="Model")
plot_params2d([weights], figsize=(10,10), binary_mask=True, xlabel="#channels", ylabel="OFM"); plot_params2d([weights], figsize=(10,10), binary_mask=True, xlabel="#channels", ylabel="OFM");
shape = distiller.size2str(sparse_model.state_dict()[pname].size()) shape = distiller.size2str(sparse_model.state_dict()[pname].size())
param_info.value = shape param_info.value = shape
shape = distiller.size2str(next (iter (sparse_model.state_dict().values())).size()) shape = distiller.size2str(next (iter (sparse_model.state_dict().values())).size())
param_info = widgets.Text(value=shape, description='shape:', disabled=True) param_info = widgets.Text(value=shape, description='shape:', disabled=True)
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='features.module.6.weight') params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='features.module.6.weight')
model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='2D-Sparse') model_dropdown = Dropdown(description='Model:', options=models_dict.keys(), value='2D-Sparse')
interact(view_weights, pname=params_dropdown, unused=param_info, model_choice=model_dropdown); interact(view_weights, pname=params_dropdown, unused=param_info, model_choice=model_dropdown);
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment