From 543048101f3317e1becc6db3ea642ad3f679a809 Mon Sep 17 00:00:00 2001 From: Bar <29775567+barrh@users.noreply.github.com> Date: Thu, 16 May 2019 10:24:34 +0300 Subject: [PATCH] Refactor export to ONNX functionality (#258) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduced a new utility function to export image-classifiers to ONNX: export_img_classifier_to_onnx. The functionality is not new, just refactored. In the sample application compress_classifier.py added --export-onnx as a stand-alone cmd-line flag for specifically exporting ONNX models. This new flag can take an optional argument which is used to name the exported onnx model file. The option to export models was removed from the –summary argument. Now we allow multiple --summary options be called together. Added a basic test for exporting ONNX. --- distiller/model_summaries.py | 53 +++++++++++-------- .../compress_classifier.py | 18 +++---- examples/classifier_compression/parser.py | 9 ++-- tests/test_model_summary.py | 30 +++-------- tests/test_onnx.py | 41 ++++++++++++++ 5 files changed, 93 insertions(+), 58 deletions(-) create mode 100644 tests/test_onnx.py diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 5793513..5a89685 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -27,7 +27,6 @@ import pandas as pd from tabulate import tabulate import logging import torch -from torch.autograd import Variable import torch.optim import distiller from .summary_graph import SummaryGraph @@ -40,21 +39,21 @@ __all__ = ['model_summary', 'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary', 'attributes_summary', 'attributes_summary_tbl', 'connectivity_summary', 'connectivity_summary_verbose', 'connectivity_tbl_summary', 'create_png', 'create_pydot_graph', - 'draw_model_to_file', 'draw_img_classifier_to_file'] + 'draw_model_to_file', 'draw_img_classifier_to_file', 'export_img_classifier_to_onnx'] def model_summary(model, what, dataset=None): - if what == 'sparsity': + if what.startswith('png'): + draw_img_classifier_to_file(model, 'model.png', dataset, what == 'png_w_params') + elif what == 'sparsity': pylogger = PythonLogger(msglogger) csvlogger = CsvLogger() distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger]) elif what == 'compute': - if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224)) - elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) - else: - print("Unsupported dataset (%s) - aborting compute operation" % dataset) + try: + dummy_input = dataset_dummy_input(dataset) + except ValueError as e: + print(e) return df = model_performance_summary(model, dummy_input, 1) t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") @@ -432,44 +431,54 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F 'fillcolor': 'gray', 'style': 'rounded, filled'} """ + dummy_input = dataset_dummy_input(dataset) try: - dummy_input = dataset_dummy_input(dataset) - model = distiller.make_non_parallel_copy(model) - g = SummaryGraph(model, dummy_input) + non_para_model = distiller.make_non_parallel_copy(model) + g = SummaryGraph(non_para_model, dummy_input) + draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles) print("Network PNG image generation completed") except FileNotFoundError: print("An error has occured while generating the network PNG image.") print("Please check that you have graphviz installed.") print("\t$ sudo apt-get install graphviz") + finally: + del non_para_model def dataset_dummy_input(dataset): if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) + dummy_input = torch.randn(1, 3, 224, 224) elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) + dummy_input = torch.randn(1, 3, 32, 32) else: - raise ValueError("Unsupported dataset (%s) - aborting draw operation" % dataset) + raise ValueError("Unsupported dataset (%s) - aborting operation" % dataset) return dummy_input -def export_img_classifier_to_onnx(model, onnx_fname, dataset, export_params=True, add_softmax=True): +def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, **kwargs): """Export a PyTorch image classifier to ONNX. + Args: + add_softmax: when True, adds softmax layer to the output model. + kwargs: arguments to be passed to torch.onnx.export """ dummy_input = dataset_dummy_input(dataset).to('cuda') - # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel - model = distiller.make_non_parallel_copy(model) + # Pytorch doesn't support exporting modules wrapped in DataParallel + non_para_model = distiller.make_non_parallel_copy(model) - with torch.onnx.set_training(model, False): + try: if add_softmax: # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. - model.original_forward = model.forward + # TorchVision models use nn.CrossEntropyLoss for computing the loss, + # instead of adding a softmax layer + non_para_model.original_forward = non_para_model.forward softmax = torch.nn.Softmax(dim=-1) - model.forward = lambda input: softmax(model.original_forward(input)) - torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=export_params) + non_para_model.forward = lambda input: softmax(non_para_model.original_forward(input)) + torch.onnx.export(non_para_model, dummy_input, onnx_fname, **kwargs) msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname)) + finally: + del non_para_model def data_node_has_parent(g, id): diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 9429381..cc7965e 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -196,7 +196,14 @@ def main(): # This sample application can be invoked to produce various summary reports. if args.summary: - return summarize_model(model, args.dataset, which_summary=args.summary) + for s in which_summary: + distiller.model_summary(model, s, dataset) + return + + if args.export_onnx is not None: + return distiller.export_img_classifier_to_onnx(model, + os.path.join(msglogger.logdir, args.export_onnx), + args.dataset, add_softmax=True, verbose=False) if args.qe_calibration: return acts_quant_stats_collection(model, criterion, pylogger, args) @@ -642,15 +649,6 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector dir=msglogger.logdir, extras={'quantized_top1': top1}) -def summarize_model(model, dataset, which_summary): - if which_summary.startswith('png'): - model_summaries.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params') - elif which_summary == 'onnx': - model_summaries.export_img_classifier_to_onnx(model, 'model.onnx', dataset) - else: - distiller.model_summary(model, which_summary, dataset) - - def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsities): # This sample application can be invoked to execute Sensitivity Analysis on your # model. The ouptut is saved to CSV and PNG. diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index 803facf..5bc4870 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -23,7 +23,7 @@ from distiller.utils import float_range_argparse_checker as float_range import distiller.models as models -SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] def get_parser(): @@ -82,9 +82,10 @@ def get_parser(): help='print masks sparsity table at end of each epoch') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)') - parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, - help='print a summary of the model, and exit - options: ' + - ' | '.join(SUMMARY_CHOICES)) + parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, action='append', + help='print a summary of the model, and exit - options: | '.join(SUMMARY_CHOICES)) + parser.add_argument('--export-onnx', action='store', nargs='?', type=str, const='model.onnx', default=None, + help='export model to ONNX format') parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', help='configuration file for pruning the model (default is to use hard-coded schedule)') parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(), diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index e5271a8..b15badc 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -27,23 +27,15 @@ logger = logging.getLogger() logger.addHandler(fh) -def test_png_generation(): - dataset = "cifar10" - arch = "resnet20_cifar" - model, _ = common.setup_test(arch, dataset, parallel=True) - # 2 different ways to create a PNG - distiller.draw_img_classifier_to_file(model, 'model.png', dataset, True) - distiller.draw_img_classifier_to_file(model, 'model.png', dataset, False) - +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] -def test_negative(): +@pytest.mark.parametrize('display_param_nodes', [True, False]) +def test_png_generation(display_param_nodes): dataset = "cifar10" arch = "resnet20_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - - with pytest.raises(ValueError): - # png is not a supported summary type, so we expect this to fail with a ValueError - distiller.model_summary(model, what='png', dataset=dataset) + # 2 different ways to create a PNG + distiller.draw_img_classifier_to_file(model, 'model.png', dataset, display_param_nodes) def test_compute_summary(): @@ -67,16 +59,10 @@ def test_compute_summary(): assert module_macs == expected_macs -def test_summary(): +@pytest.mark.parametrize('what', SUMMARY_CHOICES) +def test_summary(what): dataset = "cifar10" arch = "resnet20_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - distiller.model_summary(model, what='sparsity', dataset=dataset) - distiller.model_summary(model, what='compute', dataset=dataset) - distiller.model_summary(model, what='model', dataset=dataset) - distiller.model_summary(model, what='modules', dataset=dataset) - - -if __name__ == '__main__': - test_compute_summary() \ No newline at end of file + distiller.model_summary(model, what, dataset=dataset) diff --git a/tests/test_onnx.py b/tests/test_onnx.py new file mode 100644 index 0000000..d509312 --- /dev/null +++ b/tests/test_onnx.py @@ -0,0 +1,41 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import tempfile + +import distiller + +import pytest +import common # common test code + + +# Logging configuration +logging.basicConfig(level=logging.INFO) +fh = logging.FileHandler('test.log') +logger = logging.getLogger() +logger.addHandler(fh) + + +@pytest.mark.parametrize('arch', + ['resnet18', 'resnet20_cifar', 'alexnet', 'vgg19', 'resnext50_32x4d']) +@pytest.mark.parametrize('add_softmax', [True, False]) +def test_summary(arch, add_softmax): + dataset = 'cifar10' if arch.endswith('cifar') else 'imagenet' + model, _ = common.setup_test(arch, dataset, parallel=True) + + with tempfile.NamedTemporaryFile() as f: + distiller.export_img_classifier_to_onnx(model, f.name, dataset, add_softmax=add_softmax) -- GitLab