diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 57935131288961e2d0d2f05db991eeda0d93e364..5a89685e22b7ff74fa3840fe1f7afab7ac5b07f2 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -27,7 +27,6 @@ import pandas as pd from tabulate import tabulate import logging import torch -from torch.autograd import Variable import torch.optim import distiller from .summary_graph import SummaryGraph @@ -40,21 +39,21 @@ __all__ = ['model_summary', 'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary', 'attributes_summary', 'attributes_summary_tbl', 'connectivity_summary', 'connectivity_summary_verbose', 'connectivity_tbl_summary', 'create_png', 'create_pydot_graph', - 'draw_model_to_file', 'draw_img_classifier_to_file'] + 'draw_model_to_file', 'draw_img_classifier_to_file', 'export_img_classifier_to_onnx'] def model_summary(model, what, dataset=None): - if what == 'sparsity': + if what.startswith('png'): + draw_img_classifier_to_file(model, 'model.png', dataset, what == 'png_w_params') + elif what == 'sparsity': pylogger = PythonLogger(msglogger) csvlogger = CsvLogger() distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger]) elif what == 'compute': - if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224)) - elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) - else: - print("Unsupported dataset (%s) - aborting compute operation" % dataset) + try: + dummy_input = dataset_dummy_input(dataset) + except ValueError as e: + print(e) return df = model_performance_summary(model, dummy_input, 1) t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") @@ -432,44 +431,54 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F 'fillcolor': 'gray', 'style': 'rounded, filled'} """ + dummy_input = dataset_dummy_input(dataset) try: - dummy_input = dataset_dummy_input(dataset) - model = distiller.make_non_parallel_copy(model) - g = SummaryGraph(model, dummy_input) + non_para_model = distiller.make_non_parallel_copy(model) + g = SummaryGraph(non_para_model, dummy_input) + draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles) print("Network PNG image generation completed") except FileNotFoundError: print("An error has occured while generating the network PNG image.") print("Please check that you have graphviz installed.") print("\t$ sudo apt-get install graphviz") + finally: + del non_para_model def dataset_dummy_input(dataset): if dataset == 'imagenet': - dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False) + dummy_input = torch.randn(1, 3, 224, 224) elif dataset == 'cifar10': - dummy_input = Variable(torch.randn(1, 3, 32, 32)) + dummy_input = torch.randn(1, 3, 32, 32) else: - raise ValueError("Unsupported dataset (%s) - aborting draw operation" % dataset) + raise ValueError("Unsupported dataset (%s) - aborting operation" % dataset) return dummy_input -def export_img_classifier_to_onnx(model, onnx_fname, dataset, export_params=True, add_softmax=True): +def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, **kwargs): """Export a PyTorch image classifier to ONNX. + Args: + add_softmax: when True, adds softmax layer to the output model. + kwargs: arguments to be passed to torch.onnx.export """ dummy_input = dataset_dummy_input(dataset).to('cuda') - # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel - model = distiller.make_non_parallel_copy(model) + # Pytorch doesn't support exporting modules wrapped in DataParallel + non_para_model = distiller.make_non_parallel_copy(model) - with torch.onnx.set_training(model, False): + try: if add_softmax: # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. - model.original_forward = model.forward + # TorchVision models use nn.CrossEntropyLoss for computing the loss, + # instead of adding a softmax layer + non_para_model.original_forward = non_para_model.forward softmax = torch.nn.Softmax(dim=-1) - model.forward = lambda input: softmax(model.original_forward(input)) - torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=export_params) + non_para_model.forward = lambda input: softmax(non_para_model.original_forward(input)) + torch.onnx.export(non_para_model, dummy_input, onnx_fname, **kwargs) msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname)) + finally: + del non_para_model def data_node_has_parent(g, id): diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 9429381bc9611d264b80eaaf11e4641391b2df4a..cc7965e717a06514530f45ae2859ab5d09ccd040 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -196,7 +196,14 @@ def main(): # This sample application can be invoked to produce various summary reports. if args.summary: - return summarize_model(model, args.dataset, which_summary=args.summary) + for s in which_summary: + distiller.model_summary(model, s, dataset) + return + + if args.export_onnx is not None: + return distiller.export_img_classifier_to_onnx(model, + os.path.join(msglogger.logdir, args.export_onnx), + args.dataset, add_softmax=True, verbose=False) if args.qe_calibration: return acts_quant_stats_collection(model, criterion, pylogger, args) @@ -642,15 +649,6 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector dir=msglogger.logdir, extras={'quantized_top1': top1}) -def summarize_model(model, dataset, which_summary): - if which_summary.startswith('png'): - model_summaries.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params') - elif which_summary == 'onnx': - model_summaries.export_img_classifier_to_onnx(model, 'model.onnx', dataset) - else: - distiller.model_summary(model, which_summary, dataset) - - def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsities): # This sample application can be invoked to execute Sensitivity Analysis on your # model. The ouptut is saved to CSV and PNG. diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index 803facf3cee7a2bbca15d76ed63a6f1613158cb8..5bc4870a9b7ed3ec816b1556cbe6e12762ebf6fc 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -23,7 +23,7 @@ from distiller.utils import float_range_argparse_checker as float_range import distiller.models as models -SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx'] +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] def get_parser(): @@ -82,9 +82,10 @@ def get_parser(): help='print masks sparsity table at end of each epoch') parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False, help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)') - parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, - help='print a summary of the model, and exit - options: ' + - ' | '.join(SUMMARY_CHOICES)) + parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, action='append', + help='print a summary of the model, and exit - options: | '.join(SUMMARY_CHOICES)) + parser.add_argument('--export-onnx', action='store', nargs='?', type=str, const='model.onnx', default=None, + help='export model to ONNX format') parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store', help='configuration file for pruning the model (default is to use hard-coded schedule)') parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(), diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index e5271a8bf081b54027d51f9870d637d3a94d7f4e..b15badc5ffa7cc5010b9341552422344ae7c4aea 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -27,23 +27,15 @@ logger = logging.getLogger() logger.addHandler(fh) -def test_png_generation(): - dataset = "cifar10" - arch = "resnet20_cifar" - model, _ = common.setup_test(arch, dataset, parallel=True) - # 2 different ways to create a PNG - distiller.draw_img_classifier_to_file(model, 'model.png', dataset, True) - distiller.draw_img_classifier_to_file(model, 'model.png', dataset, False) - +SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] -def test_negative(): +@pytest.mark.parametrize('display_param_nodes', [True, False]) +def test_png_generation(display_param_nodes): dataset = "cifar10" arch = "resnet20_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - - with pytest.raises(ValueError): - # png is not a supported summary type, so we expect this to fail with a ValueError - distiller.model_summary(model, what='png', dataset=dataset) + # 2 different ways to create a PNG + distiller.draw_img_classifier_to_file(model, 'model.png', dataset, display_param_nodes) def test_compute_summary(): @@ -67,16 +59,10 @@ def test_compute_summary(): assert module_macs == expected_macs -def test_summary(): +@pytest.mark.parametrize('what', SUMMARY_CHOICES) +def test_summary(what): dataset = "cifar10" arch = "resnet20_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - distiller.model_summary(model, what='sparsity', dataset=dataset) - distiller.model_summary(model, what='compute', dataset=dataset) - distiller.model_summary(model, what='model', dataset=dataset) - distiller.model_summary(model, what='modules', dataset=dataset) - - -if __name__ == '__main__': - test_compute_summary() \ No newline at end of file + distiller.model_summary(model, what, dataset=dataset) diff --git a/tests/test_onnx.py b/tests/test_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..d509312dcab00f5b13b4eb86bdb8c60d3e01452d --- /dev/null +++ b/tests/test_onnx.py @@ -0,0 +1,41 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import tempfile + +import distiller + +import pytest +import common # common test code + + +# Logging configuration +logging.basicConfig(level=logging.INFO) +fh = logging.FileHandler('test.log') +logger = logging.getLogger() +logger.addHandler(fh) + + +@pytest.mark.parametrize('arch', + ['resnet18', 'resnet20_cifar', 'alexnet', 'vgg19', 'resnext50_32x4d']) +@pytest.mark.parametrize('add_softmax', [True, False]) +def test_summary(arch, add_softmax): + dataset = 'cifar10' if arch.endswith('cifar') else 'imagenet' + model, _ = common.setup_test(arch, dataset, parallel=True) + + with tempfile.NamedTemporaryFile() as f: + distiller.export_img_classifier_to_onnx(model, f.name, dataset, add_softmax=add_softmax)