diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index fc96aa0800072ae1c64d224b50d4bb77cf5629b6..c859924ff9345287fc0842a7f357acc5da054d98 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -26,6 +26,8 @@ import torch
 from torchnet.meter import AverageValueMeter
 import logging
 from math import sqrt
+import matplotlib
+matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 import distiller
 msglogger = logging.getLogger()
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index 57935131288961e2d0d2f05db991eeda0d93e364..5a89685e22b7ff74fa3840fe1f7afab7ac5b07f2 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -27,7 +27,6 @@ import pandas as pd
 from tabulate import tabulate
 import logging
 import torch
-from torch.autograd import Variable
 import torch.optim
 import distiller
 from .summary_graph import SummaryGraph
@@ -40,21 +39,21 @@ __all__ = ['model_summary',
            'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary',
            'attributes_summary', 'attributes_summary_tbl', 'connectivity_summary',
            'connectivity_summary_verbose', 'connectivity_tbl_summary', 'create_png', 'create_pydot_graph',
-           'draw_model_to_file', 'draw_img_classifier_to_file']
+           'draw_model_to_file', 'draw_img_classifier_to_file', 'export_img_classifier_to_onnx']
 
 
 def model_summary(model, what, dataset=None):
-    if what == 'sparsity':
+    if what.startswith('png'):
+        draw_img_classifier_to_file(model, 'model.png', dataset, what == 'png_w_params')
+    elif what == 'sparsity':
         pylogger = PythonLogger(msglogger)
         csvlogger = CsvLogger()
         distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger])
     elif what == 'compute':
-        if dataset == 'imagenet':
-            dummy_input = Variable(torch.randn(1, 3, 224, 224))
-        elif dataset == 'cifar10':
-            dummy_input = Variable(torch.randn(1, 3, 32, 32))
-        else:
-            print("Unsupported dataset (%s) - aborting compute operation" % dataset)
+        try:
+            dummy_input = dataset_dummy_input(dataset)
+        except ValueError as e:
+            print(e)
             return
         df = model_performance_summary(model, dummy_input, 1)
         t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
@@ -432,44 +431,54 @@ def draw_img_classifier_to_file(model, png_fname, dataset, display_param_nodes=F
                                    'fillcolor': 'gray',
                                    'style': 'rounded, filled'}
     """
+    dummy_input = dataset_dummy_input(dataset)
     try:
-        dummy_input = dataset_dummy_input(dataset)
-        model = distiller.make_non_parallel_copy(model)
-        g = SummaryGraph(model, dummy_input)
+        non_para_model = distiller.make_non_parallel_copy(model)
+        g = SummaryGraph(non_para_model, dummy_input)
+
         draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles)
         print("Network PNG image generation completed")
     except FileNotFoundError:
         print("An error has occured while generating the network PNG image.")
         print("Please check that you have graphviz installed.")
         print("\t$ sudo apt-get install graphviz")
+    finally:
+        del non_para_model
 
 
 def dataset_dummy_input(dataset):
     if dataset == 'imagenet':
-        dummy_input = Variable(torch.randn(1, 3, 224, 224), requires_grad=False)
+        dummy_input = torch.randn(1, 3, 224, 224)
     elif dataset == 'cifar10':
-        dummy_input = Variable(torch.randn(1, 3, 32, 32))
+        dummy_input = torch.randn(1, 3, 32, 32)
     else:
-        raise ValueError("Unsupported dataset (%s) - aborting draw operation" % dataset)
+        raise ValueError("Unsupported dataset (%s) - aborting operation" % dataset)
     return dummy_input
 
 
-def export_img_classifier_to_onnx(model, onnx_fname, dataset, export_params=True, add_softmax=True):
+def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, **kwargs):
     """Export a PyTorch image classifier to ONNX.
 
+    Args:
+        add_softmax: when True, adds softmax layer to the output model.
+        kwargs: arguments to be passed to torch.onnx.export
     """
     dummy_input = dataset_dummy_input(dataset).to('cuda')
-    # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel
-    model = distiller.make_non_parallel_copy(model)
+    # Pytorch doesn't support exporting modules wrapped in DataParallel
+    non_para_model = distiller.make_non_parallel_copy(model)
 
-    with torch.onnx.set_training(model, False):
+    try:
         if add_softmax:
             # Explicitly add a softmax layer, because it is needed for the ONNX inference phase.
-            model.original_forward = model.forward
+            # TorchVision models use nn.CrossEntropyLoss for computing the loss,
+            # instead of adding a softmax layer
+            non_para_model.original_forward = non_para_model.forward
             softmax = torch.nn.Softmax(dim=-1)
-            model.forward = lambda input: softmax(model.original_forward(input))
-        torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=export_params)
+            non_para_model.forward = lambda input: softmax(non_para_model.original_forward(input))
+        torch.onnx.export(non_para_model, dummy_input, onnx_fname, **kwargs)
         msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname))
+    finally:
+        del non_para_model
 
 
 def data_node_has_parent(g, id):
diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index 282d7fde0c4bf0f387a19d731d932444508abf47..5bd7d5cfd5a2bcd8e688e9901f7cf7d4224c5469 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -16,8 +16,8 @@
 
 from .eltwise import EltwiseAdd, EltwiseMult
 from .grouping import *
-from .rnn import DistillerLSTM, DistillerLSTMCell
+from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm
 
 __all__ = ['EltwiseAdd', 'EltwiseMult',
            'Concat', 'Chunk', 'Split', 'Stack',
-           'DistillerLSTMCell', 'DistillerLSTM']
+           'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py
index 633c088f4def2e635324d4fd25e476a1e7779d47..37886696d908834b0a892cb987aa2016561b71c2 100644
--- a/distiller/modules/rnn.py
+++ b/distiller/modules/rnn.py
@@ -20,7 +20,7 @@ import numpy as np
 from .eltwise import EltwiseAdd, EltwiseMult
 from itertools import product
 
-__all__ = ['DistillerLSTMCell', 'DistillerLSTM']
+__all__ = ['DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
 
 
 class DistillerLSTMCell(nn.Module):
@@ -430,3 +430,20 @@ class DistillerLSTM(nn.Module):
                 self.num_layers,
                 self.dropout_factor,
                 self.bidirectional)
+
+
+def convert_model_to_distiller_lstm(model: nn.Module):
+    """
+    Replaces all `nn.LSTM`s and `nn.LSTMCell`s in the model with distiller versions.
+    Args:
+        model (nn.Module): the model
+    """
+    if isinstance(model, nn.LSTMCell):
+        return DistillerLSTMCell.from_pytorch_impl(model)
+    if isinstance(model, nn.LSTM):
+        return DistillerLSTM.from_pytorch_impl(model)
+    for name, module in model.named_children():
+        module = convert_model_to_distiller_lstm(module)
+        setattr(model, name, module)
+
+    return model
diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index d246fd58dc13ea4c6e2f10399e015a719f830890..a5f8ce41d3762395a02c3da786aeca003912b3f7 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -21,6 +21,7 @@ import collections
 import torch
 import torch.jit as jit
 import logging
+from collections import OrderedDict
 msglogger = logging.getLogger()
 
 
@@ -99,17 +100,17 @@ class SummaryGraph(object):
             
             device = next(model_clone.parameters()).device
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
-            trace, _ = jit.get_trace_graph(model_clone, dummy_input)
+            trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
 
             graph = trace.graph()
-            self.ops = {}
-            self.params = {}
+            self.ops = OrderedDict()
+            self.params = OrderedDict()
             self.edges = []
-            self.temp = {}
+            self.temp = OrderedDict()
 
             in_out = list(graph.inputs()) + list(graph.outputs())
             for param in in_out:
@@ -148,7 +149,7 @@ class SummaryGraph(object):
                     self.__add_output(new_op, output)
                     self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName()))
 
-                new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
+                new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
 
         self.add_macs_attr()
         self.add_footprint_attr()
@@ -156,7 +157,7 @@ class SummaryGraph(object):
         del model_clone
 
     def __create_op(self, onnx_node):
-        op = {}
+        op = OrderedDict()
         op['name'] = onnx_node.scopeName()
         op['orig-name'] = onnx_node.scopeName()
         op['type'] = onnx_node.kind().lstrip('::onnx')
@@ -188,7 +189,7 @@ class SummaryGraph(object):
         return param
 
     def __tensor_desc(self, n):
-        tensor = {}
+        tensor = OrderedDict()
         tensor['id'] = n.uniqueName()
         try:
             # try parsing the FM tensor type.  For example: Float(1, 64, 8, 8)
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 9429381bc9611d264b80eaaf11e4641391b2df4a..cc7965e717a06514530f45ae2859ab5d09ccd040 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -196,7 +196,14 @@ def main():
 
     # This sample application can be invoked to produce various summary reports.
     if args.summary:
-        return summarize_model(model, args.dataset, which_summary=args.summary)
+        for s in which_summary:
+            distiller.model_summary(model, s, dataset)
+        return
+
+    if args.export_onnx is not None:
+        return distiller.export_img_classifier_to_onnx(model,
+            os.path.join(msglogger.logdir, args.export_onnx),
+            args.dataset, add_softmax=True, verbose=False)
 
     if args.qe_calibration:
         return acts_quant_stats_collection(model, criterion, pylogger, args)
@@ -642,15 +649,6 @@ def evaluate_model(model, criterion, test_loader, loggers, activations_collector
                                  dir=msglogger.logdir, extras={'quantized_top1': top1})
 
 
-def summarize_model(model, dataset, which_summary):
-    if which_summary.startswith('png'):
-        model_summaries.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params')
-    elif which_summary == 'onnx':
-        model_summaries.export_img_classifier_to_onnx(model, 'model.onnx', dataset)
-    else:
-        distiller.model_summary(model, which_summary, dataset)
-
-
 def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsities):
     # This sample application can be invoked to execute Sensitivity Analysis on your
     # model.  The ouptut is saved to CSV and PNG.
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
index 803facf3cee7a2bbca15d76ed63a6f1613158cb8..5bc4870a9b7ed3ec816b1556cbe6e12762ebf6fc 100755
--- a/examples/classifier_compression/parser.py
+++ b/examples/classifier_compression/parser.py
@@ -23,7 +23,7 @@ from distiller.utils import float_range_argparse_checker as float_range
 import distiller.models as models
 
 
-SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
+SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params']
 
 
 def get_parser():
@@ -82,9 +82,10 @@ def get_parser():
                         help='print masks sparsity table at end of each epoch')
     parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                         help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
-    parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES,
-                        help='print a summary of the model, and exit - options: ' +
-                        ' | '.join(SUMMARY_CHOICES))
+    parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, action='append',
+                        help='print a summary of the model, and exit - options: | '.join(SUMMARY_CHOICES))
+    parser.add_argument('--export-onnx', action='store', nargs='?', type=str, const='model.onnx', default=None,
+                        help='export model to ONNX format')
     parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
                         help='configuration file for pruning the model (default is to use hard-coded schedule)')
     parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(),
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
index e5271a8bf081b54027d51f9870d637d3a94d7f4e..b15badc5ffa7cc5010b9341552422344ae7c4aea 100755
--- a/tests/test_model_summary.py
+++ b/tests/test_model_summary.py
@@ -27,23 +27,15 @@ logger = logging.getLogger()
 logger.addHandler(fh)
 
 
-def test_png_generation():
-    dataset = "cifar10"
-    arch = "resnet20_cifar"
-    model, _ = common.setup_test(arch, dataset, parallel=True)
-    # 2 different ways to create a PNG
-    distiller.draw_img_classifier_to_file(model, 'model.png', dataset, True)
-    distiller.draw_img_classifier_to_file(model, 'model.png', dataset, False)
-    
+SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params']
 
-def test_negative():
+@pytest.mark.parametrize('display_param_nodes', [True, False])
+def test_png_generation(display_param_nodes):
     dataset = "cifar10"
     arch = "resnet20_cifar"
     model, _ = common.setup_test(arch, dataset, parallel=True)
-
-    with pytest.raises(ValueError):
-        # png is not a supported summary type, so we expect this to fail with a ValueError
-        distiller.model_summary(model, what='png', dataset=dataset)
+    # 2 different ways to create a PNG
+    distiller.draw_img_classifier_to_file(model, 'model.png', dataset, display_param_nodes)
 
 
 def test_compute_summary():
@@ -67,16 +59,10 @@ def test_compute_summary():
     assert module_macs == expected_macs
 
 
-def test_summary():
+@pytest.mark.parametrize('what', SUMMARY_CHOICES)
+def test_summary(what):
     dataset = "cifar10"
     arch = "resnet20_cifar"
     model, _ = common.setup_test(arch, dataset, parallel=True)
 
-    distiller.model_summary(model, what='sparsity', dataset=dataset)
-    distiller.model_summary(model, what='compute', dataset=dataset)
-    distiller.model_summary(model, what='model', dataset=dataset)
-    distiller.model_summary(model, what='modules', dataset=dataset)
-
-
-if __name__ == '__main__':
-    test_compute_summary()
\ No newline at end of file
+    distiller.model_summary(model, what, dataset=dataset)
diff --git a/tests/test_onnx.py b/tests/test_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..d509312dcab00f5b13b4eb86bdb8c60d3e01452d
--- /dev/null
+++ b/tests/test_onnx.py
@@ -0,0 +1,41 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import tempfile
+
+import distiller
+
+import pytest
+import common  # common test code
+
+
+# Logging configuration
+logging.basicConfig(level=logging.INFO)
+fh = logging.FileHandler('test.log')
+logger = logging.getLogger()
+logger.addHandler(fh)
+
+
+@pytest.mark.parametrize('arch',
+    ['resnet18', 'resnet20_cifar', 'alexnet', 'vgg19', 'resnext50_32x4d'])
+@pytest.mark.parametrize('add_softmax', [True, False])
+def test_summary(arch, add_softmax):
+    dataset = 'cifar10' if arch.endswith('cifar') else 'imagenet'
+    model, _ = common.setup_test(arch, dataset, parallel=True)
+
+    with tempfile.NamedTemporaryFile() as f:
+        distiller.export_img_classifier_to_onnx(model, f.name, dataset, add_softmax=add_softmax)
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index f74bfdaafab546091b58a44fb39e54cbc7c43494..46fedb3b6a99734b46d0bc801247e93502a1016d 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -116,24 +116,6 @@ def test_layer_search():
     assert preds == ['layer1.0.conv2', 'conv1']
 
 
-def test_weights_size_attr():
-    def test(dataset, arch, dataparallel:bool):
-        model = create_model(False, dataset, arch, parallel=False)
-        sgraph = SummaryGraph(model, get_input(dataset))
-
-        distiller.assign_layer_fq_names(model)
-        for name, mod in model.named_modules():
-            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
-                op = sgraph.find_op(name)
-                assert op is not None
-                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
-
-    for data_parallel in (True, False):
-        test('cifar10', 'resnet20_cifar', data_parallel)
-        test('imagenet', 'alexnet', data_parallel)
-        test('imagenet', 'resnext101_32x4d', data_parallel)
-
-
 def test_vgg():
     g = create_graph('imagenet', 'vgg19')
     assert g is not None
@@ -188,11 +170,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
 
 
 def test_named_params_layers():
-    for data_parallel in (True, False):
-        named_params_layers_test_aux('imagenet', 'vgg19', data_parallel)
-        named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel)
-        named_params_layers_test_aux('imagenet', 'alexnet', data_parallel)
-        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel)
+    for dataParallelModel in (True, False):
+        named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
+        named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
+        named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
+        named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
 
 
 def test_onnx_name_2_pytorch_name():
@@ -229,6 +211,24 @@ def test_sg_macs():
             assert summary_macs == sg_macs
  
 
+def test_weights_size_attr():
+    def test(dataset, arch, dataparallel:bool):
+        model = create_model(False, dataset, arch, parallel=dataparallel)
+        sgraph = SummaryGraph(model, get_input(dataset))
+
+        distiller.assign_layer_fq_names(model)
+        for name, mod in model.named_modules():
+            if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear):
+                op = sgraph.find_op(name)
+                assert op is not None
+                assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
+
+    for data_parallel in (True, False):
+        test('cifar10', 'resnet20_cifar', data_parallel)
+        test('imagenet', 'alexnet', data_parallel)
+        test('imagenet', 'resnext101_32x4d', data_parallel)
+
+
 if __name__ == '__main__':
     #test_connectivity_summary()
-    test_sg_macs()
+    test_sg_macs()
\ No newline at end of file