From 032b1f743a50f69527e272806aa4db5f457e5012 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Thu, 4 Jul 2019 17:45:51 +0300
Subject: [PATCH] Switch to PyTorch 1.1.0 (#306)

* PyTorch 1.1.0 now required
  - Moved other dependencies to up-to-date versions as well
* Adapt LR scheduler to PyTorch 1.1 API changes:
  - Change lr_scheduler.step() calls to succeed validate calls,
    during training
  - Pass to lr_scheduler.step() caller both loss and top1
    (Resolves issue #240)
* Adapt thinning for PyTorch 1.1 semantic changes
  - **KNOWN ISSUE**: When a thinning recipe is applied, in certain
    cases PyTorch displays this warning:
    "UserWarning: non-inplace resize is deprecated".
    To be fixed later
* SummaryGraph: Workaround for new scope name issue from PyTorch 1.1.0
* Adapt to updated PyTest version:
  - Stop using deprecated 'message' parameter of pytest.raises(),
    use pytest.fail() instead
  - Make sure only a single test case per pytest.raises context
* Move PyTorch version check to root __init__.py
  - This means the version each checked when Distiller is first
    imported. A RuntimeError is raised if the version is wrong.
* Updates to parameter_histograms notebook:
  - Replace deprecated normed argument with density
  - Add sparsity rate to plot title
  - Load model in CPU
---
 README.md                                     |  7 +-
 distiller/__init__.py                         | 16 +++
 distiller/knowledge_distillation.py           |  2 +-
 distiller/models/__init__.py                  | 32 ++++--
 distiller/policy.py                           | 13 +--
 distiller/scheduler.py                        | 14 +--
 distiller/summary_graph.py                    | 57 ++++++++---
 distiller/thinning.py                         | 99 ++++++++++---------
 .../compress_classifier.py                    | 22 +----
 jupyter/parameter_histograms.ipynb            | 85 +++++++---------
 requirements.txt                              | 30 +++---
 tests/common.py                               |  8 ++
 tests/full_flow_tests.py                      |  6 +-
 tests/test_infra.py                           |  5 +-
 tests/test_quant_utils.py                     | 27 +++--
 tests/test_quantizer.py                       | 47 ++++-----
 tests/test_summarygraph.py                    | 55 ++++++++---
 17 files changed, 300 insertions(+), 225 deletions(-)

diff --git a/README.md b/README.md
index b798bd3..fd4a556 100755
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ Network compression can reduce the memory footprint of a neural network, increas
 
 #### Note on Release 0.3 - Possible BREAKING Changes
 
-As of release 0.3, we've moved some code around to enable proper packaging and installation of Distiller. In addition, we updated Distiller to support PyTorch 1.0.1, which might also cause older code to break due to some API changes.  
+As of release 0.3, we've moved some code around to enable proper packaging and installation of Distiller. In addition, we updated Distiller to support PyTorch 1.X, which might also cause older code to break due to some API changes.  
 If updating from an earlier revision of the code, please make sure to follow the instructions in the [install](#install-the-package) section to make sure proper installation of Distiller and all dependencies.
 <details><summary><b>What's New in November?</b></summary>
 <p>
@@ -222,6 +222,8 @@ $ source env/bin/activate
 ```
 
 ### Install the package
+If you do not use CUDA 9 in your environment, please refer to [Pytorch website](https://pytorch.org/get-started/locally/) to install the compatible build of Pytorch 1.1 and torchvision 0.3, before installing the package.
+
 Finally, install the Distiller package and its dependencies using ```pip3```:
 ```
 $ cd distiller
@@ -229,8 +231,7 @@ $ pip3 install -e .
 ```
 This installs Distiller in "development mode", meaning any changes made in the code are reflected in the environment without re-running the install command (so no need to re-install after pulling changes from the Git repository).
 
-PyTorch is included in the ```requirements.txt``` file, and will currently download PyTorch version 1.0.1 for CUDA 9.0.  This is the setup we've used for testing Distiller.
-If you do not use CUDA 9 in your environment, please refer to [Pytorch website](https://pytorch.org/get-started/locally/) to install the compatible build of Pytorch 1.0.1. Use `pip3 install --force` to reinstall.
+PyTorch is included in the ```requirements.txt``` file, and will currently download PyTorch version 1.1.0 for CUDA 9.0.  This is the setup we've used for testing Distiller.
 
 ## Getting Started
 
diff --git a/distiller/__init__.py b/distiller/__init__.py
index fff740e..62556f4 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+import torch
 from .utils import *
 from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask
 from .config import file_config, dict_config, config_component_from_file_by_class
@@ -98,3 +99,18 @@ def model_find_module(model, module_to_find):
         if name == module_to_find:
             return m
     return None
+
+
+def check_pytorch_version():
+    from pkg_resources import parse_version
+    if parse_version(torch.__version__) < parse_version('1.1.0'):
+        msg = "\n\nWRONG PYTORCH VERSION\n"\
+              "The Distiller \'master\' branch now requires at least PyTorch version 1.1.0 due to "\
+              "PyTorch API changes which are not backward-compatible. Version detected is {}.\n"\
+              "To make sure PyTorch and all other dependencies are installed with their correct versions, " \
+              "go to the Distiller repo root directory and run:\n\n"\
+              "pip install -e .\n".format(torch.__version__)
+        raise RuntimeError(msg)
+
+
+check_pytorch_version()
diff --git a/distiller/knowledge_distillation.py b/distiller/knowledge_distillation.py
index 6f544d4..acba5ce 100644
--- a/distiller/knowledge_distillation.py
+++ b/distiller/knowledge_distillation.py
@@ -129,7 +129,7 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy):
     def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
         self.active = True
 
-    def on_epoch_end(self, model, zeros_mask_dict, meta):
+    def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
         self.active = False
 
     def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, zeros_mask_dict,
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 6eea6fd..712e5e5 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -16,6 +16,8 @@
 
 """This package contains ImageNet and CIFAR image classification models for pytorch"""
 
+import copy
+
 import torch
 import torchvision.models as torch_models
 from . import cifar10 as cifar10_models
@@ -32,9 +34,12 @@ msglogger = logging.getLogger()
 # TorchVision's version.
 RESNET_SYMS = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
 
-IMAGENET_MODEL_NAMES = sorted(name for name in torch_models.__dict__
-                              if name.islower() and not name.startswith("__")
-                              and callable(torch_models.__dict__[name]))
+TORCHVISION_MODEL_NAMES = sorted(
+                            name for name in torch_models.__dict__
+                            if name.islower() and not name.startswith("__")
+                            and callable(torch_models.__dict__[name]))
+
+IMAGENET_MODEL_NAMES = copy.deepcopy(TORCHVISION_MODEL_NAMES)
 IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__
                                    if name.islower() and not name.startswith("__")
                                    and callable(imagenet_extra_models.__dict__[name])))
@@ -72,15 +77,22 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     if dataset == 'imagenet':
         if arch in RESNET_SYMS:
             model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
-        elif arch in torch_models.__dict__:
-            model = torch_models.__dict__[arch](pretrained=pretrained)
-        elif (arch in imagenet_extra_models.__dict__) and not pretrained:
+        elif arch in TORCHVISION_MODEL_NAMES:
+            try:
+                model = getattr(torch_models, arch)(pretrained=pretrained)
+            except NotImplementedError:
+                # In torchvision 0.3, trying to download a model that has no
+                # pretrained image available will raise NotImplementedError
+                if not pretrained:
+                    raise
+        if model is None and (arch in imagenet_extra_models.__dict__) and not pretrained:
             model = imagenet_extra_models.__dict__[arch]()
-        elif arch in pretrainedmodels.model_names:
+        if model is None and (arch in pretrainedmodels.model_names):
             cadene = True
-            model = pretrainedmodels.__dict__[arch](num_classes=1000,
-                                                    pretrained=(dataset if pretrained else None))
-        else:
+            model = pretrainedmodels.__dict__[arch](
+                        num_classes=1000,
+                        pretrained=(dataset if pretrained else None))
+        if model is None:
             error_message = ''
             if arch not in IMAGENET_MODEL_NAMES:
                 error_message = "Model {} is not supported for dataset ImageNet".format(arch)
diff --git a/distiller/policy.py b/distiller/policy.py
index b918ffa..d460d24 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -19,11 +19,11 @@
 - PruningPolicy: prunning policy
 - RegularizationPolicy: regulization scheduling
 - LRPolicy: learning-rate decay scheduling
+- QuantizationPolicy: quantization scheduling
 """
 import torch
 import torch.optim.lr_scheduler
 from collections import namedtuple
-#from functools import partial
 import logging
 msglogger = logging.getLogger()
 
@@ -75,7 +75,7 @@ class ScheduledTrainingPolicy(object):
         """The mini-batch training pass has ended"""
         pass
 
-    def on_epoch_end(self, model, zeros_mask_dict, meta):
+    def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
         """The current epoch has ended"""
         pass
 
@@ -177,7 +177,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
             for param_name, param in model.named_parameters():
                 zeros_mask_dict[param_name].mask = None
 
-    def on_epoch_end(self, model, zeros_mask_dict, meta):
+    def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
         """The current epoch has ended"""
         if self.is_last_epoch:
             for param_name, param in model.named_parameters():
@@ -241,12 +241,13 @@ class LRPolicy(ScheduledTrainingPolicy):
         super(LRPolicy, self).__init__()
         self.lr_scheduler = lr_scheduler
 
-    def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
+    def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
         if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
             # Note: ReduceLROnPlateau doesn't inherit from _LRScheduler
-            self.lr_scheduler.step(kwargs['metrics'], epoch=meta['current_epoch'])
+            self.lr_scheduler.step(kwargs['metrics'][self.lr_scheduler.mode],
+                                   epoch=meta['current_epoch'] + 1)
         else:
-            self.lr_scheduler.step(epoch=meta['current_epoch'])
+            self.lr_scheduler.step(epoch=meta['current_epoch'] + 1)
 
 
 class QuantizationPolicy(ScheduledTrainingPolicy):
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index bf0b937..4bd04d4 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -170,13 +170,13 @@ class CompressionScheduler(object):
                 policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch,
                                         self.zeros_mask_dict, optimizer)
 
-    def on_epoch_end(self, epoch, optimizer=None):
-        if epoch in self.policies:
-            for policy in self.policies[epoch]:
-                meta = self.sched_metadata[policy]
-                meta['current_epoch'] = epoch
-                meta['optimizer'] = optimizer
-                policy.on_epoch_end(self.model, self.zeros_mask_dict, meta)
+    def on_epoch_end(self, epoch, optimizer=None, **kwargs):
+        for policy in self.policies.get(epoch, list()):
+            meta = self.sched_metadata[policy]
+            meta['current_epoch'] = epoch
+            meta['optimizer'] = optimizer
+            policy.on_epoch_end(self.model, self.zeros_mask_dict, meta,
+                                **kwargs)
 
     def mask_all_weights(self, is_forward=True):
         for name, param in self.model.named_parameters():
diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index f876152..bea53e5 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -68,7 +68,7 @@ class SummaryGraph(object):
     """
     Edge = collections.namedtuple('Edge', 'src dst')
 
-    def __init__(self, model, dummy_input):
+    def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):
         self._src_model = model
         model_clone = distiller.make_non_parallel_copy(model)
         with torch.onnx.set_training(model_clone, False):
@@ -77,16 +77,38 @@ class SummaryGraph(object):
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
             trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
 
-            # ONNX trace optimization has issues with Gemm ops (aka "Linear" / "addmm" / "FC"), where
-            # Gemm nodes get the scope name of the last non-Gemm node that came before them. This can make
-            # it impossible, in some cases, to derive the connectivity of the model using the original
-            # module names. So we save the scope names for these nodes from the un-optimized trace.
-            #
-            # Note that if the node prior to the Gemm node isn't the result of a dedicated module call,
-            # then this issue doesn't occur. For simplicity we just track all Gemms.
-            aten_addmm_nodes_scope_names = [n.scopeName() for n in trace.graph().nodes() if n.kind() == 'aten::addmm']
+            # As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names
+            # of nodes in the trace graph.
+            # These can make it impossible, in some cases, to derive the connectivity of the model using the original
+            # module names. So we try to detect these cases and apply workarounds
+
+            # Issue #1:
+            #   Gemm ops (aka "Linear" / "addmm" / "FC") get the scope name of the last non-Gemm node
+            #   that came before them.
+            #   Note that if the node prior to the Gemm node isn't the result of a dedicated module call,
+            #   then this issue doesn't occur. For simplicity we just track all Gemms.
+            # TODO: This should be fixed in PyTorch 1.2.0, revisit when it's released
+            aten_addmm_nodes_scope_names = []
             onnx_gemm_count = 0
 
+            # Issue #2:
+            #   Dropout ops are removed by ONNX trace optimization. However, the op BEFORE the original dropout op
+            #   gets the scope name of the dropout op
+            pre_dropout_nodes_scope_names = OrderedDict()
+
+            prev_non_dropout_op = None
+            for node in trace.graph().nodes():
+                kind = node.kind()
+                if 'aten' not in kind:
+                    continue
+                if kind == 'aten::dropout':
+                    if prev_non_dropout_op:
+                        pre_dropout_nodes_scope_names[node.scopeName()] = prev_non_dropout_op.scopeName()
+                else:
+                    prev_non_dropout_op = node
+                    if kind == 'aten::addmm':
+                        aten_addmm_nodes_scope_names.append(node.scopeName())
+
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
@@ -105,11 +127,18 @@ class SummaryGraph(object):
             for node in graph.nodes():
                 new_op = self.__create_op(node)
 
-                # Here we apply the workaround to the Gemm nodes scope name issue mentioned above
-                if new_op['type'] == 'Gemm':
-                    new_op['orig-name'] = aten_addmm_nodes_scope_names[onnx_gemm_count]
-                    new_op['name'] = new_op['orig-name']
-                    onnx_gemm_count += 1
+                if apply_scope_name_workarounds:
+                    # Here we apply the workaround to the Gemm nodes scope name issue mentioned above
+                    if new_op['type'] == 'Gemm':
+                        new_op['orig-name'] = aten_addmm_nodes_scope_names[onnx_gemm_count]
+                        new_op['name'] = new_op['orig-name']
+                        onnx_gemm_count += 1
+
+                    # Here we apply the workaround to the issue of dropout op scope name overriding previous op's
+                    # scope name
+                    if new_op['name'] in pre_dropout_nodes_scope_names:
+                        new_op['orig-name'] = pre_dropout_nodes_scope_names[new_op['name']]
+                        new_op['name'] = new_op['orig-name']
 
                 # Convert the graph node's scope name to a PyTorch module name
                 module_name = onnx_name_2_pytorch_name(new_op['orig-name'])
diff --git a/distiller/thinning.py b/distiller/thinning.py
index 43608b5..239bdbc 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -422,7 +422,7 @@ class StructureRemover(ScheduledTrainingPolicy):
             return
         self.__apply(model, zeros_mask_dict, optimizer)
 
-    def on_epoch_end(self, model, zeros_mask_dict, meta):
+    def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
         # The epoch has ended and we reset the 'done' flag, so that the FilterRemover instance can be reused
         self.done = False
 
@@ -502,55 +502,56 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
 
     assert len(recipe.parameters) > 0
 
-    for param_name, param_directives in recipe.parameters.items():
-        msglogger.debug("{} : {}".format(param_name, param_directives))
-        param = distiller.model_find_param(model, param_name)
-        assert param is not None
-        for directive in param_directives:
-            dim = directive[0]
-            indices = directive[1].to(device)
-            len_indices = indices.nelement()
-            if len(directive) == 4:  # TODO: this code is hard to follow
-                msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2])))
-                selection_view = param.view(*directive[2])
-                # Check that we're not trying to trim a parameter that is already "thin"
-                if param.data.size(dim) != len_indices:
-                    param.data = torch.index_select(selection_view, dim, indices)
+    with torch.no_grad():
+        for param_name, param_directives in recipe.parameters.items():
+            msglogger.debug("{} : {}".format(param_name, param_directives))
+            param = distiller.model_find_param(model, param_name)
+            assert param is not None
+            for directive in param_directives:
+                dim = directive[0]
+                indices = directive[1].to(device)
+                len_indices = indices.nelement()
+                if len(directive) == 4:  # TODO: this code is hard to follow
+                    msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2])))
+                    selection_view = param.view(*directive[2])
+                    # Check that we're not trying to trim a parameter that is already "thin"
+                    if param.data.size(dim) != len_indices:
+                        param.data = torch.index_select(selection_view, dim, indices)
+                        if param.grad is not None:
+                            # We also need to change the dimensions of the gradient tensor.
+                            grad_selection_view = param.grad.resize(*directive[2])
+                            if grad_selection_view.size(dim) != len_indices:
+                                param.grad = torch.index_select(grad_selection_view, dim, indices)
+                                # update optimizer
+                                if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
+                                    msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
+                                                    format(param_name, dim, len_indices, directive[3]))
+
+                    param.data = param.view(*directive[3])
                     if param.grad is not None:
-                        # We also need to change the dimensions of the gradient tensor.
-                        grad_selection_view = param.grad.resize_(*directive[2])
-                        if grad_selection_view.size(dim) != len_indices:
-                            param.grad = torch.index_select(grad_selection_view, dim, indices)
-                            # update optimizer
-                            if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
-                                msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
-                                                format(param_name, dim, len_indices, directive[3]))
-
-                param.data = param.view(*directive[3])
-                if param.grad is not None:
-                    param.grad = param.grad.resize_(*directive[3])
-            else:
-                if param.data.size(dim) != len_indices:
-                    msglogger.debug("[thinning] changing param {} ({})  dim:{}  new len: {}".format(
-                        param_name, param.shape, dim, len_indices))
-                    assert param.size(dim) > len_indices
-                    param.data = torch.index_select(param.data, dim, indices.to(param.device))
-                    msglogger.debug("[thinning] changed param {}".format(param_name))
-                # We also need to change the dimensions of the gradient tensor.
-                # If have not done a backward-pass thus far, then the gradient will
-                # not exist, and therefore won't need to be re-dimensioned.
-                if param.grad is not None and param.grad.size(dim) != len_indices:
-                    param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
-                    # update optimizer
-                    if optimizer_thinning(optimizer, param, dim, indices):
-                         msglogger.debug("Updated velocity buffer %s" % param_name)
-
-            if not loaded_from_file:
-                # If the masks are loaded from a checkpoint file, then we don't need to change
-                # their shape, because they are already correctly shaped
-                mask = zeros_mask_dict[param_name].mask
-                if mask is not None and (mask.size(dim) != len_indices):
-                    zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
+                        param.grad = param.grad.resize_(*directive[3])
+                else:
+                    if param.data.size(dim) != len_indices:
+                        msglogger.debug("[thinning] changing param {} ({})  dim:{}  new len: {}".format(
+                            param_name, param.shape, dim, len_indices))
+                        assert param.size(dim) > len_indices
+                        param.data = torch.index_select(param.data, dim, indices.to(param.device))
+                        msglogger.debug("[thinning] changed param {}".format(param_name))
+                    # We also need to change the dimensions of the gradient tensor.
+                    # If have not done a backward-pass thus far, then the gradient will
+                    # not exist, and therefore won't need to be re-dimensioned.
+                    if param.grad is not None and param.grad.size(dim) != len_indices:
+                        param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
+                        # update optimizer
+                        if optimizer_thinning(optimizer, param, dim, indices):
+                            msglogger.debug("Updated velocity buffer %s" % param_name)
+
+                if not loaded_from_file:
+                    # If the masks are loaded from a checkpoint file, then we don't need to change
+                    # their shape, because they are already correctly shaped
+                    mask = zeros_mask_dict[param_name].mask
+                    if mask is not None and (mask.size(dim) != len_indices):
+                        zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
 
 # Todo: consider removing this function
 def resnet_cifar_remove_layers(model):
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 93f9298..5c26cd6 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -29,8 +29,8 @@ For each epoch:
     compression_scheduler.on_epoch_begin(epoch)
     train()
     validate()
-    save_checkpoint()
     compression_scheduler.on_epoch_end(epoch)
+    save_checkpoint()
 
 train():
     For each training step:
@@ -278,8 +278,7 @@ def main():
         # This is the main training loop.
         msglogger.info('\n')
         if compression_scheduler:
-            compression_scheduler.on_epoch_begin(epoch,
-                metrics=(vloss if (epoch != start_epoch) else 10**6))
+            compression_scheduler.on_epoch_begin(epoch)
 
         # Train for one epoch
         with collectors_context(activations_collectors["train"]) as collectors:
@@ -306,7 +305,7 @@ def main():
                                         loggers=[tflogger])
 
         if compression_scheduler:
-            compression_scheduler.on_epoch_end(epoch, optimizer)
+            compression_scheduler.on_epoch_end(epoch, optimizer, metrics={'min': vloss, 'max': top1})
 
         # Update the list of top scores achieved so far, and save the checkpoint
         update_training_scores_history(perf_scores_history, model, top1, top5, epoch, args.num_best_scores)
@@ -775,23 +774,8 @@ def save_collectors_data(collectors, directory):
         msglogger.info("Saved to {}".format(file_path))
 
 
-def check_pytorch_version():
-    from pkg_resources import parse_version
-    if parse_version(torch.__version__) < parse_version('1.0.1'):
-        print("\nNOTICE:")
-        print("The Distiller \'master\' branch now requires at least PyTorch version 1.0.1 due to "
-              "PyTorch API changes which are not backward-compatible.\n"
-              "Please install PyTorch 1.0.1 or its derivative.\n"
-              "If you are using a virtual environment, do not forget to update it:\n"
-              "  1. Deactivate the old environment\n"
-              "  2. Install the new environment\n"
-              "  3. Activate the new environment")
-        exit(1)
-
-
 if __name__ == '__main__':
     try:
-        check_pytorch_version()
         main()
     except KeyboardInterrupt:
         print("\n-- KeyboardInterrupt --")
diff --git a/jupyter/parameter_histograms.ipynb b/jupyter/parameter_histograms.ipynb
index e3aec72..63198c7 100644
--- a/jupyter/parameter_histograms.ipynb
+++ b/jupyter/parameter_histograms.ipynb
@@ -16,19 +16,12 @@
    "outputs": [],
    "source": [
     "import torch\n",
-    "import torchvision\n",
     "import torch.nn as nn\n",
-    "from torch.autograd import Variable\n",
     "import scipy.stats as ss\n",
-    "\n",
-    "# Relative import of code from distiller, w/o installing the package\n",
-    "import os\n",
-    "import sys\n",
     "import numpy as np\n",
     "import matplotlib.pyplot as plt\n",
     "import distiller\n",
     "import distiller.models as models\n",
-    "from distiller.apputils import *\n",
     "\n",
     "plt.style.use('seaborn') # pretty matplotlib plots"
    ]
@@ -44,16 +37,17 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "scrolled": false
+    "scrolled": true
    },
    "outputs": [],
    "source": [
     "# It is interesting to compare the distribution of non-pretrained model (Normally-distributed)\n",
     "# vs. the distribution of the pretrained model.\n",
-    "model = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=True)\n",
+    "model = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50',\n",
+    "                            device_ids=-1)  # load to CPU\n",
     "\n",
     "# Optionally load your compressed model \n",
-    "# load_checkpoint(model, <path-to-your-checkpoint-file>);"
+    "# distiller.apputils.load_checkpoint(model, <path-to-your-checkpoint-file>)"
    ]
   },
   {
@@ -69,52 +63,49 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "scrolled": false
+    "scrolled": true
    },
    "outputs": [],
    "source": [
-    "def flatten(weights):\n",
-    "    weights = weights.view(weights.numel())\n",
-    "    weights = weights.data.cpu().numpy()\n",
-    "    return weights\n",
+    "def getSparsity(x):\n",
+    "    return 1 - (x[x.nonzero()].size / x.size)\n",
     "\n",
     "REMOVE_ZEROS = False\n",
     "nbins = 500\n",
     "for name, weights in model.named_parameters():\n",
-    "    if weights.dim() == 4:\n",
-    "        size_str = \"x\".join([str(s) for s in weights.size()])\n",
-    "        weights = flatten(weights)\n",
-    "        \n",
-    "        if REMOVE_ZEROS:\n",
-    "            # Optionally remove zeros (lots of zeros will dominate the histogram and the \n",
-    "            # other data will be hard to see\n",
-    "            weights = weights[weights!=0]\n",
-    "        \n",
-    "        # Fit the data to the Normal distribution\n",
-    "        (mean_fitted, std_fitted) = ss.norm.fit(weights)\n",
-    "        x = np.linspace(min(weights), max(weights), nbins)\n",
-    "        weights_gauss_fitted = ss.norm.pdf(x, loc=mean_fitted, scale=std_fitted)\n",
+    "    if weights.dim() != 4:\n",
+    "        # not convolution layer\n",
+    "        continue\n",
+    "\n",
+    "    shape_str = \"x\".join(map(str, weights.shape))\n",
+    "    weights = weights.cpu().detach().numpy().flatten()\n",
+    "    sparsity = getSparsity(weights)\n",
+    "\n",
+    "    if REMOVE_ZEROS:\n",
+    "        # Optionally remove zeros (lots of zeros will dominate the histogram and the \n",
+    "        # other data will be hard to see\n",
+    "        weights = weights[weights.nonzero()]\n",
     "\n",
-    "        # Fit the data to the Laplacian distribution\n",
-    "        (mean_fitted, std_fitted) = ss.laplace.fit(weights)\n",
-    "        weights_laplace_fitted = ss.laplace.pdf(x, loc=mean_fitted, scale=std_fitted)\n",
+    "    # Fit the data to the Normal distribution\n",
+    "    (mean_fitted, std_fitted) = ss.norm.fit(weights)\n",
+    "    x = np.linspace(min(weights), max(weights), nbins)\n",
+    "    weights_gauss_fitted = ss.norm.pdf(x, loc=mean_fitted, scale=std_fitted)\n",
     "\n",
-    "        n, bins, patches = plt.hist(weights, histtype='stepfilled', \n",
-    "                                    cumulative=False, bins=nbins, normed=1)\n",
-    "        plt.plot(x, weights_gauss_fitted, label='gauss')\n",
-    "        plt.plot(x, weights_laplace_fitted, label='laplace')\n",
-    "        plt.title(name + \" - \" +size_str)\n",
-    "        #plt.figure(figsize=(10,5))\n",
-    "        plt.legend()\n",
-    "        plt.show()"
+    "    # Fit the data to the Laplacian distribution\n",
+    "    (mean_fitted, std_fitted) = ss.laplace.fit(weights)\n",
+    "    weights_laplace_fitted = ss.laplace.pdf(x, loc=mean_fitted, scale=std_fitted)\n",
+    "\n",
+    "    n, bins, patches = plt.hist(weights, histtype='stepfilled', \n",
+    "                                cumulative=False, bins=nbins, density=True)\n",
+    "    plt.plot(x, weights_gauss_fitted, label='gauss')\n",
+    "    plt.plot(x, weights_laplace_fitted, label='laplace')\n",
+    "\n",
+    "    plt.title(name + \" - \" + shape_str + (\n",
+    "        ' - sparsity: {:.0%}'.format(sparsity) if REMOVE_ZEROS else ''))\n",
+    "    #plt.figure(figsize=(10,5))\n",
+    "    plt.legend()\n",
+    "    plt.show()"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {
@@ -133,7 +124,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.5.2"
+   "version": "3.6.8"
   }
  },
  "nbformat": 4,
diff --git a/requirements.txt b/requirements.txt
index b8deb5c..8eb594d 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,22 +1,22 @@
-torch==1.0.1
-numpy>=1.14.3
-torchvision==0.2.1
-scipy>=1.1.0
+torch==1.1.0
+numpy>=1.16
+torchvision==0.3.0
+scipy>=1.3.0
 gitpython==2.1.11
 torchnet==0.0.4
-tensorflow>=1.7.0
-pydot==1.2.4
-tabulate==0.8.2
+tensorflow>=1.13
+pydot==1.4.1
+tabulate==0.8.3
 pandas>=0.22.0
 jupyter>=1.0.0
-matplotlib>=2.2.2
-qgrid==1.0.2
-graphviz==0.8.2
-ipywidgets==7.1.2
-bqplot==0.10.5
+matplotlib~=3.0  # 3.0 is the last release to support Py3.5
+qgrid==1.1.1
+graphviz==0.10.1
+ipywidgets==7.4.2
+bqplot==0.11.5
 pyyaml
-pytest==3.5.1
+pytest~=4.6.1
 xlsxwriter>=1.1.1
-pretrainedmodels
-scikit-learn
+pretrainedmodels==0.7.4
+scikit-learn==0.21.2
 gym==0.12.5
diff --git a/tests/common.py b/tests/common.py
index 1afabf6..1ad80b1 100755
--- a/tests/common.py
+++ b/tests/common.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 import torch
+import pytest
 import os
 import errno
 import distiller
@@ -49,3 +50,10 @@ def find_module_by_name(model, module_to_find):
 
 def almost_equal(a , b, max_diff=0.000001):
     return abs(a - b) <= max_diff
+
+
+def pytest_raises_wrapper(exc_type, msg, func, *args, **kwargs):
+    with pytest.raises(exc_type):
+        func(*args, **kwargs)
+        if msg:
+            pytest.fail(msg)
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index 64fa60b..3666a18 100755
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -116,13 +116,13 @@ def collateral_checker(log, run_dir, *collateral_list):
 TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args'])
 
 test_configs = [
-    TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.610, 92.080]),
+    TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]),
     TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'),
-               DS_CIFAR, accuracy_checker, [91.55, 99.63]),
+               DS_CIFAR, accuracy_checker, [91.58, 99.63]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
                format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
-               DS_CIFAR, accuracy_checker, [54.590, 94.810]),
+               DS_CIFAR, accuracy_checker, [44.370, 89.640]),
     TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
                DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96158)])
diff --git a/tests/test_infra.py b/tests/test_infra.py
index f1c3786..c778e51 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -130,10 +130,9 @@ def test_load_dumb_checkpoint():
 
 
 def test_load_negative():
+    model = create_model(False, 'cifar10', 'resnet20_cifar')
     with pytest.raises(FileNotFoundError):
-        model = create_model(False, 'cifar10', 'resnet20_cifar')
-        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model,
-            'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
+        load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
 
 
 def test_load_gpu_model_on_cpu():
diff --git a/tests/test_quant_utils.py b/tests/test_quant_utils.py
index c1dad32..4243438 100644
--- a/tests/test_quant_utils.py
+++ b/tests/test_quant_utils.py
@@ -17,6 +17,7 @@
 import torch
 import pytest
 from distiller.quantization import q_utils as qu
+from common import pytest_raises_wrapper
 
 
 def test_symmetric_qparams():
@@ -161,11 +162,12 @@ test_tensor = torch.tensor([-93, 33, -77, -42, -89, -55, 79, -19, -94,
 test_tensor_4d = test_tensor.reshape(2, 2, 3, 3)
 test_tensor_2d = test_tensor.reshape(6, 6)
 
+too_large_dim_msg = "Expecting ValueError when passing too large dim"
+
 
 def test_get_tensor_min_max():
-    with pytest.raises(ValueError, message="Expecting ValueError when passing too large dim"):
-        qu.get_tensor_min_max(test_tensor_2d, per_dim=2)
-        qu.get_tensor_min_max(test_tensor_2d, per_dim=6)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_min_max, test_tensor_2d, per_dim=2)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_min_max, test_tensor_2d, per_dim=6)
 
     t_min, t_max = qu.get_tensor_min_max(test_tensor_4d)
     assert torch.equal(t_min, torch.tensor(-95.))
@@ -181,9 +183,8 @@ def test_get_tensor_min_max():
 
 
 def test_get_tensor_avg_min_max():
-    with pytest.raises(ValueError, message="Expecting ValueError when passing too large dim"):
-        qu.get_tensor_avg_min_max(test_tensor_2d, across_dim=2)
-        qu.get_tensor_avg_min_max(test_tensor_2d, across_dim=6)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_avg_min_max, test_tensor_2d, across_dim=2)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_avg_min_max, test_tensor_2d, across_dim=6)
 
     t_min, t_max = qu.get_tensor_avg_min_max(test_tensor_2d)
     assert torch.equal(t_min, torch.tensor(-95.))
@@ -199,9 +200,8 @@ def test_get_tensor_avg_min_max():
 
 
 def test_get_tensor_max_abs():
-    with pytest.raises(ValueError, message="Expecting ValueError when passing too large dim"):
-        qu.get_tensor_min_max(test_tensor_2d, per_dim=2)
-        qu.get_tensor_min_max(test_tensor_2d, per_dim=6)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_max_abs, test_tensor_2d, per_dim=2)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_max_abs, test_tensor_2d, per_dim=6)
 
     t_abs = qu.get_tensor_max_abs(test_tensor_4d)
     assert torch.equal(t_abs, torch.tensor(95.))
@@ -214,9 +214,8 @@ def test_get_tensor_max_abs():
 
 
 def test_get_tensor_avg_max_abs():
-    with pytest.raises(ValueError, message="Expecting ValueError when passing too large dim"):
-        qu.get_tensor_min_max(test_tensor_2d, per_dim=2)
-        qu.get_tensor_min_max(test_tensor_2d, per_dim=6)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_avg_max_abs, test_tensor_2d, across_dim=2)
+    pytest_raises_wrapper(ValueError, too_large_dim_msg, qu.get_tensor_avg_max_abs, test_tensor_2d, across_dim=6)
 
     t_abs = qu.get_tensor_avg_max_abs(test_tensor_2d)
     assert torch.equal(t_abs, torch.tensor(95.))
@@ -229,8 +228,8 @@ def test_get_tensor_avg_max_abs():
 
 
 def test_get_tensor_mean_n_stds_min_max():
-    with pytest.raises(ValueError, message='Expecting ValueError with n_stds = 0'):
-        qu.get_tensor_mean_n_stds_min_max(test_tensor, n_stds=0)
+    pytest_raises_wrapper(ValueError, 'Expecting ValueError with n_stds = 0',
+                          qu.get_tensor_mean_n_stds_min_max, test_tensor, n_stds=0)
 
     mean = torch.tensor(-16.)
     std = torch.tensor(62.87447738647461)
diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py
index d1f3763..eecfa3f 100644
--- a/tests/test_quantizer.py
+++ b/tests/test_quantizer.py
@@ -25,6 +25,7 @@ from distiller.quantization import Quantizer
 from distiller.quantization.quantizer import QBits, _ParamToQuant
 from distiller.quantization.quantizer import FP_BKP_PREFIX
 from distiller import has_children
+from common import pytest_raises_wrapper
 
 
 #############################
@@ -216,13 +217,15 @@ def test_no_quantization(model):
 
 
 def test_overrides_ordered_dict(model):
-    with pytest.raises(TypeError, message='Expecting TypeError when overrides is not an OrderedDict'):
-        DummyQuantizer(model, overrides={'testing': '123'})
+    pytest_raises_wrapper(TypeError, 'Expecting TypeError when overrides is not an OrderedDict',
+                          DummyQuantizer, model, overrides={'testing': {'testing': '123'}})
+
 
 acts_key = 'bits_activations'
 wts_key = 'bits_weights'
 bias_key = 'bits_bias'
 
+
 @pytest.mark.parametrize(
     "qbits, overrides, explicit_expected_overrides",
     [
@@ -386,27 +389,25 @@ def test_param_quantization(model, optimizer, qbits, overrides, explicit_expecte
 
 
 def test_overridable_args(model, optimizer, train_with_fp_copy):
-    with pytest.raises(ValueError, message='Expecting ValueError when overriding args without overriding bits.'):
-        model_copy = deepcopy(model)
-        conv_override = OrderedDict([(acts_key, None), (wts_key, None), (bias_key, None), ('prop', 123)])
-        overrides = OrderedDict([('conv1', conv_override)])
-        q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
-        q.prepare_model()
-
-    with pytest.raises(TypeError, message='Expecting TypeError when overrides contains unexpected args.'):
-        model_copy = deepcopy(model)
-        conv_override = OrderedDict([(acts_key, 8), (wts_key, 8), (bias_key, 32), ('prop', 123), ('unexpetcted_prop', 456)])
-        overrides = OrderedDict([('conv1', conv_override)])
-        q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
-        q.prepare_model()
-
-    with pytest.raises(TypeError, message='Expecting TypeError when overrides contains unexpected args.'):
-        model_copy = deepcopy(model)
-        relu_override = OrderedDict([(acts_key, 8), (wts_key, None), (bias_key, None),
-                                     ('overridable_prop', 123), ('unexpetcted_prop', 456)])
-        overrides = OrderedDict([('relu1', relu_override)])
-        q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
-        q.prepare_model()
+    model_copy = deepcopy(model)
+    conv_override = OrderedDict([(acts_key, None), (wts_key, None), (bias_key, None), ('prop', 123)])
+    overrides = OrderedDict([('conv1', conv_override)])
+    q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
+    pytest_raises_wrapper(ValueError, 'Expecting ValueError when overriding args without overriding bits',
+                          q.prepare_model)
+
+    model_copy = deepcopy(model)
+    conv_override = OrderedDict([(acts_key, 8), (wts_key, 8), (bias_key, 32), ('prop', 123), ('unexpetcted_prop', 456)])
+    overrides = OrderedDict([('conv1', conv_override)])
+    q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
+    pytest_raises_wrapper(TypeError, 'Expecting TypeError when overrides contains unexpected args', q.prepare_model)
+
+    model_copy = deepcopy(model)
+    relu_override = OrderedDict([(acts_key, 8), (wts_key, None), (bias_key, None),
+                                 ('overridable_prop', 123), ('unexpetcted_prop', 456)])
+    overrides = OrderedDict([('relu1', relu_override)])
+    q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
+    pytest_raises_wrapper(TypeError, 'Expecting TypeError when overrides contains unexpected args', q.prepare_model)
 
     model_copy = deepcopy(model)
     conv_override = OrderedDict([(acts_key, 8), (wts_key, 8), (bias_key, 32), ('prop', 123)])
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 0e767aa..927ccba 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -251,7 +251,7 @@ def test_merge_pad_avgpool():
     assert sg.ops[avgpool_ops[0]]['type'] == 'AveragePool'
 
 
-def test_gemm_nodes_scope_names():
+def test_scope_name_workarounds():
     class ModelWithGemms(nn.Module):
         def __init__(self):
             super(ModelWithGemms, self).__init__()
@@ -262,22 +262,55 @@ def test_gemm_nodes_scope_names():
             self.fc2 = nn.Linear(50, 25)
             self.relu2 = nn.ReLU(inplace=True)
             self.fc3 = nn.Linear(25, 1)
+            self.drop3 = nn.Dropout()
 
         def forward(self, x):
-            # Isn't this pretty...
-            return self.fc3(self.relu2(self.fc2(self.drop2(self.relu1(self.fc1(self.drop1(x)))))))
+            x = self.drop1(x)
+            x = self.fc1(x)
+            x = self.relu1(x)
+            x = self.drop2(x)
+            x = self.fc2(x)
+            x = self.relu2(x)
+            x = self.fc3(x)
+            x = self.drop3(x)
+            return x
 
     m = ModelWithGemms()
-    sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 100)))
+    dummy_input = distiller.get_dummy_input(input_shape=(1, 100))
+    expected_types = ('Gemm', 'Relu', 'Gemm', 'Relu', 'Gemm')
+
+    # We have workarounds for 2 issues:
+    #   1. GEMM ops get the scope name of the op that came before them
+    #   2. Ops that come before a dropout op get the scope name of the dropout op
+    # If both conditions apply, empirically that #2 is the issue that manifests
+
+    # For the model above we expect the ops in the graph to be named (in order):
+    #   'fc1', 'relu1', 'fc2', 'relu2', 'fc3'
+    # (note that dropout ops are dropped)
+    #
+    # But without our workarounds in place, we'll get:
+    #   'drop1', 'drop2', 'drop2__1', 'relu2', 'drop3'
+    #
+    # What happens is:
+    #   * 'fc1' - issue #1 applies, so 'fc1' --> 'drop1'
+    #   * 'relu1' - issue #2 applies, so 'relu1' --> 'drop2'
+    #   * 'fc2' - issue #1 applies, so 'fc1' --> 'drop2__1' ('__1' suffix because 'drop2' already exists)
+    #   * 'relu2' should be ok as-is
+    #   * 'fc3' is susceptible to both issues - it's a GEMM op AND it comes before a dropout. As mentioned above,
+    #     issue #2 "wins", so 'fc3' --> 'drop3'
+
+    # We test without the workarounds as a means to see if the issues still exist. New PyTorch versions
+    # may fix them, in which case we can remove the workarounds
+    sg = SummaryGraph(m, dummy_input, apply_scope_name_workarounds=False)
+    names, types = zip(*[(op_name, op['type']) for op_name, op in sg.ops.items()])
+    assert names == ('drop1', 'drop2', 'drop2__1', 'relu2', 'drop3')
+    assert types == expected_types
 
-    # For the model above we expect the ops to be named (in order):
-    #   'drop1', 'fc1', 'relu1', 'drop2', 'fc2', 'relu2', 'fc3'
-    # But without our workaround in place, they'll be named:
-    #   'drop1', 'drop1__1', 'relu1', 'drop2', 'drop2__1', 'relu2', 'relu2__1'
-    # (that is - each FC node gets the name of the node before)
+    # Now test with the workarounds
+    sg = SummaryGraph(m, dummy_input)
     names, types = zip(*[(op_name, op['type']) for op_name, op in sg.ops.items()])
-    assert names == ('drop1', 'fc1', 'relu1', 'drop2', 'fc2', 'relu2', 'fc3')
-    assert types == ('Dropout', 'Gemm', 'Relu', 'Dropout', 'Gemm', 'Relu', 'Gemm')
+    assert names == ('fc1', 'relu1', 'fc2', 'relu2', 'fc3')
+    assert types == expected_types
 
 
 @pytest.fixture(params=[False, True], ids=['dedicated_modules_off', 'dedicated_modules_on'])
-- 
GitLab