From c849a25fd8d8719c1a81d43f85cf831d24f6da37 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 11 Nov 2019 22:21:15 +0200
Subject: [PATCH] Pruning with virtual Batch-norm statistics folding (#415)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* pruning: add an option to virtually fold BN into Conv2D for ranking

PruningPolicy can be configured using a new control argument fold_batchnorm: when set to `True`, the weights of BatchNorm modules are folded into the weights of Conv-2D modules (if Conv2D->BN edges exist in the model graph).  Each weights filter is attenuated using a different pair of (gamma, beta) coefficients, so `fold_batchnorm` is relevant for fine-grained and filter-ranking pruning methods.  We attenuate using the running values of the mean and variance, as is done in quantization.
This control argument is only supported for Conv-2D modules (i.e. other convolution operation variants and Linear operations are not supported).
e.g.:
policies:
  - pruner:
      instance_name : low_pruner
      args:
        fold_batchnorm: True
    starting_epoch: 0
    ending_epoch: 30
    frequency: 2

* AGP: non-functional refactoring

distiller/pruning/automated_gradual_pruner.py – change `prune_to_target_sparsity`
to `_set_param_mask_by_sparsity_target`, which is a more appropriate function
name as we don’t really prune in this function

* Simplify GEMM weights input-channel ranking logic

Ranking weight-matrices by input channels is similar to ranking 4D
Conv weights by input channels, so there is no need for duplicate logic.

distiller/pruning/ranked_structures_pruner.py
-change `prune_to_target_sparsity` to `_set_param_mask_by_sparsity_target`,
which is a more appropriate function name as we don’t really prune in this
function
-remove the code handling ranking of matrix rows

distiller/norms.py – remove rank_cols.

distiller/thresholding.py – in expand_binary_map treat `channels` group_type
the same as the `cols` group_type when dealing with 2D weights

* AGP: add example of ranking filters with virtual BN-folding

Also update resnet20 AGP examples
---
 distiller/norms.py                            |  26 ++--
 distiller/policy.py                           |  74 ++++++++-
 distiller/pruning/automated_gradual_pruner.py |  11 +-
 distiller/pruning/ranked_structures_pruner.py |  64 +-------
 distiller/thresholding.py                     |   9 +-
 ...resnet20_filters.bn_fold.schedule_agp.yaml | 140 ++++++++++++++++++
 .../resnet20_filters.schedule_agp.yaml        | 104 ++++++-------
 .../resnet20_filters.schedule_agp_2.yaml      |  89 ++++++-----
 .../resnet20_filters.schedule_agp_3.yaml      |  86 ++++++-----
 .../resnet20_filters.schedule_agp_4.yaml      |  85 ++++++-----
 tests/test_pruning.py                         |   4 +-
 11 files changed, 427 insertions(+), 265 deletions(-)
 create mode 100755 examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml

diff --git a/distiller/norms.py b/distiller/norms.py
index bafda4e..a73a07f 100644
--- a/distiller/norms.py
+++ b/distiller/norms.py
@@ -40,7 +40,7 @@ __all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm",
            "rows_lp_norm", "cols_lp_norm",
            "rows_norm", "cols_norm",
            "l1_norm", "l2_norm", "max_norm",
-           "rank_channels", "rank_filters", "rank_cols"]
+           "rank_channels", "rank_filters"]
 
 
 class NamedFunction:
@@ -138,9 +138,12 @@ def channels_lp_norm(param, p=1, group_len=1, length_normalized=False):
 
 
 def channels_norm(param, norm_fn, group_len=1, length_normalized=False):
-    """Compute some norm of 3D channels of 4D parameter tensors.
+    """Compute some norm of parameter input-channels.
+
+    Computing the norms of weight-matrices input channels is logically similar to computing
+    the norms of 4D Conv weights input channels so we treat both cases.
+    Assumes 2D or 4D weights tensors.
 
-    Assumes 4D weights tensors.
     Args:
         param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
         norm_fn: a callable that computes a normal
@@ -154,6 +157,13 @@ def channels_norm(param, norm_fn, group_len=1, length_normalized=False):
     """
     assert param.dim() in (2, 4), "param has invalid dimensions"
     if param.dim() == 2:
+        # For GEMM operations, PyTorch stores the weights matrices in a transposed format.  I.e.
+        # before performing GEMM, a matrix is transposed.  This is because the output is computed
+        # as follows (see torch.nn.functional.linear):
+        #   y = x(W^T) + b ; where W^T is the transpose of W
+        #
+        # Therefore, W is expected to have shape (output_channels, input_channels), and to compute
+        # the norms of input_channels, we compute the norms of W's columns.
         return cols_norm(param, norm_fn, group_len, length_normalized)
     param = param.transpose(0, 1).contiguous()
     group_size = group_len * np.prod(param.shape[1:])
@@ -316,13 +326,3 @@ def rank_filters(param, group_len, magnitude_fn, fraction_to_partition, rounding
     mags = filters_norm(param, magnitude_fn, group_len, length_normalized=True)
     return k_smallest_elems(mags, n_filters_to_prune, noise)
 
-
-def rank_cols(param, group_len, magnitude_fn, fraction_to_partition, rounding_fn, noise):
-    assert param.dim() == 2, "This ranking is only supported for 2D tensors"
-    COLS_DIM = 0
-    n_cols = param.size(COLS_DIM)
-    n_cols_to_prune = num_structs_to_prune(n_cols, group_len, fraction_to_partition, rounding_fn)
-    if n_cols_to_prune == 0:
-        return None, None
-    mags = cols_norm(param, magnitude_fn, group_len, length_normalized=True)
-    return k_smallest_elems(mags, n_cols_to_prune, noise)
diff --git a/distiller/policy.py b/distiller/policy.py
index d460d24..e040d65 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -22,14 +22,17 @@
 - QuantizationPolicy: quantization scheduling
 """
 import torch
+import torch.nn as nn
 import torch.optim.lr_scheduler
-from collections import namedtuple
+from collections import namedtuple, OrderedDict
 import logging
-msglogger = logging.getLogger()
+import distiller
+
 
 __all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy',
            'PolicyLoss', 'LossComponent']
 
+msglogger = logging.getLogger()
 PolicyLoss = namedtuple('PolicyLoss', ['overall_loss', 'loss_components'])
 LossComponent = namedtuple('LossComponent', ['name', 'value'])
 
@@ -94,7 +97,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
             disable this masking set:
                 pruner_args['mask_on_forward_only'] = False
 
-            use_double_copies: when set to 'True', two sets of weights are used. In the forward-pass we use
+            use_double_copies: when set to `True`, two sets of weights are used. In the forward-pass we use
             masked weights to compute the loss, but in the backward-pass we update the unmasked weights (using
             gradients computed from the masked-weights loss).
 
@@ -103,7 +106,15 @@ class PruningPolicy(ScheduledTrainingPolicy):
             fine-grained control over pruning than that provided by CompressionScheduler (epoch granularity).
             When setting 'mini_batch_pruning_frequency' to a value other than zero, make sure to configure the policy's
             schedule to once-every-epoch.
-        """
+
+            fold_batchnorm: when set to `True`, the weights of BatchNorm modules are folded into the the weights of
+            Conv-2D modules (if Conv2D->BN edges exist in the model graph).  Each weights filter is attenuated using
+            a different pair of (gamma, beta) coefficients, so `fold_batchnorm` is relevant for fine-grained and
+            filter-ranking pruning methods.  We attenuate using the running values of the mean and variance, as is
+            done in quantization.
+            This control argument is only supported for Conv-2D modules (i.e. other convolution operation variants and
+            Linear operations are not supported).
+         """
         super(PruningPolicy, self).__init__(classes, layers)
         self.pruner = pruner
         # Copy external policy configuration, if available
@@ -120,10 +131,51 @@ class PruningPolicy(ScheduledTrainingPolicy):
         self.use_double_copies = pruner_args.get('use_double_copies', False)
         self.discard_masks_at_minibatch_end = pruner_args.get('discard_masks_at_minibatch_end', False)
         self.skip_first_minibatch = pruner_args.get('skip_first_minibatch', False)
-        # Initiliaze state
+        self.fold_bn = pruner_args.get('fold_batchnorm', False)
+        # These are required for BN-folding.  We cache them to improve performance
+        self.named_modules = None
+        self.sg = None
+        # Initialize state
         self.is_last_epoch = False
         self.is_initialized = False
 
+    @staticmethod
+    def _fold_batchnorm(model, param_name, param, named_modules, sg):
+        def _get_all_parameters(param_module, bn_module):
+            w, b, gamma, beta = param_module.weight, param_module.bias, bn_module.weight, bn_module.bias
+            if not bn_module.affine:
+                gamma = 1.
+                beta = 0.
+            return w, b, gamma, beta
+
+        def get_bn_folded_weights(conv_module, bn_module):
+            """Compute the weights of `conv_module` after folding successor BN layer.
+
+            In inference, DL frameworks and graph-compilers fold the batch normalization into
+            the weights as defined by equations 20 and 21 of https://arxiv.org/pdf/1806.08342.pdf
+
+            :param conv_module: nn.Conv2d module
+            :param bn_module: nn.BatchNorm2d module which succeeds `conv_module`
+            :return: Folded weights
+            """
+            w, b, gamma, beta = _get_all_parameters(conv_module, bn_module)
+            with torch.no_grad():
+                sigma_running = torch.sqrt(bn_module.running_var + bn_module.eps)
+                w_corrected = w * (gamma / sigma_running).view(-1, 1, 1, 1)
+            return w_corrected
+
+        layer_name = distiller.utils.param_name_2_module_name(param_name)
+        if not isinstance(named_modules[layer_name], nn.Conv2d):
+            return param
+
+        bn_layers = sg.successors_f(layer_name, ['BatchNormalization'])
+        if bn_layers:
+            assert len(bn_layers) == 1
+            bn_module = named_modules[bn_layers[0]]
+            conv_module = named_modules[layer_name]
+            param = get_bn_folded_weights(conv_module, bn_module)
+        return param
+
     def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
         msglogger.debug("Pruner {} is about to prune".format(self.pruner.name))
         self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
@@ -132,7 +184,16 @@ class PruningPolicy(ScheduledTrainingPolicy):
 
         meta['model'] = model
         is_initialized = self.is_initialized
+
+        if self.fold_bn:
+            # Cache this information (required for BN-folding) to improve performance
+            self.named_modules = OrderedDict(model.named_modules())
+            dummy_input = torch.randn(model.input_shape)
+            self.sg = distiller.SummaryGraph(model, dummy_input)
+
         for param_name, param in model.named_parameters():
+            if self.fold_bn:
+                param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg)
             if not is_initialized:
                 # Initialize the maskers
                 masker = zeros_mask_dict[param_name]
@@ -141,6 +202,7 @@ class PruningPolicy(ScheduledTrainingPolicy):
                 # register for the backward hook of the parameters
                 if self.mask_gradients:
                     masker.backward_hook_handle = param.register_hook(masker.mask_gradient)
+
                 self.is_initialized = True
                 if not self.skip_first_minibatch:
                     self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
@@ -164,6 +226,8 @@ class PruningPolicy(ScheduledTrainingPolicy):
 
         for param_name, param in model.named_parameters():
             if set_masks:
+                if self.fold_bn:
+                    param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg)
                 self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
             zeros_mask_dict[param_name].apply_mask(param)
 
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index d906a12..c2e6d4c 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -56,9 +56,10 @@ class AutomatedGradualPrunerBase(_ParameterPruner):
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         target_sparsity = self.compute_target_sparsity(meta)
-        self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, meta['model'])
+        self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, meta['model'])
 
-    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
+    def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
+        """Set the parameter mask using a target sparsity. Override this in subclasses"""
         raise NotImplementedError
 
 
@@ -78,7 +79,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase):
             return
         super().set_param_mask(param, param_name, zeros_mask_dict, meta)
 
-    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
+    def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
         zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity)
 
 
@@ -92,8 +93,8 @@ class StructuredAGP(AutomatedGradualPrunerBase):
         super().__init__(name, initial_sparsity, final_sparsity)
         self.pruner = None
 
-    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model):
-        self.pruner.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model)
+    def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model):
+        self.pruner._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, model)
 
 
 class L1RankedStructureParameterPruner_AGP(StructuredAGP):
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 1453031..ed66d78 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -74,9 +74,9 @@ class _RankedStructureParameterPruner(_ParameterPruner):
             model = meta['model']
         except TypeError:
             model = None
-        return self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, fraction_to_prune, model)
+        return self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, fraction_to_prune, model)
 
-    def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model):
+    def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model):
         if not self.is_supported(param_name):
             return
 
@@ -127,12 +127,10 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
             return
-        if self.group_type in ['3D', 'Filters']:
+        if self.group_type in ('3D', 'Filters'):
             group_pruning_fn = partial(self.rank_and_prune_filters, noise=self.noise)
-        elif self.group_type == 'Channels':
+        elif self.group_type in ('Channels', 'Rows'):
             group_pruning_fn = partial(self.rank_and_prune_channels, noise=self.noise)
-        elif self.group_type == 'Rows':
-            group_pruning_fn = self.rank_and_prune_rows
         elif self.group_type == 'Blocks':
             group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape)
 
@@ -147,12 +145,6 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
                                 model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm,
                                 noise=0.0, group_size=1, rounding_fn=math.floor):
         if binary_map is None:
-
-            if param.dim() == 2:
-                # 2D Linear parameters 
-                return LpRankedStructureParameterPruner.rank_and_prune_rows(fraction_to_prune, param, param_name,
-                                                                            zeros_mask_dict, model, binary_map,
-                                                                            magnitude_fn, group_size)
             bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn,
                                                                            fraction_to_prune, rounding_fn, noise)
             if bottomk_channels is None:
@@ -194,43 +186,6 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
                            fraction_to_prune)
         return binary_map
 
-    @staticmethod
-    def rank_and_prune_rows(fraction_to_prune, param, param_name,
-                            zeros_mask_dict, model=None, binary_map=None,
-                            magnitude_fn=distiller.norms.l1_norm, group_size=1):
-        """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
-
-        PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
-        transposed.  This is because the output is computed as follows:
-            y = x(W^T) + b ; where W^T is the transpose of W
-
-        Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W.
-
-        To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks
-        as usual, or we can treat columns as rows, and rows as columns :-(.
-        We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
-        that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far
-        away from each other in memory, and this means poor use of caches and system memory.
-        """
-        if binary_map is None:
-            bottomk_cols, cols_mags = distiller.norms.rank_cols(param, group_size, magnitude_fn, fraction_to_prune,
-                                                                rounding_fn=math.floor, noise=None)
-            if bottomk_cols is None:
-                # Empty list means that fraction_to_prune is too low to prune anything
-                msglogger.info("Too few cols - can't prune %.1f%% cols", 100 * fraction_to_prune)
-                return
-            threshold = bottomk_cols[-1]
-            binary_map = cols_mags.gt(threshold).type(param.data.type())
-
-        if zeros_mask_dict is not None:
-            mask, _ = distiller.thresholding.expand_binary_map(param, 'Cols', binary_map)
-            zeros_mask_dict[param_name].mask = mask
-            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
-                           magnitude_fn, param_name,
-                           distiller.sparsity(mask),
-                           fraction_to_prune)
-        return binary_map
-
     @staticmethod
     def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
                               model=None, binary_map=None, block_shape=None,
@@ -691,14 +646,9 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
                                 noise=0):
         assert binary_map is None
         if binary_map is None:
-            op_type = 'conv' if param.dim() == 4 else 'fc'
-            if op_type == 'conv':
-                bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn,
-                                                                               fraction_to_prune, rounding_fn, noise)
-            else:
-                bottomk_channels, channel_mags = distiller.norms.rank_cols(param, group_size, magnitude_fn,
-                                                                           fraction_to_prune, rounding_fn=math.floor,
-                                                                           noise=None)
+            bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn,
+                                                                           fraction_to_prune, rounding_fn, noise)
+
             # Todo: this little piece of code can be refactored
             if bottomk_channels is None:
                 # Empty list means that fraction_to_prune is too low to prune anything
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 2f93662..cb38909 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -162,14 +162,17 @@ def expand_binary_map(param, group_type, binary_map):
     assert binary_map is not None
 
     # Now let's expand back up to a 4D mask
+    if group_type == 'Channels' and param.dim() == 2:
+        group_type = 'Cols'
+
     if group_type == '2D':
         a = binary_map.expand(param.size(2) * param.size(3),
                               param.size(0) * param.size(1)).t()
-        return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
+        return a.view(*param.shape), binary_map
     elif group_type == 'Rows':
         return binary_map.expand(param.size(1), param.size(0)).t(), binary_map
     elif group_type == 'Cols':
-        return binary_map.expand(param.size(0), param.size(1)), binary_map
+        return binary_map.expand(*param.shape), binary_map
     elif group_type == '3D' or group_type == 'Filters':
         a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()
         return a.view(*param.shape), binary_map
@@ -178,5 +181,5 @@ def expand_binary_map(param, group_type, binary_map):
         a = binary_map.expand(num_filters, num_channels)
         c = a.unsqueeze(-1)
         d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous()
-        return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
+        return d.view(*param.shape), binary_map
 
diff --git a/examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml
new file mode 100755
index 0000000..f23ef84
--- /dev/null
+++ b/examples/agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml
@@ -0,0 +1,140 @@
+# This is the same as examples/agp-pruning/resnet20_filters.schedule_agp.yaml, but with BN-folding for ranking
+#
+# Baseline results:
+#     Top1: 91.780    Top5: 99.710    Loss: 0.376
+#     Total MACs: 40,813,184
+#     # of parameters: 270,896
+#
+# Results:
+#     Top1: 91.34
+#     Total MACs: 30,655,104
+#     Total sparsity: 46.3
+#     # of parameters: 120,000  (=55.7% of the baseline parameters)
+#
+# time python3 compress_classifier.py --arch resnet20_cifar  $CIFAR10_PATH -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer --gpu=0
+#
+#  Parameters:
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41630 | -0.00526 |    0.29222 |
+#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15170 | -0.01458 |    0.10325 |
+#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14991 |  0.00199 |    0.10401 |
+#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13939 | -0.01859 |    0.10473 |
+#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13140 | -0.00861 |    0.10033 |
+#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17184 | -0.00677 |    0.11853 |
+#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13944 | -0.00140 |    0.09709 |
+#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17399 | -0.00476 |    0.13297 |
+#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18404 |  0.00603 |    0.14192 |
+#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34235 | -0.01751 |    0.25170 |
+#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09837 | -0.00903 |    0.07080 |
+#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08172 | -0.00530 |    0.05841 |
+#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12200 | -0.00766 |    0.09124 |
+#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08811 |  0.00425 |    0.06511 |
+#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13276 | -0.00406 |    0.10314 |
+#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09774 | -0.00467 |    0.07597 |
+#  | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16717 | -0.00613 |    0.13268 |
+#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.76367 |  1.56250 |   69.99783 | 0.07646 | -0.00394 |    0.03609 |
+#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  1.56250 | 10.59570 |  0.00000 |   69.99783 | 0.07015 | -0.00542 |    0.03294 |
+#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 13.37891 |  3.12500 |   69.99783 | 0.07038 | -0.00402 |    0.03317 |
+#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  3.12500 | 30.32227 |  0.00000 |   69.99783 | 0.04269 |  0.00014 |    0.01844 |
+#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.55338 | -0.00001 |    0.31654 |
+#  | 22 | Total sparsity:                     | -              |        223536 |         120000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   46.31737 | 0.00000 |  0.00000 |    0.00000 |
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  Total sparsity: 46.32
+#
+#  --- validate (epoch=179)-----------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.120    Top5: 99.660    Loss: 0.374
+#
+#  ==> Best [Top1: 91.340   Top5: 99.650   Sparsity:46.32   NNZ-Params: 120000 on epoch: 109]
+#  Saving checkpoint to: logs/2019.10.31-200150/checkpoint.pth.tar
+#  --- test ---------------------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.120    Top5: 99.660    Loss: 0.364
+#
+#
+#  real    33m39.771s
+#  user    289m4.681s
+#  sys     17m43.188s
+
+version: 1
+
+pruners:
+  low_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.10
+    final_sparsity: 0.50
+    group_type: Filters
+    weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight,
+              module.layer2.0.downsample.0.weight,
+              module.layer2.1.conv2.weight, module.layer2.2.conv2.weight,
+              module.layer2.1.conv1.weight, module.layer2.2.conv1.weight]
+
+  fine_pruner:
+    class:  AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
+              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]
+
+  fc_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.50
+    group_type: Rows
+    weights: [module.fc.weight]
+
+
+lr_schedulers:
+  pruning_lr:
+    class: StepLR
+    step_size: 50
+    gamma: 0.10
+
+extensions:
+  net_thinner:
+      class: 'FilterRemover'
+      thinning_func_str: remove_filters
+      arch: 'resnet20_cifar'
+      dataset: 'cifar10'
+
+
+policies:
+  - pruner:
+      instance_name : low_pruner
+      args:
+        fold_batchnorm: True
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 2
+
+# After completing the pruning, we perform network thinning and continue fine-tuning.
+# When there is ambiguity in the scheduling order of policies, Distiller follows the
+# order of declaration.  Because epoch 30 is the end of one pruner, and the beginning
+# of two others, and because we want the thinning to happen at the beginning of
+# epoch 30, it is important to declare the thinning policy here and not lower in the
+# file.
+  - extension:
+      instance_name: net_thinner
+    epochs: [30]
+
+  - pruner:
+      instance_name : fine_pruner
+      args:
+        fold_batchnorm: True
+    starting_epoch: 30
+    ending_epoch: 50
+    frequency: 2
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 30
+    ending_epoch: 50
+    frequency: 2
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 0
+    ending_epoch: 400
+    frequency: 1
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
index 362e52c..eabf8c6 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -9,60 +9,57 @@
 #     # of parameters: 270,896
 #
 # Results:
-#     Top1: 91.73
+#     Top1: 91.34
 #     Total MACs: 30,655,104
-#     Total sparsity: 41.10
+#     Total sparsity: 46.3%
 #     # of parameters: 120,000  (=55.7% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer
+# time python3 compress_classifier.py --arch resnet20_cifar  $CIFAR10_PATH -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer --gpu=0
 #
-# Parameters:
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
-# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42267 | -0.01028 |    0.29903 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15895 | -0.01265 |    0.11210 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15610 |  0.00257 |    0.11472 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13888 | -0.01590 |    0.10543 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13052 | -0.00519 |    0.10135 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18351 | -0.01298 |    0.13564 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14909 | -0.00098 |    0.11435 |
-# |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17438 | -0.00580 |    0.13427 |
-# |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18654 | -0.00126 |    0.14499 |
-# |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34412 | -0.01243 |    0.24940 |
-# | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11833 | -0.00937 |    0.08865 |
-# | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09171 | -0.00197 |    0.06956 |
-# | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13403 | -0.01057 |    0.09999 |
-# | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09652 |  0.00544 |    0.07033 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13635 | -0.00543 |    0.10654 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09992 | -0.00600 |    0.07893 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17133 | -0.00926 |    0.13503 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  8.47168 |  1.56250 |   69.99783 | 0.07819 | -0.00423 |    0.03752 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  1.56250 |  8.37402 |  0.00000 |   69.99783 | 0.07238 | -0.00539 |    0.03450 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 11.93848 |  3.12500 |   69.99783 | 0.07195 | -0.00571 |    0.03462 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  3.12500 | 28.75977 |  1.56250 |   69.99783 | 0.04405 |  0.00060 |    0.02004 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.57112 | -0.00001 |    0.36129 |
-# | 22 | Total sparsity:                     | -              |        223536 |         120000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   46.31737 | 0.00000 |  0.00000 |    0.00000 |
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 46.32
+#  Parameters:
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41726 | -0.00601 |    0.29649 |
+#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15451 | -0.01086 |    0.10477 |
+#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15074 | -0.00062 |    0.10780 |
+#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13604 | -0.01868 |    0.10372 |
+#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12923 | -0.00463 |    0.09968 |
+#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17818 | -0.01222 |    0.13184 |
+#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14475 | -0.00069 |    0.11089 |
+#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16991 |  0.00091 |    0.12894 |
+#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18059 |  0.00199 |    0.14176 |
+#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34145 | -0.03631 |    0.25094 |
+#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13213 | -0.00809 |    0.10198 |
+#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10230 |  0.00805 |    0.07883 |
+#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11462 | -0.00682 |    0.08532 |
+#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08012 |  0.00611 |    0.05776 |
+#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13316 | -0.00256 |    0.10497 |
+#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09755 | -0.00598 |    0.07722 |
+#  | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16702 |  0.00251 |    0.12968 |
+#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.25391 |  4.68750 |   69.99783 | 0.07554 | -0.00373 |    0.03568 |
+#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  4.68750 | 11.49902 |  0.00000 |   69.99783 | 0.06968 | -0.00573 |    0.03275 |
+#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 15.57617 |  4.68750 |   69.99783 | 0.06895 | -0.00504 |    0.03245 |
+#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  4.68750 | 32.08008 |  0.00000 |   69.99783 | 0.04180 |  0.00053 |    0.01793 |
+#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.55055 | -0.00001 |    0.31038 |
+#  | 22 | Total sparsity:                     | -              |        223536 |         120000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   46.31737 | 0.00000 |  0.00000 |    0.00000 |
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  Total sparsity: 46.32
 #
-# --- validate (epoch=359)-----------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 91.490    Top5: 99.710    Loss: 0.346
+#  --- validate (epoch=179)-----------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.110    Top5: 99.700    Loss: 0.379
 #
-# ==> Best Top1: 91.730   On Epoch: 344
+#  ==> Best [Top1: 91.340   Top5: 99.670   Sparsity:46.32   NNZ-Params: 120000 on epoch: 127]
+#  Saving checkpoint to: logs/2019.10.31-204215/checkpoint.pth.tar
+#  --- test ---------------------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.110    Top5: 99.700    Loss: 0.363
 #
-# Saving checkpoint to: logs/2018.10.30-150931/checkpoint.pth.tar
-# --- test ---------------------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 91.490    Top5: 99.710    Loss: 0.346
 #
-#
-# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.10.30-150931/2018.10.30-150931.log
-#
-# real    36m36.329s
-# user    82m32.685s
-# sys     10m8.746s
+#  real    31m38.549s
+#  user    293m36.410s
+#  sys     18m6.041s
 
 version: 1
 
@@ -113,6 +110,16 @@ policies:
     ending_epoch: 30
     frequency: 2
 
+# After completing the pruning, we perform network thinning and continue fine-tuning.
+# When there is ambiguity in the scheduling order of policies, Distiller follows the
+# order of declaration.  Because epoch 30 is the end of one pruner, and the beginning
+# of two others, and because we want the thinning to happen at the beginning of
+# epoch 30, it is important to declare the thinning policy here and not lower in the
+# file.
+  - extension:
+      instance_name: net_thinner
+    epochs: [30]
+
   - pruner:
       instance_name : fine_pruner
     starting_epoch: 30
@@ -125,11 +132,6 @@ policies:
     ending_epoch: 50
     frequency: 2
 
-# After completeing the pruning, we perform network thinning and continue fine-tuning.
-  - extension:
-      instance_name: net_thinner
-    epochs: [32]
-
   - lr_scheduler:
       instance_name: pruning_lr
     starting_epoch: 0
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
index 1ea0661..cc07e90 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml
@@ -9,60 +9,59 @@
 #     # of parameters: 270,896
 #
 # Results:
-#     Top1: 91.200    Top5: 99.660    Loss: 1.551
+#     Top1: 91.630   Top5: 99.670
 #     Total MACs: 30,638,720
 #     Total sparsity: 41.84
 #     # of parameters: 143,488 (=53% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --gpus=0 --vs=0
 #
-# Parameters:
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
-# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.27315 | -0.00387 |    0.19394 |
-# |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12038 | -0.01295 |    0.08811 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11879 | -0.00031 |    0.08735 |
-# |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10293 | -0.01274 |    0.07795 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09285 | -0.00276 |    0.07141 |
-# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12849 | -0.00345 |    0.09355 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10689 | -0.00381 |    0.08038 |
-# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10467 | -0.00371 |    0.08149 |
-# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08897 | -0.00502 |    0.06938 |
-# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17695 | -0.01111 |    0.12479 |
-# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07736 | -0.00531 |    0.06118 |
-# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06832 | -0.00404 |    0.05406 |
-# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07965 | -0.00904 |    0.06278 |
-# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06305 |  0.00122 |    0.04955 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06753 | -0.00459 |    0.05371 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06379 | -0.00297 |    0.05078 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08779 | -0.00584 |    0.06956 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.12891 |  0.00000 |   69.99783 | 0.05191 | -0.00319 |    0.02604 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.98340 |  0.00000 |   69.99783 | 0.04658 | -0.00360 |    0.02330 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.10742 |  0.00000 |   69.99783 | 0.04563 | -0.00393 |    0.02297 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 31.20117 |  0.00000 |   69.99783 | 0.02453 |  0.00005 |    0.01240 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.50451 | -0.00001 |    0.43698 |
-# | 22 | Total sparsity:                     | -              |        246704 |         143488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.83799 | 0.00000 |  0.00000 |    0.00000 |
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 41.84
+#  Parameters:
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.38465 | -0.00533 |    0.27349 |
+#  |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17334 | -0.01720 |    0.12535 |
+#  |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17280 |  0.00148 |    0.12660 |
+#  |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14518 | -0.02108 |    0.11044 |
+#  |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13157 | -0.00240 |    0.09998 |
+#  |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18724 | -0.00470 |    0.13594 |
+#  |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15303 | -0.00564 |    0.11591 |
+#  |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15062 | -0.00379 |    0.11690 |
+#  |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12943 | -0.00739 |    0.10150 |
+#  |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25227 | -0.01490 |    0.17715 |
+#  | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11074 | -0.00783 |    0.08721 |
+#  | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09753 | -0.00582 |    0.07681 |
+#  | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11501 | -0.01363 |    0.09121 |
+#  | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09145 |  0.00280 |    0.07167 |
+#  | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09772 | -0.00674 |    0.07769 |
+#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09316 | -0.00339 |    0.07396 |
+#  | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12438 | -0.00958 |    0.09868 |
+#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.71387 |  0.00000 |   69.99783 | 0.07404 | -0.00405 |    0.03694 |
+#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.25098 |  0.00000 |   69.99783 | 0.06739 | -0.00494 |    0.03356 |
+#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.37598 |  0.00000 |   69.99783 | 0.06739 | -0.00414 |    0.03368 |
+#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 28.49121 |  0.00000 |   69.99783 | 0.03788 |  0.00048 |    0.01900 |
+#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.54585 | -0.00002 |    0.46076 |
+#  | 22 | Total sparsity:                     | -              |        246704 |         143488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.83799 | 0.00000 |  0.00000 |    0.00000 |
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  Total sparsity: 41.84
 #
-# --- validate (epoch=359)-----------
-# 5000 samples (256 per mini-batch)
-# ==> Top1: 93.460    Top5: 99.800    Loss: 1.530
+#  --- validate (epoch=179)-----------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.430    Top5: 99.640    Loss: 0.365
 #
-# ==> Best Top1: 97.320   On Epoch: 180
+#  ==> Best [Top1: 91.630   Top5: 99.670   Sparsity:41.84   NNZ-Params: 143488 on epoch: 74]
+#  Saving checkpoint to: logs/2019.10.31-235045/checkpoint.pth.tar
+#  --- test ---------------------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.430    Top5: 99.640    Loss: 0.379
 #
-# Saving checkpoint to: logs/2018.10.15-115941/checkpoint.pth.tar
-# --- test ---------------------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 91.200    Top5: 99.660    Loss: 1.551
 #
+#  Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller_remote/examples/classifier_compression/logs/2019.10.31-235045/2019.10.31-235045.log
 #
-# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-115941/2018.10.15-115941.log
-#
-# real    32m31.997s
-# user    72m58.813s
-# sys     9m1.245s
+#  real    52m57.688s
+#  user    304m20.353s
+#  sys     10m56.498s
 
 version: 1
 pruners:
@@ -112,7 +111,7 @@ policies:
     ending_epoch: 40
     frequency: 2
 
-# After completeing the pruning, we perform network thinning and continue fine-tuning.
+# After completing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
     epochs: [22]
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
index f958ec1..2f965af 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
@@ -9,59 +9,57 @@
 #     # of parameters: 270,896
 #
 # Results:
-#     Top1: 91.470 on Epoch: 288
+#     Top1: 91.29
 #     Total MACs: 30,433,920 (74.6% of the original compute)
 #     Total sparsity: 56.41%
 #     # of parameters: 95922  (=35.4% of the baseline parameters ==> 64.6% sparsity)
 #
 # time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0
 #
-# Parameters:
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
-# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41372 | -0.00535 |    0.29289 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15610 | -0.01373 |    0.11096 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15429 |  0.00180 |    0.11294 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13297 | -0.01580 |    0.10052 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12638 | -0.00556 |    0.09699 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17940 | -0.01313 |    0.13183 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14671 | -0.00056 |    0.11065 |
-# |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16872 | -0.00380 |    0.12838 |
-# |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18371 |  0.00119 |    0.14401 |
-# |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.33976 |  0.00148 |    0.24721 |
-# | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12741 | -0.00734 |    0.09754 |
-# | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10207 |  0.00286 |    0.07914 |
-# | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13480 | -0.00943 |    0.10174 |
-# | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09721 |  0.00049 |    0.07094 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           4608 |    0.00000 |    0.00000 |  0.00000 |  2.63672 |  1.56250 |   50.00000 | 0.11758 | -0.00484 |    0.07093 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          18432 |    0.00000 |    0.00000 |  1.56250 |  2.00195 |  0.00000 |   50.00000 | 0.08720 | -0.00522 |    0.05143 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16003 | -0.01049 |    0.12534 |
-# | 17 | module.layer3.1.conv1.weight        | (63, 64, 3, 3) |         36288 |          10887 |    0.00000 |    0.00000 |  0.00000 |  9.20139 |  1.58730 |   69.99835 | 0.07613 | -0.00415 |    0.03605 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 63, 3, 3) |         36288 |          10887 |    0.00000 |    0.00000 |  1.58730 |  9.10218 |  0.00000 |   69.99835 | 0.07025 | -0.00544 |    0.03305 |
-# | 19 | module.layer3.2.conv1.weight        | (62, 64, 3, 3) |         35712 |          10714 |    0.00000 |    0.00000 |  0.00000 | 13.33165 |  3.22581 |   69.99888 | 0.07118 | -0.00550 |    0.03367 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 62, 3, 3) |         35712 |          10714 |    0.00000 |    0.00000 |  3.22581 | 28.80544 |  0.00000 |   69.99888 | 0.04353 |  0.00071 |    0.01894 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.57334 | -0.00001 |    0.35840 |
-# | 22 | Total sparsity:                     | -              |        220080 |          95922 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   56.41494 | 0.00000 |  0.00000 |    0.00000 |
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 56.41
+#  Parameters:
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41066 | -0.01044 |    0.28684 |
+#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15311 | -0.01381 |    0.10292 |
+#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15199 |  0.00052 |    0.10768 |
+#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13703 | -0.01426 |    0.10459 |
+#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12956 | -0.00363 |    0.10109 |
+#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18044 | -0.01347 |    0.13435 |
+#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14809 | -0.00015 |    0.11282 |
+#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16746 | -0.00541 |    0.13009 |
+#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18162 | -0.00449 |    0.14069 |
+#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34670 |  0.00145 |    0.23974 |
+#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11932 | -0.00470 |    0.08793 |
+#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09387 |  0.00694 |    0.06755 |
+#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11968 | -0.00680 |    0.08725 |
+#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08844 |  0.00324 |    0.06366 |
+#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           4608 |    0.00000 |    0.00000 |  0.00000 |  1.26953 |  0.00000 |   50.00000 | 0.11985 | -0.00543 |    0.07131 |
+#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.48828 |  0.00000 |   50.00000 | 0.08819 | -0.00512 |    0.05183 |
+#  | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15923 | -0.01150 |    0.12339 |
+#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.03125 |  0.00000 |   69.99783 | 0.07690 | -0.00368 |    0.03622 |
+#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  8.78906 |  0.00000 |   69.99783 | 0.07113 | -0.00574 |    0.03346 |
+#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 18.65234 |  9.37500 |   69.99783 | 0.06893 | -0.00443 |    0.03236 |
+#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  9.37500 | 33.39844 |  0.00000 |   69.99783 | 0.04186 |  0.00078 |    0.01802 |
+#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.54733 | -0.00001 |    0.30963 |
+#  | 22 | Total sparsity:                     | -              |        223536 |          96960 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   56.62444 | 0.00000 |  0.00000 |    0.00000 |
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  Total sparsity: 56.62
 #
-# --- validate (epoch=359)-----------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 91.140    Top5: 99.750    Loss: 0.331
+#  --- validate (epoch=179)-----------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.290    Top5: 99.760    Loss: 0.344
 #
-# ==> Best Top1: 91.470 on Epoch: 288
-# Saving checkpoint to: logs/2018.11.08-232134/checkpoint.pth.tar
-# --- test ---------------------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 91.140    Top5: 99.750    Loss: 0.331
+#  ==> Best [Top1: 91.290   Top5: 99.760   Sparsity:56.62   NNZ-Params: 96960 on epoch: 179]
+#  Saving checkpoint to: logs/2019.11.01-005544/checkpoint.pth.tar
+#  --- test ---------------------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 91.290    Top5: 99.760    Loss: 0.344
 #
 #
-# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.11.08-232134/2018.11.08-232134.log
-#
-# real    37m51.274s
-# user    85m48.506s
-# sys     10m35.410s
+#  real    31m53.439s
+#  user    264m58.198s
+#  sys     16m57.435s
 
 version: 1
 
@@ -147,7 +145,7 @@ policies:
   #   ending_epoch: 20
   #   frequency: 2
 
-# After completeing the pruning, we perform network thinning and continue fine-tuning.
+# After completing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
     epochs: [32]
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
index 147d7b8..954752c 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
@@ -9,52 +9,57 @@
 #     # of parameters: 270,896
 #
 # Results:
-#     Top1: 90.86
+#     Top1: 90.89
 #     Total MACs: 24,723,776 (=1.65x MACs)
 #     Total sparsity: 39.66
 #     # of parameters: 78,776  (=29.1% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 --gpu=0
 #
-# Parameters:
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
-# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43875 | -0.01073 |    0.30863 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16813 | -0.01061 |    0.11996 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16491 |  0.00070 |    0.12291 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14755 | -0.01732 |    0.11317 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13875 | -0.00517 |    0.10758 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18923 | -0.01326 |    0.14210 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15583 | -0.00106 |    0.11905 |
-# |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18225 | -0.00390 |    0.14086 |
-# |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19739 | -0.00785 |    0.15657 |
-# |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34375 | -0.01899 |    0.24210 |
-# | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13369 | -0.00798 |    0.10439 |
-# | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10876 | -0.00311 |    0.08349 |
-# | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13521 | -0.01331 |    0.09820 |
-# | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09794 |  0.00274 |    0.07076 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14004 | -0.00417 |    0.11131 |
-# | 15 | module.layer3.0.conv2.weight        | (32, 64, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12559 | -0.00528 |    0.09919 |
-# | 16 | module.layer3.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.20064 |  0.00941 |    0.15694 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 32, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 |  7.76367 |  1.56250 |   69.99783 | 0.10111 | -0.00460 |    0.04898 |
-# | 18 | module.layer3.1.conv2.weight        | (32, 64, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  1.56250 |  8.83789 |  0.00000 |   69.99783 | 0.09000 | -0.00652 |    0.04327 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 32, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 |  9.37500 |  0.00000 |   69.99783 | 0.09371 | -0.00490 |    0.04521 |
-# | 20 | module.layer3.2.conv2.weight        | (32, 64, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 | 24.65820 |  0.00000 |   69.99783 | 0.06228 |  0.00012 |    0.02827 |
-# | 21 | module.fc.weight                    | (10, 32)       |           320 |            160 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.78238 | -0.00000 |    0.48823 |
-# | 22 | Total sparsity:                     | -              |        130544 |          78776 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   39.65560 | 0.00000 |  0.00000 |    0.00000 |
-# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 39.66
+#  Parameters:
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43641 | -0.01077 |    0.30848 |
+#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16754 | -0.01429 |    0.11853 |
+#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16230 |  0.00232 |    0.12090 |
+#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14305 | -0.01696 |    0.10940 |
+#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13429 | -0.00717 |    0.10360 |
+#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19055 | -0.01061 |    0.14304 |
+#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15439 |  0.00421 |    0.11755 |
+#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18164 | -0.00175 |    0.14038 |
+#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.20045 | -0.01430 |    0.15603 |
+#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.33919 | -0.02323 |    0.25281 |
+#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13992 | -0.00223 |    0.10972 |
+#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11157 |  0.00652 |    0.08719 |
+#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14406 | -0.00417 |    0.10841 |
+#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10744 |  0.00565 |    0.08125 |
+#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13871 | -0.00321 |    0.10984 |
+#  | 15 | module.layer3.0.conv2.weight        | (32, 64, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12541 | -0.00566 |    0.09928 |
+#  | 16 | module.layer3.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19916 | -0.00610 |    0.15371 |
+#  | 17 | module.layer3.1.conv1.weight        | (64, 32, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 |  7.61719 |  1.56250 |   69.99783 | 0.10042 | -0.00397 |    0.04862 |
+#  | 18 | module.layer3.1.conv2.weight        | (32, 64, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  1.56250 |  9.71680 |  0.00000 |   69.99783 | 0.09041 | -0.00711 |    0.04355 |
+#  | 19 | module.layer3.2.conv1.weight        | (64, 32, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 | 12.30469 |  3.12500 |   69.99783 | 0.09180 | -0.00314 |    0.04391 |
+#  | 20 | module.layer3.2.conv2.weight        | (32, 64, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  3.12500 | 29.29688 |  0.00000 |   69.99783 | 0.06059 |  0.00070 |    0.02763 |
+#  | 21 | module.fc.weight                    | (10, 32)       |           320 |            160 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.73275 | -0.00001 |    0.41788 |
+#  | 22 | Total sparsity:                     | -              |        130544 |          78776 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   39.65560 | 0.00000 |  0.00000 |    0.00000 |
+#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#  Total sparsity: 39.66
 #
-# --- validate (epoch=359)-----------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 90.580    Top5: 99.650    Loss: 0.341
+#  --- validate (epoch=179)-----------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 90.660    Top5: 99.700    Loss: 0.343
 #
-# ==> Best Top1: 90.860 on Epoch: 294
-# Saving checkpoint to: logs/2018.11.30-152150/checkpoint.pth.tar
-# --- test ---------------------
-# 10000 samples (256 per mini-batch)
-# ==> Top1: 90.580    Top5: 99.650    Loss: 0.341
+#  ==> Best [Top1: 90.890   Top5: 99.750   Sparsity:39.66   NNZ-Params: 78776 on epoch: 107]
+#  Saving checkpoint to: logs/2019.11.01-013919/checkpoint.pth.tar
+#  --- test ---------------------
+#  10000 samples (256 per mini-batch)
+#  ==> Top1: 90.660    Top5: 99.700    Loss: 0.358
+#
+#
+#  real    30m44.680s
+#  user    290m49.048s
+#  sys     17m17.928s
 
 version: 1
 
@@ -141,7 +146,7 @@ policies:
   #   ending_epoch: 20
   #   frequency: 2
 
-# After completeing the pruning, we perform network thinning and continue fine-tuning.
+# After completing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
     epochs: [32]
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 0862962..269d262 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -453,9 +453,9 @@ def test_row_pruning():
                           [7., 8., 9.]])
     from distiller.pruning import L1RankedStructureParameterPruner
 
-    masker = distiller.scheduler.ParameterMasker("why name")
+    masker = distiller.scheduler.ParameterMasker("debug name")
     zeros_mask_dict = {"some name": masker}
-    L1RankedStructureParameterPruner.rank_and_prune_rows(0.5, param, "some name", zeros_mask_dict)
+    L1RankedStructureParameterPruner.rank_and_prune_channels(0.5, param, "some name", zeros_mask_dict)
     print(distiller.sparsity_rows(masker.mask))
     assert math.isclose(distiller.sparsity_rows(masker.mask), 1/3)
     pass
-- 
GitLab