Skip to content
Snippets Groups Projects
Commit bc719c20 authored by Neta Zmora's avatar Neta Zmora
Browse files

Export trained (image classification) models to ONNX

parent 17242204
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ This code is proven to work on CNN image classification models using PyTorch 04.
RNNs are currently not working well.
"""
import os
import re
import numpy as np
import collections
......@@ -583,14 +584,7 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
'style': 'rounded, filled'}
"""
try:
if dataset == 'imagenet':
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = Variable(torch.randn(1, 3, 32, 32))
else:
print("Unsupported dataset (%s) - aborting draw operation" % dataset)
return
dummy_input = dataset_dummy_input(dataset)
model = distiller.make_non_parallel_copy(model)
g = SummaryGraph(model, dummy_input)
draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles)
......@@ -601,6 +595,32 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
print("\t$ sudo apt-get install graphviz")
def dataset_dummy_input(dataset):
if dataset == 'imagenet':
dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = Variable(torch.randn(1, 3, 32, 32))
else:
raise ValueError("Unsupported dataset (%s) - aborting draw operation" % dataset)
return dummy_input
def export_img_classifier_to_onnx(model, onnx_fname, dataset):
"""Export a PyTorch image classifier to ONNX.
"""
dummy_input = dataset_dummy_input(dataset)
#model.eval()
with torch.onnx.set_training(model, False):
# Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel
if isinstance(model, torch.nn.DataParallel):
model = model.module
torch.onnx.export(model, dummy_input.to('cuda'), onnx_fname, verbose=False)
msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname))
def data_node_has_parent(g, id):
for edge in g.edges:
if edge.dst == id: return True
......
......@@ -122,7 +122,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true',
help='collect activation statistics (WARNING: this slows down training)')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params']
SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
......@@ -632,6 +632,8 @@ def evaluate_model(model, criterion, test_loader, loggers, args):
def summarize_model(model, dataset, which_summary):
if which_summary.startswith('png'):
apputils.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params')
elif which_summary == 'onnx':
apputils.export_img_classifier_to_onnx(model, 'model.onnx', dataset)
else:
distiller.model_summary(model, which_summary, dataset)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment