Skip to content
Snippets Groups Projects
Commit aa8862bd authored by Neta Zmora's avatar Neta Zmora
Browse files

Fix PyTorch 0.4 compatability issue

Sometimes the gmin/gmax in group color-normalization ends up with a zero
dimensional tensor, which needs to be accessed using .item()
parent ac1235a5
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Interpreting your pruning and regularization experiments ## Interpreting your pruning and regularization experiments
This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:<br> This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:<br>
```%run distiller_jupyter_helpers.ipynb``` ```%run distiller_jupyter_helpers.ipynb```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Relative import of code from distiller, w/o installing the package # Relative import of code from distiller, w/o installing the package
import os import os
import sys import sys
module_path = os.path.abspath(os.path.join('..')) module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path: if module_path not in sys.path:
sys.path.append(module_path) sys.path.append(module_path)
import distiller.utils import distiller.utils
import distiller import distiller
import apputils.checkpoint import apputils.checkpoint
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import torch import torch
import torchvision import torchvision
import os import os
import collections import collections
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
def to_np(x): def to_np(x):
return x.cpu().numpy() return x.cpu().numpy()
def flatten(weights): def flatten(weights):
weights = weights.clone().view(weights.numel()) weights = weights.clone().view(weights.numel())
weights = to_np(weights) weights = to_np(weights)
return weights return weights
import scipy.stats as stats import scipy.stats as stats
def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None): def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None):
weights = flatten(weights_pytorch) weights = flatten(weights_pytorch)
if remove_zeros: if remove_zeros:
weights = weights[weights!=0] weights = weights[weights!=0]
n, bins, patches = plt.hist(weights, bins=200) n, bins, patches = plt.hist(weights, bins=200)
plt.title(name) plt.title(name)
if kmeans is not None: if kmeans is not None:
labels = kmeans.labels_ labels = kmeans.labels_
centroids = kmeans.cluster_centers_ centroids = kmeans.cluster_centers_
cnt_coefficients = [len(labels[labels==i]) for i in range(16)] cnt_coefficients = [len(labels[labels==i]) for i in range(16)]
# Normalize the coefficients so they display in the same range as the float32 histogram # Normalize the coefficients so they display in the same range as the float32 histogram
cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients] cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients]
centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients))) centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients)))
cnt_coefficients = list(cnt_coefficients) cnt_coefficients = list(cnt_coefficients)
centroids = list(centroids) centroids = list(centroids)
if remove_zeros: if remove_zeros:
for i in range(len(centroids)): for i in range(len(centroids)):
if abs(centroids[i]) < 0.0001: # almost zero if abs(centroids[i]) < 0.0001: # almost zero
centroids.remove(centroids[i]) centroids.remove(centroids[i])
cnt_coefficients.remove(cnt_coefficients[i]) cnt_coefficients.remove(cnt_coefficients[i])
break break
plt.plot(centroids, cnt_coefficients) plt.plot(centroids, cnt_coefficients)
zeros = [0] * len(centroids) zeros = [0] * len(centroids)
plt.plot(centroids, zeros, 'r+', markersize=15) plt.plot(centroids, zeros, 'r+', markersize=15)
h = cnt_coefficients h = cnt_coefficients
hmean = np.mean(h) hmean = np.mean(h)
hstd = np.std(h) hstd = np.std(h)
pdf = stats.norm.pdf(h, hmean, hstd) pdf = stats.norm.pdf(h, hmean, hstd)
#plt.plot(h, pdf) #plt.plot(h, pdf)
plt.show() plt.show()
print("mean: %f\nstddev: %f" % (weights.mean(), weights.std())) print("mean: %f\nstddev: %f" % (weights.mean(), weights.std()))
print("size=%s %d elements" % distiller.size2str(weights_pytorch.size())) print("size=%s %d elements" % distiller.size2str(weights_pytorch.size()))
print("min: %.3f\nmax:%.3f" % (weights.min(), weights.max())) print("min: %.3f\nmax:%.3f" % (weights.min(), weights.max()))
def plot_params_hist(params, which='weight', remove_zeros=False): def plot_params_hist(params, which='weight', remove_zeros=False):
for name, weights_pytorch in params.items(): for name, weights_pytorch in params.items():
if which not in name: if which not in name:
continue continue
plot_params_hist_single(name, weights_pytorch, remove_zeros) plot_params_hist_single(name, weights_pytorch, remove_zeros)
def plot_params2d(classifier_weights, figsize, binary_mask=True, def plot_params2d(classifier_weights, figsize, binary_mask=True,
gmin=None, gmax=None, gmin=None, gmax=None,
xlabel="", ylabel="", title=""): xlabel="", ylabel="", title=""):
if not isinstance(classifier_weights, list): if not isinstance(classifier_weights, list):
classifier_weights = [classifier_weights] classifier_weights = [classifier_weights]
for weights in classifier_weights: for weights in classifier_weights:
assert weights.dim() in [2,4], "something's wrong" assert weights.dim() in [2,4], "something's wrong"
shape_str = distiller.size2str(weights.size()) shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights) volume = distiller.volume(weights)
# Clone because we are going to change the tensor values # Clone because we are going to change the tensor values
if binary_mask: if binary_mask:
weights2d = weights.clone() weights2d = weights.clone()
else: else:
weights2d = weights weights2d = weights
if weights.dim() == 4: if weights.dim() == 4:
weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1) weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)
sparsity = len(weights2d[weights2d==0]) / volume sparsity = len(weights2d[weights2d==0]) / volume
cmap='seismic' cmap='seismic'
# create a binary image (non-zero elements are black; zeros are white) # create a binary image (non-zero elements are black; zeros are white)
if binary_mask: if binary_mask:
cmap='binary' cmap='binary'
weights2d[weights2d!=0] = 1 weights2d[weights2d!=0] = 1
fig = plt.figure(figsize=figsize) fig = plt.figure(figsize=figsize)
if (not binary_mask) and (gmin is not None) and (gmax is not None): if (not binary_mask) and (gmin is not None) and (gmax is not None):
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax) plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)
else: else:
plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1) plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)
#plt.figure(figsize=(20,40)) #plt.figure(figsize=(20,40))
plt.xlabel(xlabel) plt.xlabel(xlabel)
plt.ylabel(ylabel) plt.ylabel(ylabel)
plt.title(title) plt.title(title)
plt.colorbar( pad=0.01, fraction=0.01) plt.colorbar( pad=0.01, fraction=0.01)
plt.show() plt.show()
print("sparsity = %.1f%% (nnz=black)" % (sparsity*100)) print("sparsity = %.1f%% (nnz=black)" % (sparsity*100))
print("size=%s = %d elements" % (shape_str, volume)) print("size=%s = %d elements" % (shape_str, volume))
def printk(k): def printk(k):
"""Print the values of the elements of a kernel as a list""" """Print the values of the elements of a kernel as a list"""
print(list(k.view(k.numel()))) print(list(k.view(k.numel())))
def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model', def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model',
gmin=None, gmax=None, interpolation=None, first_kernel=0): gmin=None, gmax=None, interpolation=None, first_kernel=0):
ofms, ifms = weights.size()[0], weights.size()[1] ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3] kw, kh = weights.size()[2], weights.size()[3]
print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max())) print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max()))
shape_str = distiller.size2str(weights.size()) shape_str = distiller.size2str(weights.size())
volume = distiller.volume(weights) volume = distiller.volume(weights)
print("size=%s = %d elements" % (shape_str, volume)) print("size=%s = %d elements" % (shape_str, volume))
# Clone because we are going to change the tensor values # Clone because we are going to change the tensor values
weights = weights.clone() weights = weights.clone()
if binary_mask: if binary_mask:
weights[weights!=0] = 1 weights[weights!=0] = 1
# Take the inverse of the pixels, because we want zeros to appear white # Take the inverse of the pixels, because we want zeros to appear white
#weights = 1 - weights #weights = 1 - weights
kernels = weights.view(ofms * ifms, kh, kw) kernels = weights.view(ofms * ifms, kh, kw)
nrow, ncol = layout[0], layout[1] nrow, ncol = layout[0], layout[1]
# Plot the graph # Plot the graph
plt.gray() plt.gray()
#plt.tight_layout() #plt.tight_layout()
fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) ) fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );
# We want to normalize the grayscale brightness levels for all of the images we display (group), # We want to normalize the grayscale brightness levels for all of the images we display (group),
# otherwise, each image is normalized separately and this causes distortion between the different # otherwise, each image is normalized separately and this causes distortion between the different
# filters images we ddisplay. # filters images we ddisplay.
# We don't normalize across all of the filters images, because the outliers cause the image of each # We don't normalize across all of the filters images, because the outliers cause the image of each
# filter to be very muted. This is because each group of filters we display usually has low variance # filter to be very muted. This is because each group of filters we display usually has low variance
# between the element values of that group. # between the element values of that group.
if color_normalization=='Tensor': if color_normalization=='Tensor':
gmin = weights.min() gmin = weights.min()
gmax = weights.max() gmax = weights.max()
elif color_normalization=='Group': elif color_normalization=='Group':
gmin = weights[0:nrow, 0:ncol].min() gmin = weights[0:nrow, 0:ncol].min()
gmax = weights[0:nrow, 0:ncol].max() gmax = weights[0:nrow, 0:ncol].max()
print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax)) print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
if isinstance(gmin, torch.Tensor):
gmin = gmin.item()
gmax = gmax.item()
i = 0 i = 0
for row in range(0, nrow): for row in range(0, nrow):
for col in range (0, ncol): for col in range (0, ncol):
ax = fig.add_subplot(layout[0], layout[1], i+1) ax = fig.add_subplot(layout[0], layout[1], i+1)
if binary_mask: if binary_mask:
ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1); ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1);
else: else:
# Use siesmic so that colors around the center are lighter. Red and blue are used # Use siesmic so that colors around the center are lighter. Red and blue are used
# to represent (and visually separate) negative and positive weights # to represent (and visually separate) negative and positive weights
ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation); ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);
ax.set(xticks=[], yticks=[]) ax.set(xticks=[], yticks=[])
i += 1 i += 1
#plt.show();
#return fig
def l1_norm_histogram(weights): def l1_norm_histogram(weights):
"""Compute a histogram of the L1-norms of the kernels of a weights tensor. """Compute a histogram of the L1-norms of the kernels of a weights tensor.
The L1-norm of a kernel is one way to quantify the "magnitude" of the total coeffiecients The L1-norm of a kernel is one way to quantify the "magnitude" of the total coeffiecients
making up this kernel. making up this kernel.
Another interesting look at filters is to compute a histogram per filter. Another interesting look at filters is to compute a histogram per filter.
""" """
ofms, ifms = weights.size()[0], weights.size()[1] ofms, ifms = weights.size()[0], weights.size()[1]
kw, kh = weights.size()[2], weights.size()[3] kw, kh = weights.size()[2], weights.size()[3]
kernels = weights.view(ofms * ifms, kh, kw) kernels = weights.view(ofms * ifms, kh, kw)
l1_hist = [] l1_hist = []
for kernel in range(ofms*ifms): for kernel in range(ofms*ifms):
l1_hist.append(kernels[kernel].norm(1)) l1_hist.append(kernels[kernel].norm(1))
return l1_hist return l1_hist
def plot_l1_norm_hist(weights): def plot_l1_norm_hist(weights):
l1_hist = l1_norm_histogram(weights) l1_hist = l1_norm_histogram(weights)
n, bins, patches = plt.hist(l1_hist, bins=200) n, bins, patches = plt.hist(l1_hist, bins=200)
plt.title('Kernel L1-norm histograms') plt.title('Kernel L1-norm histograms')
plt.ylabel('Frequency') plt.ylabel('Frequency')
plt.xlabel('Kernel L1-norm') plt.xlabel('Kernel L1-norm')
plt.show() plt.show()
def plot_layer_sizes(which, sparse_model, dense_model): def plot_layer_sizes(which, sparse_model, dense_model):
dense = [] dense = []
sparse = [] sparse = []
names = [] names = []
for name, sparse_weights in sparse_model.state_dict().items(): for name, sparse_weights in sparse_model.state_dict().items():
if ('weight' not in name) or (which!='*' and which not in name): if ('weight' not in name) or (which!='*' and which not in name):
continue continue
sparse.append(len(sparse_weights[sparse_weights!=0])) sparse.append(len(sparse_weights[sparse_weights!=0]))
names.append(name) names.append(name)
for name, dense_weights in dense_model.state_dict().items(): for name, dense_weights in dense_model.state_dict().items():
if ('weight' not in name) or (which!='*' and which not in name): if ('weight' not in name) or (which!='*' and which not in name):
continue continue
dense.append(dense_weights.numel()) dense.append(dense_weights.numel())
N = len(sparse) N = len(sparse)
ind = np.arange(N) # the x locations for the groups ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots() fig, ax = plt.subplots()
width = .47 width = .47
p1 = plt.bar(ind, dense, width = .47, color = '#278DBC') p1 = plt.bar(ind, dense, width = .47, color = '#278DBC')
p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099') p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099')
plt.ylabel('Size') plt.ylabel('Size')
plt.title('Layer sizes') plt.title('Layer sizes')
plt.xticks(rotation='vertical') plt.xticks(rotation='vertical')
plt.xticks(ind, names) plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150)) #plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), ('Dense', 'Sparse')) plt.legend((p1[0], p2[0]), ('Dense', 'Sparse'))
#Remove plot borders #Remove plot borders
for location in ['right', 'left', 'top', 'bottom']: for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False) ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots #Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid') ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True) ax.set_axisbelow(True)
plt.show() plt.show()
def conv_param_names(model): def conv_param_names(model):
return [param_name for param_name, p in model.state_dict().items() return [param_name for param_name, p in model.state_dict().items()
if (p.dim()>2) and ("weight" in param_name)] if (p.dim()>2) and ("weight" in param_name)]
def conv_fc_param_names(model): def conv_fc_param_names(model):
return [param_name for param_name, p in model.state_dict().items() return [param_name for param_name, p in model.state_dict().items()
if (p.dim()>1) and ("weight" in param_name)] if (p.dim()>1) and ("weight" in param_name)]
def conv_fc_params(model): def conv_fc_params(model):
return [(param_name,p) for (param_name, p) in model.state_dict() return [(param_name,p) for (param_name, p) in model.state_dict()
if (p.dim()>1) and ("weight" in param_name)] if (p.dim()>1) and ("weight" in param_name)]
def fc_param_names(model): def fc_param_names(model):
return [param_name for param_name, p in model.state_dict().items() return [param_name for param_name, p in model.state_dict().items()
if (p.dim()==2) and ("weight" in param_name)] if (p.dim()==2) and ("weight" in param_name)]
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def plot_bars(which, setA, setAName, setB, setBName, names, title): def plot_bars(which, setA, setAName, setB, setBName, names, title):
N = len(setA) N = len(setA)
ind = np.arange(N) # the x locations for the groups ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots(figsize=(20,10)) fig, ax = plt.subplots(figsize=(20,10))
width = .47 width = .47
p1 = plt.bar(ind, setA, width = .47, color = '#278DBC') p1 = plt.bar(ind, setA, width = .47, color = '#278DBC')
p2 = plt.bar(ind, setB, width = 0.35, color = '#000099') p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')
plt.ylabel('Size') plt.ylabel('Size')
plt.title(title) plt.title(title)
plt.xticks(rotation='vertical') plt.xticks(rotation='vertical')
plt.xticks(ind, names) plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150)) #plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), (setAName, setBName)) plt.legend((p1[0], p2[0]), (setAName, setBName))
#Remove plot borders #Remove plot borders
for location in ['right', 'left', 'top', 'bottom']: for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False) ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots #Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid') ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True) ax.set_axisbelow(True)
plt.show() plt.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment