# Relative import of code from distiller, w/o installing the package
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
import distiller.utils
import distiller
import apputils.checkpoint
-
Neta Zmora authored
Sometimes the gmin/gmax in group color-normalization ends up with a zero dimensional tensor, which needs to be accessed using .item()
Neta Zmora authoredSometimes the gmin/gmax in group color-normalization ends up with a zero dimensional tensor, which needs to be accessed using .item()
distiller_jupyter_helpers.ipynb 13.96 KiB
Interpreting your pruning and regularization experiments
This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:
%run distiller_jupyter_helpers.ipynb
import torch
import torchvision
import os
import collections
import matplotlib.pyplot as plt
import numpy as np
def to_np(x):
return x.cpu().numpy()
def flatten(weights):
weights = weights.clone().view(weights.numel())
weights = to_np(weights)
return weights
import scipy.stats as stats
def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None):
weights = flatten(weights_pytorch)
if remove_zeros:
weights = weights[weights!=0]
n, bins, patches = plt.hist(weights, bins=200)
plt.title(name)
if kmeans is not None:
labels = kmeans.labels_
centroids = kmeans.cluster_centers_
cnt_coefficients = [len(labels[labels==i]) for i in range(16)]
# Normalize the coefficients so they display in the same range as the float32 histogram
cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients]
centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients)))
cnt_coefficients = list(cnt_coefficients)
centroids = list(centroids)
if remove_zeros:
for i in range(len(centroids)):
if abs(centroids[i]) < 0.0001: # almost zero
centroids.remove(centroids[i])
cnt_coefficients.remove(cnt_coefficients[i])
break
plt.plot(centroids, cnt_coefficients)
zeros = [0] * len(centroids)
plt.plot(centroids, zeros, 'r+', markersize=15)
h = cnt_coefficients
hmean = np.mean(h)
hstd = np.std(h)
pdf = stats.norm.pdf(h, hmean, hstd)
#plt.plot(h, pdf)
plt.show()
print("mean: %f\nstddev: %f" % (weights.mean(), weights.std()))
print("size=%s %d elements" % distiller.size2str(weights_pytorch.size()))
print("min: %.3f\nmax:%.3f" % (weights.min(), weights.max()))
def plot_params_hist(params, which='weight', remove_zeros=False):
for name, weights_pytorch in params.items():
if which not in name:
continue
plot_params_hist_single(name, weights_pytorch, remove_zeros)
def plot_params2d(classifier_weights, figsize, binary_mask=True,
gmin=None, gmax=None,
xlabel="", ylabel="", title=""):
if not isinstance(classifier_weights, list):
classifier_weights = [classifier_weights]
for weights in classifier_weights:
assert weights.dim() in [2,4], "something's wrong"
shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights)
# Clone because we are going to change the tensor values
if binary_mask:
weights2d = weights.clone()
else:
weights2d = weights
if weights.dim() == 4:
weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)
sparsity = len(weights2d[weights2d==0]) / volume
cmap='seismic'
# create a binary image (non-zero elements are black; zeros are white)
if binary_mask:
cmap='binary'
weights2d[weights2d!=0] = 1
fig = plt.figure(figsize=figsize)
if (not binary_mask) and (gmin is not None) and (gmax is not None):
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)
else:
plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)
#plt.figure(figsize=(20,40))
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.colorbar( pad=0.01, fraction=0.01)
plt.show()
print("sparsity = %.1f%% (nnz=black)" % (sparsity*100))
print("size=%s = %d elements" % (shape_str, volume))
def printk(k):
"""Print the values of the elements of a kernel as a list"""
print(list(k.view(k.numel())))
def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model',
gmin=None, gmax=None, interpolation=None, first_kernel=0):
ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3]
print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max()))
shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights)
print("size=%s = %d elements" % (shape_str, volume))
# Clone because we are going to change the tensor values
weights = weights.clone()
if binary_mask:
weights[weights!=0] = 1
# Take the inverse of the pixels, because we want zeros to appear white
#weights = 1 - weights
kernels = weights.view(ofms * ifms, kh, kw)
nrow, ncol = layout[0], layout[1]
# Plot the graph
plt.gray()
#plt.tight_layout()
fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );
# We want to normalize the grayscale brightness levels for all of the images we display (group),
# otherwise, each image is normalized separately and this causes distortion between the different
# filters images we ddisplay.
# We don't normalize across all of the filters images, because the outliers cause the image of each
# filter to be very muted. This is because each group of filters we display usually has low variance
# between the element values of that group.
if color_normalization=='Tensor':
gmin = weights.min()
gmax = weights.max()
elif color_normalization=='Group':
gmin = weights[0:nrow, 0:ncol].min()
gmax = weights[0:nrow, 0:ncol].max()
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
i = 0
for row in range(0, nrow):
for col in range (0, ncol):
ax = fig.add_subplot(layout[0], layout[1], i+1)
if binary_mask:
ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1);
else:
# Use siesmic so that colors around the center are lighter. Red and blue are used
# to represent (and visually separate) negative and positive weights
ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);
ax.set(xticks=[], yticks=[])
i += 1
def l1_norm_histogram(weights):
"""Compute a histogram of the L1-norms of the kernels of a weights tensor.
The L1-norm of a kernel is one way to quantify the "magnitude" of the total coeffiecients
making up this kernel.
Another interesting look at filters is to compute a histogram per filter.
"""
ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3]
kernels = weights.view(ofms * ifms, kh, kw)
l1_hist = []
for kernel in range(ofms*ifms):
l1_hist.append(kernels[kernel].norm(1))
return l1_hist
def plot_l1_norm_hist(weights):
l1_hist = l1_norm_histogram(weights)
n, bins, patches = plt.hist(l1_hist, bins=200)
plt.title('Kernel L1-norm histograms')
plt.ylabel('Frequency')
plt.xlabel('Kernel L1-norm')
plt.show()
def plot_layer_sizes(which, sparse_model, dense_model):
dense = []
sparse = []
names = []
for name, sparse_weights in sparse_model.state_dict().items():
if ('weight' not in name) or (which!='*' and which not in name):
continue
sparse.append(len(sparse_weights[sparse_weights!=0]))
names.append(name)
for name, dense_weights in dense_model.state_dict().items():
if ('weight' not in name) or (which!='*' and which not in name):
continue
dense.append(dense_weights.numel())
N = len(sparse)
ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots()
width = .47
p1 = plt.bar(ind, dense, width = .47, color = '#278DBC')
p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099')
plt.ylabel('Size')
plt.title('Layer sizes')
plt.xticks(rotation='vertical')
plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), ('Dense', 'Sparse'))
#Remove plot borders
for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True)
plt.show()
def conv_param_names(model):
return [param_name for param_name, p in model.state_dict().items()
if (p.dim()>2) and ("weight" in param_name)]
def conv_fc_param_names(model):
return [param_name for param_name, p in model.state_dict().items()
if (p.dim()>1) and ("weight" in param_name)]
def conv_fc_params(model):
return [(param_name,p) for (param_name, p) in model.state_dict()
if (p.dim()>1) and ("weight" in param_name)]
def fc_param_names(model):
return [param_name for param_name, p in model.state_dict().items()
if (p.dim()==2) and ("weight" in param_name)]
def plot_bars(which, setA, setAName, setB, setBName, names, title):
N = len(setA)
ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots(figsize=(20,10))
width = .47
p1 = plt.bar(ind, setA, width = .47, color = '#278DBC')
p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')
plt.ylabel('Size')
plt.title(title)
plt.xticks(rotation='vertical')
plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), (setAName, setBName))
#Remove plot borders
for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True)
plt.show()