Skip to content
Snippets Groups Projects
Commit 671038d5 authored by Neta Zmora's avatar Neta Zmora
Browse files

Add some protection around the experimental features not supported in PyTorch 3.1

parent c8752ac6
No related branches found
No related tags found
No related merge requests found
...@@ -257,17 +257,27 @@ def draw_model_to_file(sgraph, png_fname): ...@@ -257,17 +257,27 @@ def draw_model_to_file(sgraph, png_fname):
fid.write(png) fid.write(png)
def draw_img_classifier_to_file(model, png_fname, dataset): def draw_img_classifier_to_file(model, png_fname, dataset):
if dataset == 'imagenet': try:
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) if dataset == 'imagenet':
elif dataset == 'cifar10': dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
dummy_input = Variable(torch.randn(1, 3, 32, 32)) elif dataset == 'cifar10':
else: dummy_input = Variable(torch.randn(1, 3, 32, 32))
print("Unsupported dataset (%s) - aborting draw operation" % dataset) else:
return print("Unsupported dataset (%s) - aborting draw operation" % dataset)
return
g = SummaryGraph(model, dummy_input)
draw_model_to_file(g, png_fname) g = SummaryGraph(model, dummy_input)
draw_model_to_file(g, png_fname)
print("Network PNG image generation completed")
except TypeError as e:
print("An error has occured while generating the network PNG image.")
print("This feature is not supported on official PyTorch releases.")
print("Please check that you are using a valid PyTorch version.")
print("You are using pytorch version %s" %torch.__version__)
except FileNotFoundError:
print("An error has occured while generating the network PNG image.")
print("Please check that you have graphviz installed.")
print("\t$ sudo apt-get install graphviz")
def create_png(sgraph): def create_png(sgraph):
"""Create a PNG object containing a graphiz-dot graph of the netowrk represented """Create a PNG object containing a graphiz-dot graph of the netowrk represented
......
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Experimental # Experimental
<br> <br>
<font size="6" color="red"> &#9888; WARNING </font> <font size="6" color="red"> &#9888; WARNING </font>
<br> <br>
<font size="4" color="red">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font> <font size="4" color="red">This part of the notebook works correctly only on some advanced PyTorch versions (e.g. 0.4.0a0+410fd58), therefore is may not run correctly for you.</font><br><br>
Please also note that for generating a PNG image of the network (last cell of the notebook), you will need to have graphviz installed:
```
$ sudo apt-get install graphviz
```
<br> <br>
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Relative import of code from distiller, w/o installing the package # Relative import of code from distiller, w/o installing the package
import os import os
import sys import sys
module_path = os.path.abspath(os.path.join('..')) module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path: if module_path not in sys.path:
sys.path.append(module_path) sys.path.append(module_path)
%matplotlib inline %matplotlib inline
import matplotlib import matplotlib
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from distiller.model_summaries import * from distiller.model_summaries import *
from models import create_model from models import create_model
from apputils import * from apputils import *
import torch import torch
import torchvision import torchvision
import qgrid import qgrid
# Load some common jupyter code # Load some common jupyter code
%run distiller_jupyter_helpers.ipynb %run distiller_jupyter_helpers.ipynb
import ipywidgets as widgets import ipywidgets as widgets
from ipywidgets import interactive, interact, Layout from ipywidgets import interactive, interact, Layout
# Some models have long node names and require longer lines # Some models have long node names and require longer lines
from IPython.core.display import display, HTML from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>")) display(HTML("<style>.container { width:100% !important; }</style>"))
print("You are using pytorch version %s" %torch.__version__) print("You are using pytorch version %s" %torch.__version__)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Choose which model you want to examine ## Choose which model you want to examine
If you are studying the structure of a neural network model, you probably don't need a pruned model, although you can use one. If you are studying the structure of a neural network model, you probably don't need a pruned model, although you can use one.
<br> <br>
In this example, we look at a pretrained ResNet18 model. In this example, we look at a pretrained ResNet18 model.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
dataset = 'imagenet' dataset = 'imagenet'
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
arch = 'resnet18' arch = 'resnet18'
#arch = 'alexnet' #arch = 'alexnet'
checkpoint_file = None checkpoint_file = None
if checkpoint_file is not None: if checkpoint_file is not None:
model = create_model(pretrained=False, dataset=dataset, arch=arch) model = create_model(pretrained=False, dataset=dataset, arch=arch)
load_checkpoint(model, checkpoint_file) load_checkpoint(model, checkpoint_file)
else: else:
model = create_model(pretrained=False, dataset=dataset, arch=arch, parallel=False) model = create_model(pretrained=False, dataset=dataset, arch=arch, parallel=False)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
You can examine layer connectivity: You can examine layer connectivity:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
dummy_imagenet_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) dummy_imagenet_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
dummy_cifar_input = Variable(torch.randn(1, 3, 32, 32), requires_grad=False) dummy_cifar_input = Variable(torch.randn(1, 3, 32, 32), requires_grad=False)
dummy_input = dummy_imagenet_input if dataset=='imagenet' else dummy_cifar_input dummy_input = dummy_imagenet_input if dataset=='imagenet' else dummy_cifar_input
g = SummaryGraph(model, dummy_input) g = SummaryGraph(model, dummy_input)
df = connectivity_summary(g) df = connectivity_summary(g)
#qgrid.set_grid_option('defaultColumnWidth', 10) #qgrid.set_grid_option('defaultColumnWidth', 10)
qgrid.show_grid(df) qgrid.show_grid(df)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
You can also print the shapes of the various tensors. You can also print the shapes of the various tensors.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
df = connectivity_summary_verbose(g) df = connectivity_summary_verbose(g)
qgrid.show_grid(df) qgrid.show_grid(df)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
And you can discover the attributes of each layer: And you can discover the attributes of each layer:
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def volume(dims): def volume(dims):
vol = 1 vol = 1
for d in range(len(dims)): vol *= dims[d] for d in range(len(dims)): vol *= dims[d]
return vol return vol
def param_shape(sgraph, param_id): def param_shape(sgraph, param_id):
return sgraph.params[param_id]['shape'] return sgraph.params[param_id]['shape']
def param_volume(sgraph, param_id): def param_volume(sgraph, param_id):
return volume(param_shape(sgraph, param_id)) return volume(param_shape(sgraph, param_id))
def add_macs_attr(sgraph): def add_macs_attr(sgraph):
for op in sgraph.ops: for op in sgraph.ops:
op['attrs']['MACs'] = 0 op['attrs']['MACs'] = 0
if op['type'] == 'Conv': if op['type'] == 'Conv':
conv_out = op['outputs'][0] conv_out = op['outputs'][0]
conv_in = op['inputs'][0] conv_in = op['inputs'][0]
conv_w = op['attrs']['kernel_shape'] conv_w = op['attrs']['kernel_shape']
ofm_vol = param_volume(sgraph, conv_out) ofm_vol = param_volume(sgraph, conv_out)
# MACs = volume(OFM) * (#IFM * K^2) # MACs = volume(OFM) * (#IFM * K^2)
op['attrs']['MACs'] = ofm_vol * volume(conv_w) * sgraph.params[conv_in]['shape'][1] op['attrs']['MACs'] = ofm_vol * volume(conv_w) * sgraph.params[conv_in]['shape'][1]
elif op['type'] == 'Gemm': elif op['type'] == 'Gemm':
conv_out = op['outputs'][0] conv_out = op['outputs'][0]
conv_in = op['inputs'][0] conv_in = op['inputs'][0]
n_ifm = param_shape(sgraph, conv_in)[1] n_ifm = param_shape(sgraph, conv_in)[1]
n_ofm = param_shape(sgraph, conv_out)[1] n_ofm = param_shape(sgraph, conv_out)[1]
# MACs = #IFM * #OFM # MACs = #IFM * #OFM
op['attrs']['MACs'] = n_ofm * n_ifm op['attrs']['MACs'] = n_ofm * n_ifm
def add_footprint_attr(sgraph): def add_footprint_attr(sgraph):
for op in sgraph.ops: for op in sgraph.ops:
op['attrs']['footprint'] = 0 op['attrs']['footprint'] = 0
if op['type'] in ['Conv', 'Gemm', 'MaxPool']: if op['type'] in ['Conv', 'Gemm', 'MaxPool']:
conv_out = op['outputs'][0] conv_out = op['outputs'][0]
conv_in = op['inputs'][0] conv_in = op['inputs'][0]
ofm_vol = param_volume(sgraph, conv_out) ofm_vol = param_volume(sgraph, conv_out)
ifm_vol = param_volume(sgraph, conv_in) ifm_vol = param_volume(sgraph, conv_in)
if op['type'] == 'Conv' or op['type'] == 'Gemm': if op['type'] == 'Conv' or op['type'] == 'Gemm':
conv_w = op['inputs'][1] conv_w = op['inputs'][1]
weights_vol = param_volume(sgraph, conv_w) weights_vol = param_volume(sgraph, conv_w)
#print(ofm_vol , ifm_vol , weights_vol) #print(ofm_vol , ifm_vol , weights_vol)
op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol
op['attrs']['weights_vol'] = weights_vol op['attrs']['weights_vol'] = weights_vol
elif op['type'] == 'MaxPool': elif op['type'] == 'MaxPool':
op['attrs']['footprint'] = ofm_vol + ifm_vol op['attrs']['footprint'] = ofm_vol + ifm_vol
def add_arithmetic_intensity_attr(sgraph): def add_arithmetic_intensity_attr(sgraph):
for op in sgraph.ops: for op in sgraph.ops:
if op['attrs']['footprint'] == 0: if op['attrs']['footprint'] == 0:
op['attrs']['ai'] = 0 op['attrs']['ai'] = 0
else: else:
# integers are enough, and note that we also round up # integers are enough, and note that we also round up
op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint']) // op['attrs']['footprint']) op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint']) // op['attrs']['footprint'])
def get_attr(sgraph, attr, f = lambda op: True): def get_attr(sgraph, attr, f = lambda op: True):
return [op['attrs'][attr] for op in sgraph.ops if attr in op['attrs'] and f(op)] return [op['attrs'][attr] for op in sgraph.ops if attr in op['attrs'] and f(op)]
def get_ops(sgraph, attr, f = lambda op: True): def get_ops(sgraph, attr, f = lambda op: True):
return [op for op in sgraph.ops if attr in op['attrs'] and f(op)] return [op for op in sgraph.ops if attr in op['attrs'] and f(op)]
add_macs_attr(g) add_macs_attr(g)
add_footprint_attr(g) add_footprint_attr(g)
add_arithmetic_intensity_attr(g) add_arithmetic_intensity_attr(g)
ignore_attrs = ['group', 'is_test', 'consumed_inputs', 'alpha', 'beta', 'MACs', 'footprint', 'ai', 'fm_vol', 'weights_vol'] ignore_attrs = ['group', 'is_test', 'consumed_inputs', 'alpha', 'beta', 'MACs', 'footprint', 'ai', 'fm_vol', 'weights_vol']
df = attributes_summary(g, ignore_attrs) df = attributes_summary(g, ignore_attrs)
df['MAC'] = get_attr(g, 'MACs') df['MAC'] = get_attr(g, 'MACs')
df['BW'] = get_attr(g, 'footprint') df['BW'] = get_attr(g, 'footprint')
df['AI'] = get_attr(g, 'ai') df['AI'] = get_attr(g, 'ai')
#df = df.assign([5]*len(df)).values #df = df.assign([5]*len(df)).values
qgrid.show_grid(df) qgrid.show_grid(df)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def plot_bars(which, setA, setAName, setB, setBName, names, title): def plot_bars(which, setA, setAName, setB, setBName, names, title):
N = len(setA) N = len(setA)
ind = np.arange(N) # the x locations for the groups ind = np.arange(N) # the x locations for the groups
fig, ax = plt.subplots(figsize=(20,10)) fig, ax = plt.subplots(figsize=(20,10))
width = .47 width = .47
p1 = plt.bar(ind, setA, width = .47, color = '#278DBC') p1 = plt.bar(ind, setA, width = .47, color = '#278DBC')
p2 = plt.bar(ind, setB, width = 0.35, color = '#000099') p2 = plt.bar(ind, setB, width = 0.35, color = '#000099')
plt.ylabel('Size') plt.ylabel('Size')
plt.title(title) plt.title(title)
plt.xticks(rotation='vertical') plt.xticks(rotation='vertical')
plt.xticks(ind, names) plt.xticks(ind, names)
#plt.yticks(np.arange(0, 100, 150)) #plt.yticks(np.arange(0, 100, 150))
plt.legend((p1[0], p2[0]), (setAName, setBName)) plt.legend((p1[0], p2[0]), (setAName, setBName))
#Remove plot borders #Remove plot borders
for location in ['right', 'left', 'top', 'bottom']: for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False) ax.spines[location].set_visible(False)
#Fix grid to be horizontal lines only and behind the plots #Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid') ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True) ax.set_axisbelow(True)
plt.show() plt.show()
sgraph = g sgraph = g
names = [op['name'] for op in sgraph.ops] names = [op['name'] for op in sgraph.ops]
setA = get_attr(g, 'fm_vol') setA = get_attr(g, 'fm_vol')
setB = get_attr(g, 'weights_vol') setB = get_attr(g, 'weights_vol')
plot_bars(None, setA, 'Feature maps', setB, 'Weights', names, 'Weights footprint vs. feature-maps footprint\n(Normalized)') plot_bars(None, setA, 'Feature maps', setB, 'Weights', names, 'Weights footprint vs. feature-maps footprint\n(Normalized)')
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
names = [op['name'] for op in sgraph.ops if 'MACs' in op['attrs'] and op['attrs']['MACs']>0] names = [op['name'] for op in sgraph.ops if 'MACs' in op['attrs'] and op['attrs']['MACs']>0]
macs = get_attr(g, 'MACs', lambda op: op['attrs']['MACs']>0) macs = get_attr(g, 'MACs', lambda op: op['attrs']['MACs']>0)
y_pos = np.arange(len(names)) y_pos = np.arange(len(names))
fig, ax = plt.subplots(figsize=(20,10)) fig, ax = plt.subplots(figsize=(20,10))
barlist = plt.bar(y_pos, macs, align='center', alpha=0.5, color = '#278DBC') barlist = plt.bar(y_pos, macs, align='center', alpha=0.5, color = '#278DBC')
plt.xticks(y_pos, names) plt.xticks(y_pos, names)
plt.ylabel('MACs') plt.ylabel('MACs')
plt.title('MACs per layer') plt.title('MACs per layer')
#Fix grid to be horizontal lines only and behind the plots #Fix grid to be horizontal lines only and behind the plots
ax.yaxis.grid(color='gray', linestyle='solid') ax.yaxis.grid(color='gray', linestyle='solid')
ax.set_axisbelow(True) ax.set_axisbelow(True)
plt.xticks(rotation='vertical') plt.xticks(rotation='vertical')
#Remove plot borders #Remove plot borders
for location in ['right', 'left', 'top', 'bottom']: for location in ['right', 'left', 'top', 'bottom']:
ax.spines[location].set_visible(False) ax.spines[location].set_visible(False)
ops = get_ops(g, 'MACs', lambda op: op['attrs']['MACs']>0) ops = get_ops(g, 'MACs', lambda op: op['attrs']['MACs']>0)
for bar,op in zip(barlist, ops): for bar,op in zip(barlist, ops):
kernel = op['attrs'].get('kernel_shape', None) kernel = op['attrs'].get('kernel_shape', None)
if str(kernel) == '[7, 7]': if str(kernel) == '[7, 7]':
bar.set_color('r') bar.set_color('r')
if str(kernel) == '[3, 3]': if str(kernel) == '[3, 3]':
bar.set_color('g') bar.set_color('g')
plt.show() plt.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
ops = sgraph.ops ops = sgraph.ops
positive_mac = lambda op: op['attrs']['MACs']>0 positive_mac = lambda op: op['attrs']['MACs']>0
names = get_attr(g, 'name', positive_mac) names = get_attr(g, 'name', positive_mac)
macs = get_attr(g, 'MACs', positive_mac) macs = get_attr(g, 'MACs', positive_mac)
norm_macs = [float(i)/np.sum(macs) for i in macs] norm_macs = [float(i)/np.sum(macs) for i in macs]
footprint = get_attr(g, 'footprint', positive_mac) footprint = get_attr(g, 'footprint', positive_mac)
norm_footprint = [float(i)/np.sum(footprint) for i in footprint] norm_footprint = [float(i)/np.sum(footprint) for i in footprint]
plot_bars(None, norm_macs, 'MACs', norm_footprint, 'footprint', names, "MACs vs footprint") plot_bars(None, norm_macs, 'MACs', norm_footprint, 'footprint', names, "MACs vs footprint")
#norm = [float(i)/sum(raw) for i in raw] #norm = [float(i)/sum(raw) for i in raw]
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Create a PNG image of the model ## Create a PNG image of the model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from IPython.display import Image from IPython.display import Image
g = SummaryGraph(model, dummy_input) g = SummaryGraph(model, dummy_input)
DRAW_TO_FILE = True DRAW_TO_FILE = True
if DRAW_TO_FILE: if DRAW_TO_FILE:
draw_model_to_file(g, 'graph.png') draw_model_to_file(g, 'graph.png')
# Draw on notebook # Draw on notebook
png = create_png(g) png = create_png(g)
Image(png) Image(png)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## References ## References
<div id="Gray-et-al-2015"></div> **Andrew Lavin and Scott Gray**. <div id="Gray-et-al-2015"></div> **Andrew Lavin and Scott Gray**.
[*Fast Algorithms for Convolutional Neural Networks*](https://arxiv.org/pdf/1509.09308.pdf), [*Fast Algorithms for Convolutional Neural Networks*](https://arxiv.org/pdf/1509.09308.pdf),
2015. 2015.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment