diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index deff50344e58225759a48d8b16a7f30262d7cb5f..e634d681c3e4d318aa7862960241ab488f68de51 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -20,6 +20,7 @@ This code is proven to work on CNN image classification models using PyTorch 04. RNNs are currently not working well. """ +import os import re import numpy as np import collections @@ -583,14 +584,7 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F 'style': 'rounded, filled'} """ try: - if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) - else: - print("Unsupported dataset (%s) - aborting draw operation" % dataset) - return - + dummy_input = dataset_dummy_input(dataset) model = distiller.make_non_parallel_copy(model) g = SummaryGraph(model, dummy_input) draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles) @@ -601,6 +595,32 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F print("\t$ sudo apt-get install graphviz") +def dataset_dummy_input(dataset): + if dataset == 'imagenet': + dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) + elif dataset == 'cifar10': + dummy_input = Variable(torch.randn(1, 3, 32, 32)) + else: + raise ValueError("Unsupported dataset (%s) - aborting draw operation" % dataset) + return dummy_input + + +def export_img_classifier_to_onnx(model, onnx_fname, dataset): + """Export a PyTorch image classifier to ONNX. + + """ + dummy_input = dataset_dummy_input(dataset) + + #model.eval() + with torch.onnx.set_training(model, False): + # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel + if isinstance(model, torch.nn.DataParallel): + model = model.module + torch.onnx.export(model, dummy_input.to('cuda'), onnx_fname, verbose=False) + msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname)) + + + def data_node_has_parent(g, id): for edge in g.edges: if edge.dst == id: return True diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 943a97048af76ac63989c3f94bb7a73e454f9677..e0687f1babe67352cf6a6b07f7307f66b9b604ab 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -122,7 +122,7 @@ parser.add_argument('--act-stats', dest='activation_stats', action='store_true', help='collect activation statistics (WARNING: this slows down training)') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)') -SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES, help='print a summary of the model, and exit - options: ' + ' | '.join(SUMMARY_CHOICES)) @@ -632,6 +632,8 @@ def evaluate_model(model, criterion, test_loader, loggers, args): def summarize_model(model, dataset, which_summary): if which_summary.startswith('png'): apputils.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params') + elif which_summary == 'onnx': + apputils.export_img_classifier_to_onnx(model, 'model.onnx', dataset) else: distiller.model_summary(model, which_summary, dataset)