Skip to content
Snippets Groups Projects
Unverified Commit 62862a08 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by GitHub
Browse files

PyTorch 1.0.0 support + Proper Packaging (Release 0.3) (#144)

Not backward compatible - re-installation is required

* Fixes for PyTorch==1.0.0
* Refactoring folder structure
* Update installation section in docs
parent ee1d160f
No related branches found
No related tags found
No related merge requests found
Showing
with 272 additions and 11 deletions
......@@ -11,3 +11,6 @@ env/
logs/
.DS_Store
.vscode/
distiller.egg-info/
latest_log_dir
latest_log_file
......@@ -34,6 +34,11 @@
**Distiller** is an open-source Python package for neural network compression research.
Network compression can reduce the memory footprint of a neural network, increase its inference speed and save energy. Distiller provides a [PyTorch](http://pytorch.org/) environment for prototyping and analyzing compression algorithms, such as sparsity-inducing methods and low-precision arithmetic.
#### Note on Release 0.3 - Possible BREAKING Changes
As of release 0.3, we've moved some code around to enable proper packaging and installation of Distiller. In addition, we updated Distiller to support PyTorch 1.0.0, which might also cause older code to break due to some API changes.
If updating from an earlier revision of the code, please make sure to follow the instructions in the [install](#install-the-package) section to make sure proper installation of Distiller and all dependencies.
<details><summary><b>What's New in November?</b></summary>
<p>
......@@ -109,7 +114,7 @@ Beware.
- [Using virtualenv](#using-virtualenv)
- [Using venv](#using-venv)
- [Activate the environment](#activate-the-environment)
- [Install dependencies](#install-dependencies)
- [Install the package](#install-the-package)
- [Getting Started](#getting-started)
- [Example invocations of the sample application](#example-invocations-of-the-sample-application)
- [Training-only](#training-only)
......@@ -216,12 +221,15 @@ The environment activation and deactivation commands for ```venv``` and ```virtu
$ source env/bin/activate
```
### Install dependencies
Finally, install Distiller's dependency packages using ```pip3```:
### Install the package
Finally, install the Distiller package and its dependencies using ```pip3```:
```
$ pip3 install -r requirements.txt
$ cd distiller
$ pip3 install -e .
```
PyTorch is included in the ```requirements.txt``` file, and will currently download PyTorch version 0.4.0 for CUDA 8.0. This is the setup we've used for testing Distiller.
This installs Distiller in "development mode", meaning any changes made in the code are reflected in the environment without re-running the install command (so no need to re-install after pulling changes from the Git repository).
PyTorch is included in the ```requirements.txt``` file, and will currently download PyTorch version 1.0.1 for CUDA 9.0. This is the setup we've used for testing Distiller.
## Getting Started
......
......@@ -24,6 +24,7 @@ from .directives import *
from .policy import *
from .thinning import *
from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name
del dict_config
......
......@@ -19,13 +19,11 @@ when working with distiller.
"""
from .data_loaders import *
from .model_summaries import *
from .checkpoint import *
from .execution_env import *
from .dataset_summaries import *
del data_loaders
del model_summaries
del checkpoint
del execution_env
del dataset_summaries
File moved
File moved
......@@ -32,7 +32,7 @@ def dataset_summary(data_loader):
from statistics import mean
print('Dataset contains {} items'.format(len(data_loader.sampler)))
print('Found {} classes'.format(nclasses))
for data_class, size in hist.iteritems():
for data_class, size in hist.items():
print('\tClass {} = {}'.format(data_class, size))
print('mean: ', mean(list(hist.values())))
File moved
......@@ -20,6 +20,8 @@
- optimizer state
- model details
"""
import os
import pydot
from functools import partial
import pandas as pd
from tabulate import tabulate
......@@ -28,13 +30,17 @@ import torch
from torch.autograd import Variable
import torch.optim
import distiller
from .summary_graph import SummaryGraph
from .data_loggers import PythonLogger, CsvLogger
msglogger = logging.getLogger()
__all__ = ['model_summary',
'weights_sparsity_summary', 'weights_sparsity_tbl_summary',
'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary']
'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary',
'attributes_summary', 'attributes_summary_tbl', 'connectivity_summary',
'connectivity_summary_verbose', 'connectivity_tbl_summary', 'create_png', 'create_pydot_graph',
'draw_model_to_file', 'draw_img_classifier_to_file']
def model_summary(model, what, dataset=None):
......@@ -226,3 +232,248 @@ def model_performance_tbl_summary(model, dummy_input, batch_size):
df = model_performance_summary(model, dummy_input, batch_size)
t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
return t
def attributes_summary(sgraph, ignore_attrs):
"""Generate a summary of a graph's attributes.
Args:
sgraph: a SummaryGraph instance
ignore_attrs: a list of attributes to ignore in the output datafraem
Output:
A Pandas dataframe
"""
def pretty_val(val):
if type(val) == int:
return format(val, ",d")
return str(val)
def pretty_attrs(attrs, ignore_attrs):
ret = ''
for key, val in attrs.items():
if key in ignore_attrs:
continue
ret += key + ': ' + pretty_val(val) + '\n'
return ret
df = pd.DataFrame(columns=['Name', 'Type', 'Attributes'])
pd.set_option('precision', 5)
for i, op in enumerate(sgraph.ops.values()):
df.loc[i] = [op['name'], op['type'], pretty_attrs(op['attrs'], ignore_attrs)]
return df
def attributes_summary_tbl(sgraph, ignore_attrs):
df = attributes_summary(sgraph, ignore_attrs)
return tabulate(df, headers='keys', tablefmt='psql')
def connectivity_summary(sgraph):
"""Generate a summary of each node's connectivity.
Args:
sgraph: a SummaryGraph instance
"""
df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs'])
pd.set_option('precision', 5)
for i, op in enumerate(sgraph.ops.values()):
df.loc[i] = [op['name'], op['type'], op['inputs'], op['outputs']]
return df
def connectivity_summary_verbose(sgraph):
"""Generate a summary of each node's connectivity, with details
about the parameters.
Args:
sgraph: a SummaryGraph instance
"""
def format_list(l):
ret = ''
for i in l: ret += str(i) + '\n'
return ret[:-1]
df = pd.DataFrame(columns=['Name', 'Type', 'Inputs', 'Outputs'])
pd.set_option('precision', 5)
for i, op in enumerate(sgraph.ops.values()):
outputs = []
for blob in op['outputs']:
if blob in sgraph.params:
outputs.append(blob + ": " + str(sgraph.params[blob]['shape']))
inputs = []
for blob in op['inputs']:
if blob in sgraph.params:
inputs.append(blob + ": " + str(sgraph.params[blob]['shape']))
inputs = format_list(inputs)
outputs = format_list(outputs)
df.loc[i] = [op['name'], op['type'], inputs, outputs]
return df
def connectivity_tbl_summary(sgraph, verbose=False):
if verbose:
df = connectivity_summary_verbose(sgraph)
else:
df = connectivity_summary(sgraph)
return tabulate(df, headers='keys', tablefmt='psql')
def create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir='TB', styles=None):
"""Low-level API to create a PyDot graph (dot formatted).
"""
pydot_graph = pydot.Dot('Net', graph_type='digraph', rankdir=rankdir)
op_node_style = {'shape': 'record',
'fillcolor': '#6495ED',
'style': 'rounded, filled'}
for op_node in op_nodes:
style = op_node_style
# Check if we should override the style of this node.
if styles is not None and op_node[0] in styles:
style = styles[op_node[0]]
pydot_graph.add_node(pydot.Node(op_node[0], **style, label="\n".join(op_node)))
for data_node in data_nodes:
pydot_graph.add_node(pydot.Node(data_node[0], label="\n".join(data_node[1:])))
node_style = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
if param_nodes is not None:
for param_node in param_nodes:
pydot_graph.add_node(pydot.Node(param_node[0], **node_style, label="\n".join(param_node[1:])))
for edge in edges:
pydot_graph.add_edge(pydot.Edge(edge[0], edge[1]))
return pydot_graph
def create_png(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
"""Create a PNG object containing a graphiz-dot graph of the network,
as represented by SummaryGraph 'sgraph'.
Args:
sgraph (SummaryGraph): the SummaryGraph instance to draw.
display_param_nodes (boolean): if True, draw the parameter nodes
rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
'LR'/'R/L' is Left-to-Rt/Rt-to-Left
styles: a dictionary of styles. Key is module name. Value is
a legal pydot style dictionary. For example:
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
"""
op_nodes = [op['name'] for op in sgraph.ops.values()]
data_nodes = []
param_nodes = []
for id, param in sgraph.params.items():
n_data = (id, str(distiller.volume(param['shape'])), str(param['shape']))
if data_node_has_parent(sgraph, id):
data_nodes.append(n_data)
else:
param_nodes.append(n_data)
edges = sgraph.edges
if not display_param_nodes:
# Use only the edges that don't have a parameter source
non_param_ids = op_nodes + [dn[0] for dn in data_nodes]
edges = [edge for edge in sgraph.edges if edge.src in non_param_ids]
param_nodes = None
op_nodes = [(op['name'], op['type']) for op in sgraph.ops.values()]
pydot_graph = create_pydot_graph(op_nodes, data_nodes, param_nodes, edges, rankdir, styles)
png = pydot_graph.create_png()
return png
def draw_model_to_file(sgraph, png_fname, display_param_nodes=False, rankdir='TB', styles=None):
"""Create a PNG file, containing a graphiz-dot graph of the netowrk represented
by SummaryGraph 'sgraph'
Args:
sgraph (SummaryGraph): the SummaryGraph instance to draw.
png_fname (string): PNG file name
display_param_nodes (boolean): if True, draw the parameter nodes
rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
'LR'/'R/L' is Left-to-Rt/Rt-to-Left
styles: a dictionary of styles. Key is module name. Value is
a legal pydot style dictionary. For example:
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
"""
png = create_png(sgraph, display_param_nodes=display_param_nodes)
with open(png_fname, 'wb') as fid:
fid.write(png)
def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=False,
rankdir='TB', styles=None):
"""Draw a PyTorch image classifier to a PNG file. This a helper function that
simplifies the interface of draw_model_to_file().
Args:
model: PyTorch model instance
png_fname (string): PNG file name
dataset (string): one of 'imagenet' or 'cifar10'. This is required in order to
create a dummy input of the correct shape.
display_param_nodes (boolean): if True, draw the parameter nodes
rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top
'LR'/'R/L' is Left-to-Rt/Rt-to-Left
styles: a dictionary of styles. Key is module name. Value is
a legal pydot style dictionary. For example:
styles['conv1'] = {'shape': 'oval',
'fillcolor': 'gray',
'style': 'rounded, filled'}
"""
try:
dummy_input = dataset_dummy_input(dataset)
model = distiller.make_non_parallel_copy(model)
g = SummaryGraph(model, dummy_input)
draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles)
print("Network PNG image generation completed")
except FileNotFoundError:
print("An error has occured while generating the network PNG image.")
print("Please check that you have graphviz installed.")
print("\t$ sudo apt-get install graphviz")
def dataset_dummy_input(dataset):
if dataset == 'imagenet':
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = Variable(torch.randn(1, 3, 32, 32))
else:
raise ValueError("Unsupported dataset (%s) - aborting draw operation" % dataset)
return dummy_input
def export_img_classifier_to_onnx(model, onnx_fname, dataset, export_params=True, add_softmax=True):
"""Export a PyTorch image classifier to ONNX.
"""
dummy_input = dataset_dummy_input(dataset).to('cuda')
# Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel
model = distiller.make_non_parallel_copy(model)
with torch.onnx.set_training(model, False):
if add_softmax:
# Explicitly add a softmax layer, because it is needed for the ONNX inference phase.
model.original_forward = model.forward
softmax = torch.nn.Softmax(dim=-1)
model.forward = lambda input: softmax(model.original_forward(input))
torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=export_params)
msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname))
def data_node_has_parent(g, id):
for edge in g.edges:
if edge.dst == id:
return True
return False
......@@ -18,8 +18,8 @@
import torch
import torchvision.models as torch_models
import models.cifar10 as cifar10_models
import models.imagenet as imagenet_extra_models
from . import cifar10 as cifar10_models
from . import imagenet as imagenet_extra_models
import logging
msglogger = logging.getLogger()
......
File moved
File moved
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment