Skip to content
Snippets Groups Projects
Commit 957e6777 authored by Neta Zmora's avatar Neta Zmora
Browse files

pytorch 0.4: adjustments to API changes

Various small changes due to the chamnges in the semantics and syntax of the
PyTorch 0.4 API.

Note that currently distiller.model_performance_summary() returns wrong results
on graphs containing torch.nn.DataParallel layers.
parent ea580770
No related branches found
No related tags found
No related merge requests found
......@@ -37,18 +37,18 @@ del thinning
# Distiller version
__version__ = "0.1.0"
def model_find_param_name(model, tensor_to_find):
"""Look up the name of a model tensor.
def model_find_param_name(model, param_to_find):
"""Look up the name of a model parameter.
Arguments:
model: the model to search
tensor_to_find: the tensors who's name we want to look up
param_to_find: the parameter whose name we want to look up
Returns:
The parameter name (string) or None, if the paramter was not found.
The parameter name (string) or None, if the parameter was not found.
"""
for name, tensor in model.state_dict().items():
if tensor is tensor_to_find:
for name, param in model.named_parameters():
if param is param_to_find:
return name
return None
......
......@@ -43,7 +43,7 @@ def model_summary(model, optimizer, what, dataset=None):
distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger])
elif what == 'compute':
if dataset == 'imagenet':
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
dummy_input = Variable(torch.randn(1, 3, 224, 224))
elif dataset == 'cifar10':
dummy_input = Variable(torch.randn(1, 3, 32, 32))
else:
......@@ -101,9 +101,9 @@ def weights_sparsity_summary(model, return_total_sparsity=False, param_dims=[2,4
distiller.sparsity_2D(param)*100,
distiller.sparsity_3D(param)*100,
(1-_density)*100,
param.std(),
param.mean(),
param.abs().mean()
param.std().item(),
param.mean().item(),
param.abs().mean().item()
])
total_sparsity = (1 - sparse_params_size/params_size)*100
......@@ -158,7 +158,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None
in_features_shape = input[0].size()
out_features_shape = output.size()
param_name = distiller.model_find_param_name(model, self.weight.data)
param_name = distiller.model_find_param_name(model, self.weight)
if param_name is None:
return
mod_name = param_name[:param_name.find(".weight")]
......@@ -168,8 +168,13 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None
distiller.size_to_str(out_features_shape), distiller.volume(output),
weights_vol, int(macs)])
def model_performance_summary(model, dummy_input, batch_size=1):
"""Collect performance data"""
"""Collect performance data
warning: in PyTorch 0.4 this function does not return correct values when
the graph contains torch.nn.DataParallel layers.
"""
def install_perf_collector(m):
if isinstance(m, torch.nn.Conv2d):
hook_handles.append(m.register_forward_hook(
......
%% Cell type:markdown id: tags:
# Model summary
Get familiar with your model, by examining its structure and properties.
Use this notebook to display statics and information about the weights, layers and connectivity of the model.<br>
## Table of Contents
1. [Choose which model you want to examine](#Choose-which-model-you-want-to-examine)
2. [Print a summary of the statistics of the model attributes in tabular format](#Print-a-summary-of-the-statistics-of-the-model-attributes-in-tabular-format)<br>
2.1. [Display some information about the layer types](#Display-some-information-about-the-layer-types)<br>
2.2. [Compare weights footprint to feature-map footprint](#Compare-weights-footprint-to-feature-map-footprint)<br>
2.3. [Compare data footprint to compute (MACs)](#Compare-data-footprint-to-compute-(MACs)
3. [Filter L1-norm](#Filter-L1-norm)
4. [References](#References)
%% Cell type:code id: tags:
``` python
# Relative import of code from distiller, w/o installing the package
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from distiller.model_summaries import *
from models import create_model
from apputils import *
import torch
import torchvision
import qgrid
# Load some common jupyter code
%run distiller_jupyter_helpers.ipynb
import ipywidgets as widgets
from ipywidgets import interactive, interact, Layout
# Some models have long node names and require longer lines
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
def pretty_int(i):
return "{:,}".format(i)
```
%% Cell type:markdown id: tags:
## Choose which model you want to examine
If you are studying the structure of a neural network model, you probably don't need a pruned model, although you can use one.
<br>
In this example, we look at a pretrained ResNet18 model.
%% Cell type:code id: tags:
``` python
dataset = 'imagenet'
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
dummy_input = torch.randn(1, 3, 224, 224)
arch = 'resnet18'
#arch = 'alexnet'
checkpoint_file = None
if checkpoint_file is not None:
model = create_model(pretrained=True, dataset=dataset, arch=arch)
load_checkpoint(model, checkpoint_file)
else:
model = create_model(pretrained=True, dataset=dataset, arch=arch, parallel=False)
```
%% Cell type:markdown id: tags:
## Print a summary of the statistics of the model attributes in tabular format
Distiller generates several different summary reports, which are returned as Pandas dataframes which you can slice, dice and sort using Pandas' rich API.<br>
<br>
MACs are multiply-accumulate operations: a MAC unit computes the product of two elements and adds the product to an accumulator. The MACs reported by distiller.model_performance_summary are for direct GEMM (General Matrix-Matrix Multiplication) and convolution. Different hardware uses specific algorithms at different times. For example, [Intel's MKL-DNN](https://intel.github.io/mkl-dnn/) uses [Winograd](https://arxiv.org/pdf/1509.09308.pdf) for 3x3 convolutions. As another example, [convolutions are sometimes computed using GEMM](https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/) for increased utilization of vectorized hardware.<br>
<br>
In the example below, we display some statistics about the sizes and shapes of the feature-maps and weight tensors, and some other goodies. :-)
%% Cell type:code id: tags:
``` python
df = distiller.model_performance_summary(model, dummy_input, 1)
# You can display summaries using several backends, and each has its advantages and disadvantages, so you will want to use them in different situations.
print("Weights shapes, sizes and statistics (showing only FC and convolution layers):")
print("\tTotal IFM footprint (elements): " + "{:,}".format(df['IFM volume'].sum()))
print("\tTotal OFM footprint (elements): " + "{:,}".format(df['OFM volume'].sum()))
print("\tTotal weights footprint (elements): " + "{:,}".format(df['Weights volume'].sum()))
# 1. As a textual table
#t = distiller.model_performance_tbl_summary(model, dummy_input, 1)
#print(t)
# 2. As a plain Pandas dataframe
# display(df)
# 3. As a QGrid table, which you can sort and filter.
qgrid.show_grid(df)
```
%% Cell type:markdown id: tags:
### Display some information about the layer types
Gleaning model statistics using Pandas dataframes, provides a painless way to query 2nd level details about the model, such as what layer types it uses.
%% Cell type:code id: tags:
``` python
conv7x7 = df[df['Attrs'] == 'k=(7, 7)']
conv3x3 = df[df['Attrs'] == 'k=(3, 3)']
conv1x1 = df[df['Attrs'] == 'k=(1, 1)']
print("There are %d Conv(7,7) layers with total MACs = %s" % (len(conv7x7), pretty_int(conv7x7['MACs'].sum())))
print("There are %d Conv(3,3) layers with total MACs = %s" % (len(conv3x3), pretty_int(conv3x3['MACs'].sum())))
print("There are %d Conv(1,1) layers with total MACs = %s" % (len(conv1x1), pretty_int(conv1x1['MACs'].sum())))
```
%% Cell type:markdown id: tags:
### Compare weights footprint to feature-map footprint
Memory footprint, bandwidth and throughput are different concepts. Footprint is the size amount of memory required to store a piece of data (e.g. measured as number of bytes). Bandwidth is the rate at which data can be read or written (stored) from/to memory by different hardware (e.g. measured as bytes/sec). Throughput is a measure of the data that actually moves (read/stored) in a period of time (bytes/sec).<br>
Because the amount of data required for a typical neural-network operation is often larger than the available working memory of the compute hardware (e.g. CPU registers and cache), data often needs to be sliced into tiles (blocks). The sizes of the tiles, together with the memory access pattern and the compute algorithm, determine the total amount of data that needs to move around (read/stored). Because of this hardware dependency, we provide below information regarding memory footprint and not throughput.
%% Cell type:code id: tags:
``` python
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(15,7.5))
fig.suptitle("Foortprint Statistics by layer")
ax.set_ylabel("Feature Maps")
ax.set_xlabel("Layer")
ax2 = ax.twinx()
ax2.set_ylabel("Weights")
ax.set_xticklabels(df.Name, rotation=90);
df["FM volume"] = df["OFM volume"] + df["IFM volume"]
df[["Name","FM volume"]].plot(ax=ax, xticks=range(len(df.index)), style="b-", rot=90)
df[["Name","Weights volume"]].plot(ax=ax2, style="g-", use_index=True, rot=90);
```
%% Cell type:markdown id: tags:
### Compare data footprint to compute (MACs)
We measure Footprint in number of elements, not bytes. If, for example, the elements data type is FP32, then the real footprint is 4x the reported footprint.
%% Cell type:code id: tags:
``` python
fig, ax = plt.subplots(figsize=(15,7.5))
fig.suptitle("Foortprint vs. Compute")
ax.set_ylabel("MACs")
ax.set_xlabel("Layer")
ax2 = ax.twinx()
ax2.set_ylabel("Footprint")
df[["Name", "MACs"]].plot(ax=ax, kind='bar', rot=90, xticks=range(len(df.index)), figsize=[15,7.5])
ax.set_xticklabels(df.Name, rotation=90);
df2 = df["Weights volume"] + df["OFM volume"] + df["IFM volume"]
df2.plot(ax=ax2, style="g-", use_index=True, rot=90);
```
%% Cell type:markdown id: tags:
## Filter L1-norm
Draw the L1 norm of each filter, in a selected weight tensor.<br>
When ranking filters by L1-norm (as in [Pruning filters for efficient convnets](#Hao-et-al-2016)), this can provide some insight as to which filters will be removed.<br>
Make sure you've loaded a pretrained network, otherwise you will be looking at random data.
%% Cell type:code id: tags:
``` python
params_names = conv_param_names(model)
def view_weights(pname, sort):
param = model.state_dict()[pname]
view_filters = param.view(param.size(0), -1)
filter_mags = to_np(view_filters.abs().mean(dim=1))
if sort:
filter_mags = np.sort(filter_mags)
plt.figure(figsize=[15,7.5])
plt.plot(range(len(filter_mags)), filter_mags, label=pname, marker="o", markersize=5, markerfacecolor="C1")
plt.xlabel('Filter index (i.e. output feature-map channel)')
plt.ylabel('Fliter L1-norm')
sort_choice = widgets.Checkbox(value=True, description='Sort filters')
params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))
interact(view_weights, pname=params_dropdown, sort=sort_choice);
```
%% Cell type:markdown id: tags:
## References
<div id="Gray-et-al-2015"></div> **Andrew Lavin and Scott Gray**.
[*Fast Algorithms for Convolutional Neural Networks*](https://arxiv.org/pdf/1509.09308.pdf),
2015.
<div id="Hao-et-al-2016"></div> **Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, and Hans Peter Graf**.
[*Pruning filters for efficient convnets*](https://arxiv.org/abs/1608.08710),
2016.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment