Skip to content
Snippets Groups Projects
distiller_jupyter_helpers.ipynb 13.96 KiB

Interpreting your pruning and regularization experiments

This notebook contains code to be included in your own notebooks by adding this line at the top of your notebook:
%run distiller_jupyter_helpers.ipynb

# Relative import of code from distiller, w/o installing the package
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import distiller.utils
import distiller
import apputils.checkpoint 
import torch
import torchvision
import os
import collections
import matplotlib.pyplot as plt
import numpy as np


def to_np(x):
    return x.cpu().numpy()

def flatten(weights):
    weights = weights.clone().view(weights.numel())
    weights = to_np(weights)
    return weights


import scipy.stats as stats
def plot_params_hist_single(name, weights_pytorch, remove_zeros=False, kmeans=None):
    weights = flatten(weights_pytorch)
    if remove_zeros:
        weights = weights[weights!=0]
    n, bins, patches = plt.hist(weights, bins=200)
    plt.title(name)
    
    if kmeans is not None:
        labels = kmeans.labels_
        centroids = kmeans.cluster_centers_
        cnt_coefficients = [len(labels[labels==i]) for i in range(16)]
        # Normalize the coefficients so they display in the same range as the float32 histogram
        cnt_coefficients = [cnt / 5 for cnt in cnt_coefficients] 
        centroids, cnt_coefficients = zip(*sorted(zip(centroids, cnt_coefficients)))
        cnt_coefficients = list(cnt_coefficients)
        centroids = list(centroids)
        if remove_zeros:
            for i in range(len(centroids)):
                if abs(centroids[i]) < 0.0001:  # almost zero
                    centroids.remove(centroids[i])
                    cnt_coefficients.remove(cnt_coefficients[i])
                    break
        
        plt.plot(centroids, cnt_coefficients)
        zeros = [0] * len(centroids)
        plt.plot(centroids, zeros, 'r+', markersize=15)
        
        h = cnt_coefficients
        hmean = np.mean(h)
        hstd = np.std(h)
        pdf = stats.norm.pdf(h, hmean, hstd)
        #plt.plot(h, pdf)
        
    plt.show()
    print("mean: %f\nstddev: %f" % (weights.mean(), weights.std()))
    print("size=%s %d elements" % distiller.size2str(weights_pytorch.size()))
    print("min: %.3f\nmax:%.3f" % (weights.min(), weights.max()))

    
def plot_params_hist(params, which='weight', remove_zeros=False):      
    for name, weights_pytorch in params.items():
        if which not in name:
            continue
        plot_params_hist_single(name, weights_pytorch, remove_zeros)
        
def plot_params2d(classifier_weights, figsize, binary_mask=True, 
                  gmin=None, gmax=None,
                  xlabel="", ylabel="", title=""):
    if not isinstance(classifier_weights, list):
        classifier_weights = [classifier_weights]
    
    for weights in classifier_weights:
        assert weights.dim() in [2,4], "something's wrong"
        
        shape_str = distiller.size2str(weights.size())
        volume = distiller.volume(weights)
        
        # Clone because we are going to change the tensor values
        if binary_mask:
            weights2d = weights.clone()
        else:
            weights2d = weights
 
        if weights.dim() == 4:
            weights2d = weights2d.view(weights.size()[0] * weights.size()[1], -1)

        sparsity = len(weights2d[weights2d==0]) / volume
        
        cmap='seismic'
        # create a binary image (non-zero elements are black; zeros are white)
        if binary_mask:
            cmap='binary'
            weights2d[weights2d!=0] = 1
                    
        fig = plt.figure(figsize=figsize)
        if (not binary_mask) and (gmin is not None) and (gmax is not None):
            if isinstance(gmin, torch.Tensor):
                gmin = gmin.item()
                gmax = gmax.item()
            plt.imshow(weights2d, cmap=cmap, vmin=gmin, vmax=gmax)
        else:
            plt.imshow(weights2d, cmap=cmap, vmin=0, vmax=1)
        #plt.figure(figsize=(20,40))
        
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.title(title)
        plt.colorbar( pad=0.01, fraction=0.01)
        plt.show()
        print("sparsity = %.1f%% (nnz=black)" % (sparsity*100))
        print("size=%s = %d elements" % (shape_str, volume))
        
        
def printk(k):
    """Print the values of the elements of a kernel as a list"""
    print(list(k.view(k.numel())))

    
def plot_param_kernels(weights, layout, size_ctrl, binary_mask=False, color_normalization='Model', 
                       gmin=None, gmax=None, interpolation=None, first_kernel=0):
    ofms, ifms = weights.size()[0], weights.size()[1]
    kw, kh = weights.size()[2], weights.size()[3]
    
    print("min=%.4f\tmax=%.4f" % (weights.min(), weights.max()))
    shape_str = distiller.size2str(weights.size())
    volume = distiller.volume(weights)
    print("size=%s = %d elements" % (shape_str, volume))
    
    # Clone because we are going to change the tensor values
    weights = weights.clone()
    if binary_mask:
        weights[weights!=0] = 1
        # Take the inverse of the pixels, because we want zeros to appear white
        #weights = 1 - weights
    
    kernels = weights.view(ofms * ifms, kh, kw)
    nrow, ncol = layout[0], layout[1]

    # Plot the graph
    plt.gray()
    #plt.tight_layout()
    fig = plt.figure( figsize=(layout[0]*size_ctrl, layout[1]*size_ctrl) );

    # We want to normalize the grayscale brightness levels for all of the images we display (group),
    # otherwise, each image is normalized separately and this causes distortion between the different
    # filters images we ddisplay.
    # We don't normalize across all of the filters images, because the outliers cause the image of each 
    # filter to be very muted.  This is because each group of filters we display usually has low variance
    # between the element values of that group.
    if color_normalization=='Tensor':
        gmin = weights.min()
        gmax = weights.max()
    elif color_normalization=='Group':
        gmin = weights[0:nrow, 0:ncol].min()
        gmax = weights[0:nrow, 0:ncol].max()
    print("gmin=%.4f\tgmax=%.4f" % (gmin, gmax))
    if isinstance(gmin, torch.Tensor):
        gmin = gmin.item()
        gmax = gmax.item()
    
    i = 0 
    for row in range(0, nrow):
        for col in range (0, ncol):
            ax = fig.add_subplot(layout[0], layout[1], i+1)
            if binary_mask:
                ax.matshow(kernels[first_kernel+i], cmap='binary', vmin=0, vmax=1);
            else:
                # Use siesmic so that colors around the center are lighter.  Red and blue are used
                # to represent (and visually separate) negative and positive weights 
                ax.matshow(kernels[first_kernel+i], cmap='seismic', vmin=gmin, vmax=gmax, interpolation=interpolation);
            ax.set(xticks=[], yticks=[])
            i += 1
    
    
def l1_norm_histogram(weights):
    """Compute a histogram of the L1-norms of the kernels of a weights tensor.
    
    The L1-norm of a kernel is one way to quantify the "magnitude" of the total coeffiecients
    making up this kernel.
    
    Another interesting look at filters is to compute a histogram per filter.
    """
    ofms, ifms = weights.size()[0], weights.size()[1]
    kw, kh = weights.size()[2], weights.size()[3]
    kernels = weights.view(ofms * ifms, kh, kw)
    
    l1_hist = []
    for kernel in range(ofms*ifms):
        l1_hist.append(kernels[kernel].norm(1))
    return l1_hist

def plot_l1_norm_hist(weights):    
    l1_hist = l1_norm_histogram(weights)
    n, bins, patches = plt.hist(l1_hist, bins=200)
    plt.title('Kernel L1-norm histograms')
    plt.ylabel('Frequency')
    plt.xlabel('Kernel L1-norm')
    plt.show()
    

def plot_layer_sizes(which, sparse_model, dense_model):
    dense = []
    sparse = []
    names = []
    for name, sparse_weights in sparse_model.state_dict().items():
        if ('weight' not in name) or (which!='*' and which not in name):
                continue    
        sparse.append(len(sparse_weights[sparse_weights!=0]))
        names.append(name)

    for name, dense_weights in dense_model.state_dict().items():
        if ('weight' not in name) or (which!='*' and which not in name):
                continue
        dense.append(dense_weights.numel())

    N = len(sparse)
    ind = np.arange(N)    # the x locations for the groups

    fig, ax = plt.subplots()
    width = .47
    p1 = plt.bar(ind, dense,  width = .47, color = '#278DBC')
    p2 = plt.bar(ind, sparse, width = 0.35, color = '#000099')

    plt.ylabel('Size')
    plt.title('Layer sizes')
    plt.xticks(rotation='vertical')
    plt.xticks(ind, names)
    #plt.yticks(np.arange(0, 100, 150))
    plt.legend((p1[0], p2[0]), ('Dense', 'Sparse'))

    #Remove plot borders
    for location in ['right', 'left', 'top', 'bottom']:
        ax.spines[location].set_visible(False)  

    #Fix grid to be horizontal lines only and behind the plots
    ax.yaxis.grid(color='gray', linestyle='solid')
    ax.set_axisbelow(True)
    plt.show()
    
    
def conv_param_names(model):
    return [param_name for param_name, p in model.state_dict().items()  
            if (p.dim()>2) and ("weight" in param_name)]

def conv_fc_param_names(model):
    return [param_name for param_name, p in model.state_dict().items()  
            if (p.dim()>1) and ("weight" in param_name)]

def conv_fc_params(model):
    return [(param_name,p) for (param_name, p) in model.state_dict()
            if (p.dim()>1) and ("weight" in param_name)]

def fc_param_names(model):
    return [param_name for param_name, p in model.state_dict().items()  
            if (p.dim()==2) and ("weight" in param_name)]
def plot_bars(which, setA, setAName, setB, setBName, names, title):
    N = len(setA)
    ind = np.arange(N)    # the x locations for the groups

    fig, ax = plt.subplots(figsize=(20,10))
    width = .47
    p1 = plt.bar(ind, setA,  width = .47, color = '#278DBC')
    p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')

    plt.ylabel('Size')
    plt.title(title)
    plt.xticks(rotation='vertical')
    plt.xticks(ind, names)
    #plt.yticks(np.arange(0, 100, 150))
    plt.legend((p1[0], p2[0]), (setAName, setBName))

    #Remove plot borders
    for location in ['right', 'left', 'top', 'bottom']:
        ax.spines[location].set_visible(False)  

    #Fix grid to be horizontal lines only and behind the plots
    ax.yaxis.grid(color='gray', linestyle='solid')
    ax.set_axisbelow(True)
    plt.show()