From 60a4f44a11435cea757ca807788493b5d044506d Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 5 Nov 2018 15:59:32 +0200
Subject: [PATCH] Dynamic Network Surgery (#69)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Added an implementation of:

Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
NIPS 2016, https://arxiv.org/abs/1608.04493.

- Added SplicingPruner: A pruner that both prunes and splices connections.
- Included an example schedule on ResNet20 CIFAR.
- New features for compress_classifier.py:
   1. Added the "--masks-sparsity" which, when enabled, logs the sparsity
      of the weight masks during training.
  2. Added a new command-line argument to report the top N
      best accuracy scores, instead of just the highest score.
      This is sometimes useful when pruning a pre-trained model,
      that has the best Top1 accuracy in the first few pruning epochs.
- New features for PruningPolicy:
   1. The pruning policy can use two copies of the weights: one is used during
       the forward-pass, the other during the backward pass.
       This is controlled by the “mask_on_forward_only” argument.
   2. If we enable “mask_on_forward_only”, we probably want to permanently apply
       the mask at some point (usually once the pruning phase is done).
       This is controlled by the “keep_mask” argument.
   3. We introduce a first implementation of scheduling at the training-iteration
       granularity (i.e. at the mini-batch granularity). Until now we could schedule
       pruning at the epoch-granularity. This is controlled by the “mini_batch_pruning_frequency”
       (disable by setting to zero).

   Some of the abstractions may have leaked from PruningPolicy to CompressionScheduler.
   Need to reexamine this in the future.
---
 distiller/model_summaries.py                  |  30 +++-
 distiller/policy.py                           |  74 +++++++++-
 distiller/pruning/__init__.py                 |   1 +
 distiller/pruning/splicing_pruner.py          |  88 +++++++++++
 distiller/scheduler.py                        |  28 +++-
 distiller/utils.py                            |  37 ++++-
 .../compress_classifier.py                    |  28 ++--
 .../resnet20.network_surgery.yaml             | 139 ++++++++++++++++++
 .../resnet56_cifar_filter_rank_v2.yaml        |   4 +-
 9 files changed, 402 insertions(+), 27 deletions(-)
 create mode 100755 distiller/pruning/splicing_pruner.py
 create mode 100755 examples/network_surgery/resnet20.network_surgery.yaml

diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index 6e57421..e04ce06 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -34,7 +34,7 @@ msglogger = logging.getLogger()
 
 __all__ = ['model_summary',
            'weights_sparsity_summary', 'weights_sparsity_tbl_summary',
-           'model_performance_summary', 'model_performance_tbl_summary']
+           'model_performance_summary', 'model_performance_tbl_summary', 'masks_sparsity_tbl_summary']
 
 
 def model_summary(model, what, dataset=None):
@@ -126,6 +126,34 @@ def weights_sparsity_tbl_summary(model, return_total_sparsity=False, param_dims=
     return t
 
 
+def masks_sparsity_summary(model, scheduler, param_dims=[2, 4]):
+    df = pd.DataFrame(columns=['Name', 'Fine (%)'])
+    pd.set_option('precision', 2)
+    params_size = 0
+    sparse_params_size = 0
+    for name, param in model.state_dict().items():
+        # Extract just the actual parameter's name, which in this context we treat as its "type"
+        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
+            mask = scheduler.zeros_mask_dict[name].mask
+            if mask is None:
+                _density = 1
+            else:
+                _density = distiller.density(mask)
+            params_size += torch.numel(param)
+            sparse_params_size += param.numel() * _density
+            df.loc[len(df.index)] = ([name, (1-_density)*100])
+
+    assert params_size != 0
+    total_sparsity = (1 - sparse_params_size/params_size)*100
+    df.loc[len(df.index)] = (['Total sparsity:', total_sparsity])
+    return df
+
+
+def masks_sparsity_tbl_summary(model, scheduler, param_dims=[2, 4]):
+    df = masks_sparsity_summary(model, scheduler, param_dims=param_dims)
+    return tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f")
+
+
 # Performance data collection  code follows from here down
 
 def conv_visitor(self, input, output, df, model, memo):
diff --git a/distiller/policy.py b/distiller/policy.py
index dc85c60..f42219f 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -46,7 +46,8 @@ class ScheduledTrainingPolicy(object):
         """A new epcoh is about to begin"""
         pass
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch,
+                           zeros_mask_dict, meta, optimizer=None):
         """The forward-pass of a new mini-batch is about to begin"""
         pass
 
@@ -74,30 +75,89 @@ class ScheduledTrainingPolicy(object):
 
 class PruningPolicy(ScheduledTrainingPolicy):
     """Base class for pruning policies.
-
-    The current implementation restricts the pruning step to the beginning of
-    each epoch.  This can be easily changed.
     """
     def __init__(self, pruner, pruner_args, classes=None, layers=None):
+        """
+        Arguments:
+            mask_on_forward_only: controls what we do after the weights are updated by the backward pass.
+            In issue #53 (https://github.com/NervanaSystems/distiller/issues/53) we explain why in some
+            cases masked weights will be updated to a non-zero value, even if their gradients are masked
+            (e.g. when using SGD with momentum). Therefore, to circumvent this weights-update performed by
+            the backward pass, we usually mask the weights again - right after the backward pass.  To
+            disable this masking set:
+                pruner_args['mask_on_forward_only'] = False
+
+            use_double_copies: when set to 'True', two sets of weights are used. In the forward-pass we use
+            masked weights to compute the loss, but in the backward-pass we update the unmasked weights (using
+            gradients computed from the masked-weights loss).
+
+            mini_batch_pruning_frequency: this controls pruning scheduling at the mini-batch granularity.  Every
+            mini_batch_pruning_frequency training steps (i.e. mini_batches) we perform pruning.  This provides more
+            fine-grained control over pruning than that provided by CompressionScheduler (epoch granularity).
+            When setting 'mini_batch_pruning_frequency' to a value other than zero, make sure to configure the policy's
+            schedule to once-every-epoch.
+        """
         super(PruningPolicy, self).__init__(classes, layers)
         self.pruner = pruner
         self.levels = None
-        if pruner_args is not None and 'levels' in pruner_args:
-            self.levels = pruner_args['levels']
+        self.keep_mask = False
+        self.mini_batch_pruning_frequency = 0
+        self.mask_on_forward_only = False
+        self.use_double_copies = False
+        if pruner_args is not None:
+            if 'levels' in pruner_args:
+                self.levels = pruner_args['levels']
+            self.keep_mask = pruner_args.get('keep_mask', False)
+            self.mini_batch_pruning_frequency = pruner_args.get('mini_batch_pruning_frequency', 0)
+            self.mask_on_forward_only = pruner_args.get('mask_on_forward_only', False)
+            self.use_double_copies = pruner_args.get('use_double_copies', False)
+        self.is_last_epoch = False
+        self.mini_batch_id = 0          # The ID of the mini_batch within the present epoch
+        self.global_mini_batch_id = 0   # The ID of the mini_batch within the present training session
 
     def on_epoch_begin(self, model, zeros_mask_dict, meta):
         msglogger.debug("Pruner {} is about to prune".format(self.pruner.name))
+        self.mini_batch_id = 0
+        self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
+        self.is_first_epoch = meta['current_epoch'] == meta['starting_epoch']
         if self.levels is not None:
             self.pruner.levels = self.levels
 
+        if self.is_first_epoch:
+            self.global_mini_batch_id = 0
+
         meta['model'] = model
         for param_name, param in model.named_parameters():
+            if self.mask_on_forward_only and self.is_first_epoch:
+                zeros_mask_dict[param_name].use_double_copies = self.use_double_copies
+                zeros_mask_dict[param_name].mask_on_forward_only = self.mask_on_forward_only
             self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
 
-    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer=None):
+    def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch,
+                           zeros_mask_dict, meta, optimizer=None):
+        self.mini_batch_id += 1
+        self.global_mini_batch_id += 1
+        if (self.mini_batch_pruning_frequency != 0 and
+           self.global_mini_batch_id % self.mini_batch_pruning_frequency == 0):
+            for param_name, param in model.named_parameters():
+                self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
+
         for param_name, param in model.named_parameters():
             zeros_mask_dict[param_name].apply_mask(param)
 
+    def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
+        for param_name, param in model.named_parameters():
+            zeros_mask_dict[param_name].remove_mask(param)
+
+    def on_epoch_end(self, model, zeros_mask_dict, meta):
+        """The current epoch has ended"""
+        is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
+        if self.keep_mask and is_last_epoch:
+            for param_name, param in model.named_parameters():
+                zeros_mask_dict[param_name].use_double_copies = False
+                zeros_mask_dict[param_name].mask_on_forward_only = False
+                zeros_mask_dict[param_name].apply_mask(param)
+
 
 class RegularizationPolicy(ScheduledTrainingPolicy):
     """Regularization policy.
diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index 24fe7e5..2f576e9 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -24,6 +24,7 @@ from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureP
                                       RandomRankedFilterPruner_AGP
 from .level_pruner import SparsityLevelParameterPruner
 from .sensitivity_pruner import SensitivityPruner
+from .splicing_pruner import SplicingPruner
 from .structure_pruner import StructureParameterPruner
 from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \
                                       RandomRankedFilterPruner, GradientRankedFilterPruner
diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py
new file mode 100755
index 0000000..dfe7e25
--- /dev/null
+++ b/distiller/pruning/splicing_pruner.py
@@ -0,0 +1,88 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from .pruner import _ParameterPruner
+import torch
+import logging
+msglogger = logging.getLogger()
+
+
+class SplicingPruner(_ParameterPruner):
+    """A pruner that both prunes and splices connections.
+
+    The idea of pruning and splicing working in tandem was first proposed in the following
+    NIPS paper from Intel Labs China in 2016:
+        Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
+        NIPS 2016, https://arxiv.org/abs/1608.04493.
+
+    A SplicingPruner works best with a Dynamic Network Surgery schedule.
+    The original Caffe code from the authors of the paper is available here:
+    https://github.com/yiwenguo/Dynamic-Network-Surgery/blob/master/src/caffe/layers/compress_conv_layer.cpp
+    """
+
+    def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0):
+        """Arguments:
+        """
+        super(SplicingPruner, self).__init__(name)
+        self.sensitivities = sensitivities
+        self.low_thresh_mult = low_thresh_mult
+        self.hi_thresh_mult = hi_thresh_mult
+        self.sensitivity_multiplier = sensitivity_multiplier
+
+    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
+        if param_name not in self.sensitivities:
+            if '*' not in self.sensitivities:
+                return
+            else:
+                sensitivity = self.sensitivities['*']
+        else:
+            sensitivity = self.sensitivities[param_name]
+
+        if not hasattr(param, '_std'):
+            # Compute the mean and standard-deviation once, and cache them.
+            param._std = torch.std(param.abs()).item()
+            param._mean = torch.mean(param.abs()).item()
+
+        if self.sensitivity_multiplier > 0:
+            # Linearly growing sensitivity - for now this is hard-coded
+            starting_epoch = meta['starting_epoch']
+            current_epoch = meta['current_epoch']
+            sensitivity *= (current_epoch - starting_epoch) * self.sensitivity_multiplier + 1
+
+        threshold_low = (param._mean + param._std * sensitivity) * self.low_thresh_mult
+        threshold_hi = (param._mean + param._std * sensitivity) * self.hi_thresh_mult
+
+        if zeros_mask_dict[param_name].mask is None:
+            zeros_mask_dict[param_name].mask = torch.ones_like(param)
+
+        # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper:
+        #
+        #           0    if a  > |W|
+        # h(W) =    mask if a <= |W| < b
+        #           1    if b <= |W|
+        #
+        # h(W) is the so-called "network surgery function".
+        # mask is the mask used in the previous iteration.
+        # a and b are the low and high thresholds, respectively.
+        # We followed the example implementation from Yiwen Guo in Caffe, and used the
+        # weight tensor's starting mean and std.
+        # This is very similar to the initialization performed by distiller.SensitivityPruner.
+
+        masked_weights = param.mul(zeros_mask_dict[param_name].mask).abs()
+        a = masked_weights.ge(threshold_low)
+        b = a & zeros_mask_dict[param_name].mask.type(torch.cuda.ByteTensor)
+        zeros_mask_dict[param_name].mask = (b | masked_weights.ge(threshold_hi)).type(torch.cuda.FloatTensor)
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index d831c19..11b7e69 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -33,6 +33,9 @@ class ParameterMasker(object):
         self.mask = None                # Mask lazily initialized by pruners
         self.param_name = param_name    # For debug/logging purposes
         self.is_regularization_mask = False
+        self.use_double_copies = False
+        self.mask_on_forward_only = False
+        self.unmasked_copy = None
 
     def apply_mask(self, tensor):
         """Apply a mask on the weights tensor."""
@@ -40,11 +43,22 @@ class ParameterMasker(object):
             msglogger.debug('No mask for parameter {0}'.format(self.param_name))
             return
         msglogger.debug('Masking parameter {0}'.format(self.param_name))
+        if self.use_double_copies:
+            self.unmasked_copy = tensor.clone()
         tensor.data.mul_(self.mask)
         if self.is_regularization_mask:
             self.mask = None
         return tensor
 
+    def remove_mask(self, tensor):
+        if self.mask is None:
+            msglogger.debug('No mask for parameter {0}'.format(self.param_name))
+            return
+        if not self.use_double_copies:
+            msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
+            return
+        tensor.data = self.unmasked_copy.data
+
 
 def create_model_masks_dict(model):
     """A convinience function to create a dictionary of paramter maskers for a model"""
@@ -64,7 +78,6 @@ class CompressionScheduler(object):
         self.device = device
         self.policies = {}
         self.sched_metadata = {}
-
         self.zeros_mask_dict = {}
         for name, param in self.model.named_parameters():
             masker = ParameterMasker(name)
@@ -101,8 +114,10 @@ class CompressionScheduler(object):
     def on_minibatch_begin(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
         if epoch in self.policies:
             for policy in self.policies[epoch]:
+                meta = self.sched_metadata[policy]
+                meta['current_epoch'] = epoch
                 policy.on_minibatch_begin(self.model, epoch, minibatch_id, minibatches_per_epoch,
-                                          self.zeros_mask_dict, optimizer)
+                                          self.zeros_mask_dict, meta, optimizer)
 
     def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss, optimizer=None,
                              return_loss_components=False):
@@ -131,7 +146,7 @@ class CompressionScheduler(object):
         #
         # Therefore we choose to always apply the pruning mask.  In the future we may optimize this by applying
         # the mask only if the some policy is actually using the mask.
-        self.apply_mask()
+        self.apply_mask(is_forward=False)
         if epoch in self.policies:
             for policy in self.policies[epoch]:
                 policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch,
@@ -145,10 +160,13 @@ class CompressionScheduler(object):
                 meta['optimizer'] = optimizer
                 policy.on_epoch_end(self.model, self.zeros_mask_dict, meta)
 
-    def apply_mask(self):
+    def apply_mask(self, is_forward=True):
         for name, param in self.model.named_parameters():
             try:
-                self.zeros_mask_dict[name].apply_mask(param)
+                if is_forward or not self.zeros_mask_dict[name].mask_on_forward_only:
+                    # When we mask on forward-pass only, we allow the gradients to change
+                    # the weights.
+                    self.zeros_mask_dict[name].apply_mask(param)
             except KeyError:
                 # Quantizers for training modify some model parameters by adding a prefix
                 # If this is the source of the error, workaround and move on
diff --git a/distiller/utils.py b/distiller/utils.py
index 9e93755..25ccc08 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -49,6 +49,21 @@ def pretty_int(i):
     return "{:,}".format(i)
 
 
+class MutableNamedTuple(dict):
+    def __init__(self, init_dict):
+        for k, v in init_dict.items():
+            self[k] = v
+
+    def __getattr__(self, key):
+        return self[key]
+
+    def __setattr__(self, key, val):
+        if key in self.__dict__:
+            self.__dict__[key] = val
+        else:
+            self[key] = val
+
+
 def assign_layer_fq_names(container, name=None):
     """Assign human-readable names to the modules (layers).
 
@@ -108,11 +123,13 @@ def denormalize_module_name(parallel_model, normalized_name):
         return normalized_name   # Did not find a module with the name <normalized_name>
 
 
-def volume(tensor_desc):
+def volume(tensor):
     """return the volume of a pytorch tensor"""
-    if isinstance(tensor_desc, tuple):
-        return np.prod(tensor_desc)
-    return np.prod(tensor_desc.shape)
+    if isinstance(tensor, torch.FloatTensor) or isinstance(tensor, torch.cuda.FloatTensor):
+        return np.prod(tensor.shape)
+    if isinstance(tensor, tuple):
+        return np.prod(tensor)
+    raise ValueError
 
 
 def density(tensor):
@@ -269,6 +286,18 @@ def density_rows(tensor, transposed=True):
     return 1 - sparsity_rows(tensor, transposed)
 
 
+def model_sparsity(model, param_dims=[2, 4]):
+    params_size = 0
+    sparse_params_size = 0
+    for name, param in model.state_dict().items():
+        if param.dim() in param_dims and any(type in name for type in ['weight', 'bias']):
+            _density = density(param)
+            params_size += torch.numel(param)
+            sparse_params_size += param.numel() * _density
+    total_sparsity = (1 - sparse_params_size/params_size)*100
+    return total_sparsity
+
+
 def norm_filters(weights, p=1):
     """Compute the p-norm of convolution filters.
 
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 588b58d..d0663c6 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -120,6 +120,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                     help='use pre-trained model')
 parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None,
                     help='collect activation statistics (WARNING: this slows down training)')
+parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
+                    help='print masks sparsity table at end of each epoch')
 parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                     help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
 SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
@@ -148,6 +150,8 @@ parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earl
                     help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)')
 parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
                     help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)')
+parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
+                    help='number of best scores to track and report (default: 1)')
 
 quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
                                         '("post-training quantization)')
@@ -238,8 +242,8 @@ def main():
     msglogger.debug("Distiller: %s", distiller.__version__)
 
     start_epoch = 0
-    best_top1 = 0
-    best_epoch = 0
+    best_epochs = [distiller.MutableNamedTuple({'epoch': 0, 'top1': 0, 'sparsity': 0})
+                   for i in range(args.num_best_scores)]
 
     if args.deterministic:
         # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
@@ -372,6 +376,8 @@ def main():
             distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
             distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
                                                 collector=collectors["sparsity"])
+            if args.masks_sparsity:
+                msglogger.info(distiller.masks_sparsity_tbl_summary(model, compression_scheduler))
 
         # evaluate on validation set
         with collectors_context(activations_collectors["valid"]) as collectors:
@@ -391,13 +397,18 @@ def main():
             compression_scheduler.on_epoch_end(epoch, optimizer)
 
         # remember best top1 and save checkpoint
-        is_best = top1 > best_top1
+        #sparsity = distiller.model_sparsity(model)
+        is_best = top1 > best_epochs[0].top1
         if is_best:
-            best_epoch = epoch
-            best_top1 = top1
-        msglogger.info('==> Best Top1: %.3f   On Epoch: %d\n', best_top1, best_epoch)
-        apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best,
-                                 args.name, msglogger.logdir)
+            best_epochs[0].epoch = epoch
+            best_epochs[0].top1 = top1
+            #best_epoch.sparsity = sparsity
+            best_epochs = sorted(best_epochs, key=lambda score: score.top1)
+        for score in reversed(best_epochs):
+            if score.top1 > 0:
+                msglogger.info('==> Best Top1: %.3f on Epoch: %d', score.top1, score.epoch)
+        apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler,
+                                 best_epochs[0].top1, is_best, args.name, msglogger.logdir)
 
     # Finally run results on the test set
     test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)
@@ -525,7 +536,6 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args):
     with collectors_context(activations_collectors["test"]) as collectors:
         top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args)
         distiller.log_activation_statsitics(-1, "test", loggers, collector=collectors['sparsity'])
-
     return top1, top5, lossses
 
 
diff --git a/examples/network_surgery/resnet20.network_surgery.yaml b/examples/network_surgery/resnet20.network_surgery.yaml
new file mode 100755
index 0000000..ba8740e
--- /dev/null
+++ b/examples/network_surgery/resnet20.network_surgery.yaml
@@ -0,0 +1,139 @@
+# This schedule follows the methodology proposed by Intel Labs China in the paper:
+#   Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
+#   NIPS 2016, https://arxiv.org/abs/1600.504493.
+#
+# We have not been able to see great results with our implementation of the paper, as we understood it.  In the paper
+# two sets of weights are used: in the forward-pass Guo et. al use masked weights to compute the loss, but in the
+# backward-pass they update the unmasked weights (using gradients computed from the masked-weights loss).
+# To replicate this behavior set in the SplicingPruner policy:
+#   use_double_copies: True
+#   mask_on_forward_only: True
+#
+# We found that using two copies of weights reduces the accuracy results, and so we disable this configuration in the
+# example schedule below.
+#
+# The "mask_on_forward_only" configuration controls what we do after the weights are updated by the backward pass.
+# In issue #53 (https://github.com/NervanaSystems/distiller/issues/53) we explain why in some cases masked weights
+# will be updated to a non-zero value, even if their gradients are masked (e.g. when using SGD with momentum).
+# Therefore, to circumvent this weights-update performed by the backward pass, we usually mask the weights again -
+# right after the backward pass.  To disable this masking set:
+#   mask_on_forward_only: False
+#
+#
+# Baseline results:
+#     Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Total MACs: 40,813,184
+#     # of parameters: 270,896
+#
+# Results:
+#     Best Top1: 91.490 (on Epoch: 339)
+#     Total MACs: 40,813,184
+#     Total sparsity: 69.1%
+#     # of parameters: 83,671
+#
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml -j=1 --deterministic  --validation-size=0 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --masks-sparsity --num-best-scores=10
+#
+# Parameters:
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.36322 | -0.00677 |    0.25811 |
+# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |            641 |    0.00000 |    0.00000 |  6.25000 | 19.53125 |  0.00000 |   72.17882 | 0.12306 | -0.00732 |    0.05892 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |            629 |    0.00000 |    0.00000 |  0.00000 | 19.14062 |  0.00000 |   72.69965 | 0.12015 | -0.00124 |    0.05719 |
+# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |            670 |    0.00000 |    0.00000 |  0.00000 | 16.40625 |  0.00000 |   70.92014 | 0.10390 | -0.00880 |    0.05262 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |            659 |    0.00000 |    0.00000 |  0.00000 | 14.84375 |  0.00000 |   71.39757 | 0.09757 | -0.00261 |    0.04895 |
+# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |            714 |    0.00000 |    0.00000 |  0.00000 | 20.31250 |  0.00000 |   69.01042 | 0.13813 | -0.00660 |    0.07117 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |            685 |    0.00000 |    0.00000 |  0.00000 | 19.14062 |  0.00000 |   70.26910 | 0.10906 |  0.00146 |    0.05552 |
+# |  7 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           1534 |    0.00000 |    0.00000 |  0.00000 | 14.64844 |  0.00000 |   66.71007 | 0.10999 |  0.00096 |    0.05983 |
+# |  8 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           2915 |    0.00000 |    0.00000 |  0.00000 |  8.30078 |  0.00000 |   68.37023 | 0.09245 | -0.00413 |    0.04926 |
+# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            137 |    0.00000 |    0.00000 |  0.00000 | 73.24219 |  0.00000 |   73.24219 | 0.20471 | -0.00300 |    0.09534 |
+# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           2738 |    0.00000 |    0.00000 |  0.00000 |  9.96094 |  0.00000 |   70.29080 | 0.08032 | -0.00450 |    0.04164 |
+# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           2791 |    0.00000 |    0.00000 |  0.00000 |  8.39844 |  0.00000 |   69.71571 | 0.07054 | -0.00283 |    0.03704 |
+# | 12 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           2770 |    0.00000 |    0.00000 |  0.00000 | 11.03516 |  0.00000 |   69.94358 | 0.08060 | -0.00651 |    0.04184 |
+# | 13 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           2745 |    0.00000 |    0.00000 |  0.00000 | 11.52344 |  0.00000 |   70.21484 | 0.06489 |  0.00081 |    0.03376 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |           5852 |    0.00000 |    0.00000 |  0.00000 | 13.37891 |  0.00000 |   68.25087 | 0.07912 | -0.00226 |    0.04262 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          11772 |    0.00000 |    0.00000 |  0.00000 |  4.83398 |  0.00000 |   68.06641 | 0.07390 | -0.00186 |    0.03979 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |            667 |    0.00000 |    0.00000 |  0.00000 | 67.43164 |  0.00000 |   67.43164 | 0.11225 | -0.00579 |    0.06123 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          10906 |    0.00000 |    0.00000 |  0.00000 |  6.98242 |  0.00000 |   70.41558 | 0.07001 | -0.00362 |    0.03640 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          10850 |    0.00000 |    0.00000 |  0.00000 |  8.20312 |  0.00000 |   70.56749 | 0.06218 | -0.00419 |    0.03227 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11053 |    0.00000 |    0.00000 |  0.00000 | 10.42480 |  0.00000 |   70.01682 | 0.06197 | -0.00483 |    0.03245 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11871 |    0.00000 |    0.00000 |  0.00000 | 23.90137 |  0.00000 |   67.79785 | 0.03901 |  0.00004 |    0.02103 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56654 | -0.00002 |    0.48838 |
+# | 22 | Total sparsity:                     | -              |        270896 |          83671 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   69.11324 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-11-03 12:20:32,056 - Total sparsity: 69.11
+# 2018-11-03 12:20:32,108 - --- validate (epoch=359)-----------
+# 2018-11-03 12:20:32,109 - 10000 samples (256 per mini-batch)
+# 2018-11-03 12:20:34,825 - ==> Top1: 91.430    Top5: 99.670    Loss: 0.368
+#
+# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.490 on Epoch: 339   <====== BEST RESULT
+# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.480 on Epoch: 222
+# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.470 on Epoch: 280
+# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.450 on Epoch: 328
+# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.440 on Epoch: 298
+# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.440 on Epoch: 314
+# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.440 on Epoch: 319
+# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 283
+# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 304
+# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 359
+# 2018-11-03 12:20:34,828 - Saving checkpoint to: logs/2018.11.03-111011/checkpoint.pth.tar
+# 2018-11-03 12:20:34,859 - --- test ---------------------
+# 2018-11-03 12:20:34,859 - 10000 samples (256 per mini-batch)
+# 2018-11-03 12:20:37,610 - ==> Top1: 91.430    Top5: 99.670    Loss: 0.368
+
+
+version: 1
+pruners:
+  pruner1:
+    class: SplicingPruner
+    low_thresh_mult: 0.9 # 0.6
+    hi_thresh_mult: 1.1 # 0.7
+    sensitivity_multiplier: 0.015
+    sensitivities:
+      #module.conv1.weight: 0.50
+      module.layer1.0.conv1.weight: 0.050
+      module.layer1.0.conv2.weight: 0.050
+      module.layer1.1.conv1.weight: 0.050
+      module.layer1.1.conv2.weight: 0.050
+      module.layer1.2.conv1.weight: 0.010
+      module.layer1.2.conv2.weight: 0.050
+      module.layer2.0.conv1.weight: 0.010
+      module.layer2.0.conv2.weight: 0.030
+      module.layer2.0.downsample.0.weight: 0.050
+      module.layer2.1.conv1.weight: 0.050
+      module.layer2.1.conv2.weight: 0.050
+      module.layer2.2.conv1.weight: 0.040
+      module.layer2.2.conv2.weight: 0.050
+      module.layer3.0.conv1.weight: 0.040
+      module.layer3.0.conv2.weight: 0.020
+      module.layer3.0.downsample.0.weight: 0.050
+      module.layer3.1.conv1.weight: 0.050
+      module.layer3.1.conv2.weight: 0.050
+      module.layer3.2.conv1.weight: 0.050
+      module.layer3.2.conv2.weight: 0.050
+      #module.fc.weight
+
+lr_schedulers:
+  training_lr:
+    class: StepLR
+    step_size: 45
+    gamma: 0.10
+
+policies:
+  - pruner:
+      instance_name: pruner1
+      args:
+        keep_mask: True
+        mini_batch_pruning_frequency: 1
+        mask_on_forward_only: True
+        # use_double_copies: True
+    starting_epoch: 180
+    ending_epoch: 280
+    frequency: 1
+
+
+  - lr_scheduler:
+      instance_name: training_lr
+    starting_epoch: 225
+    ending_epoch: 400
+    frequency: 1
diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
index d0f58ce..5f67404 100755
--- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
+++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml
@@ -21,11 +21,13 @@
 #
 # Baseline results:
 #     Top1: 92.850    Top5: 99.780    Loss: 0.464
+#     Parameters: 851,504
 #     Total MACs: 125,747,840
 #
 # Results:
 #     Top1: 92.740    Top5: 99.640    Loss: 1.534
-#     Total MACs: 67,797,632
+#     Parameters: 570,704  (= 33% sparse )
+#     Total MACs: 67,797,632 (=1.85x less MACs)
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-- 
GitLab