From 68514d178b583d3f6d59f13bb9f34120ef51eb4c Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 20 Apr 2020 15:02:22 +0300
Subject: [PATCH] small tensor masking API refactoring (#499)

Added masking primitives:
 -mask_tensor
 -create_mask_threshold_criterion
 -create_mask_level_criterion
 -create_mask_sensitivity_criterion

 These APIs have a clearer name and communicate their
 responsibility better: create a tensor mask, based on
 some criterion.  Previously,
 distiller.pruning.create_mask_threshold_criterion was
 named distiller.threshold_mask which did not communicate
 well what this function did.
 Masking functionality is no longer hidden
 inside the Pruner instances, so they can be used directly
 by an application, or to compose new Pruner classes.

Removed file distiller.pruning.pruner:
 -The base-class _ParameterPruner is useless and adds
 needless details to the implementation.

AGP: Separated the pruning-rate schedule from the
 rest of the logic.  This allows us to mix-and-match different
 pruning-rate schedules (just like LR schedulers).
---
 distiller/__init__.py                         | 32 +++----
 distiller/norms.py                            |  1 +
 distiller/pruning/__init__.py                 | 86 +++++++++++++++++++
 distiller/pruning/automated_gradual_pruner.py | 64 +++++++-------
 distiller/pruning/baidu_rnn_pruner.py         | 11 +--
 distiller/pruning/level_pruner.py             | 18 +---
 distiller/pruning/magnitude_pruner.py         | 12 +--
 distiller/pruning/pruner.py                   | 50 -----------
 distiller/pruning/ranked_structures_pruner.py | 42 +++++----
 distiller/pruning/sensitivity_pruner.py       | 18 +---
 distiller/pruning/splicing_pruner.py          |  6 +-
 distiller/pruning/structure_pruner.py         |  4 +-
 distiller/regularization/l1_regularizer.py    |  6 +-
 distiller/scheduler.py                        |  2 +-
 distiller/thresholding.py                     | 15 +---
 tests/test_thresholding.py                    | 24 +++++-
 16 files changed, 204 insertions(+), 187 deletions(-)
 delete mode 100755 distiller/pruning/pruner.py

diff --git a/distiller/__init__.py b/distiller/__init__.py
index e687010..13ef908 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -16,10 +16,11 @@
 
 import torch
 from .utils import *
-from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask
+from .thresholding import GroupThresholdMixin, group_threshold_mask
 from .config import file_config, dict_config, config_component_from_file_by_class
 from .model_summaries import *
 from .scheduler import *
+from .pruning import *
 from .sensitivity import *
 from .directives import *
 from .policy import *
@@ -41,6 +42,21 @@ except pkg_resources.DistributionNotFound:
     __version__ = "Unknown"
 
 
+def __check_pytorch_version():
+    from pkg_resources import parse_version
+    required = '1.3.1'
+    actual = torch.__version__
+    if parse_version(actual) < parse_version(required):
+        msg = "\n\nWRONG PYTORCH VERSION\n"\
+              "Required:  {}\n" \
+              "Installed: {}\n"\
+              "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual)
+        raise RuntimeError(msg)
+
+
+__check_pytorch_version()
+
+
 def model_find_param_name(model, param_to_find):
     """Look up the name of a model parameter.
 
@@ -104,17 +120,3 @@ def model_find_module(model, module_to_find):
             return m
     return None
 
-
-def check_pytorch_version():
-    from pkg_resources import parse_version
-    required = '1.3.1'
-    actual = torch.__version__
-    if parse_version(actual) < parse_version(required):
-        msg = "\n\nWRONG PYTORCH VERSION\n"\
-              "Required:  {}\n" \
-              "Installed: {}\n"\
-              "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual)
-        raise RuntimeError(msg)
-
-
-check_pytorch_version()
diff --git a/distiller/norms.py b/distiller/norms.py
index a73a07f..c43019b 100644
--- a/distiller/norms.py
+++ b/distiller/norms.py
@@ -33,6 +33,7 @@ see: https://www.kaggle.com/residentmario/l1-norms-versus-l2-norms)
 import torch
 import numpy as np
 from functools import partial
+from random import uniform
 
 
 __all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm",
diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index 710c7b5..25e115f 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -42,6 +42,7 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \
                                       FMReconstructionChannelPruner
 from .baidu_rnn_pruner import BaiduRNNPruner
 from .greedy_filter_pruning import greedy_pruner
+import torch
 
 del magnitude_pruner
 del automated_gradual_pruner
@@ -49,3 +50,88 @@ del level_pruner
 del sensitivity_pruner
 del structure_pruner
 del ranked_structures_pruner
+
+
+def mask_tensor(tensor, mask, inplace=True):
+    """Mask the provided tensor
+
+    Args:
+        tensor - the torch-tensor to mask
+        mask - binary coefficient-masking tensor.  Has the same type and shape as `tensor`
+    Returns:
+        tensor = tensor * mask  ;where * is the element-wise multiplication operator
+    """
+    assert tensor.type() == mask.type()
+    assert tensor.shape == mask.shape
+    if mask is not None:
+        return tensor.data.mul_(mask) if inplace else tensor.data.mul(mask)
+    return tensor
+
+
+def create_mask_threshold_criterion(tensor, threshold):
+    """Create a tensor mask using a threshold criterion.
+
+    All values smaller or equal to the threshold will be masked-away.
+    Granularity: Element-wise
+    Args:
+        tensor - the tensor to threshold.
+        threshold - a floating-point threshold value.
+    Returns:
+        boolean mask tensor, having the same size as the input tensor.
+    """
+    with torch.no_grad():
+        mask = torch.gt(torch.abs(tensor), threshold).type(tensor.type())
+        return mask
+
+
+def create_mask_level_criterion(tensor, desired_sparsity):
+    """Create a tensor mask using a level criterion.
+
+    A specified fraction of the input tensor will be masked.  The tensor coefficients
+    are first sorted by their L1-norm (absolute value), and then the lower `desired_sparsity`
+    coefficients are masked.
+    Granularity: Element-wise
+
+    WARNING: due to the implementation details (speed over correctness), this will perform
+    incorrectly if "too many" of the coefficients have the same value. E.g. this will fail:
+        a = torch.ones(3, 64, 32, 32)
+        mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3)
+        assert math.isclose(distiller.sparsity(mask), 0.3)
+
+    Args:
+        tensor - the tensor to mask.
+        desired_sparsity - a floating-point value in the range (0..1) specifying what
+            percentage of the tensor will be masked.
+    Returns:
+        boolean mask tensor, having the same size as the input tensor.
+    """
+    with torch.no_grad():
+        # partial sort
+        bottomk, _ = torch.topk(tensor.abs().view(-1),
+                                int(desired_sparsity * tensor.numel()),
+                                largest=False,
+                                sorted=True)
+        threshold = bottomk.data[-1]  # This is the largest element from the group of elements that we prune away
+        mask = create_mask_threshold_criterion(tensor, threshold)
+        return mask
+
+
+def create_mask_sensitivity_criterion(tensor, sensitivity):
+    """Create a tensor mask using a sensitivity criterion.
+
+    Mask an input tensor based on the variance of the distribution of the tensor coefficients.
+    Coefficients in the distribution's specified band around the mean will be masked (symmetrically).
+    Granularity: Element-wise
+    Args:
+        tensor - the tensor to mask.
+        sensitivity - a floating-point value specifying the sensitivity.  This is a simple
+            multiplier of the standard-deviation.
+    Returns:
+        boolean mask tensor, having the same size as the input tensor.
+    """
+    if not hasattr(tensor, 'stddev'):
+        tensor.stddev = torch.std(tensor).item()
+    with torch.no_grad():
+        threshold = tensor.stddev * sensitivity
+        mask = create_mask_threshold_criterion(tensor, threshold)
+        return mask
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index c2e6d4c..e499ae2 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -14,37 +14,19 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
-from .level_pruner import SparsityLevelParameterPruner
 from .ranked_structures_pruner import *
-from distiller.utils import *
-from functools import partial
+import distiller
 
 
-class AutomatedGradualPrunerBase(_ParameterPruner):
-    """Prune to an exact sparsity level specification using a prescribed sparsity
-    level schedule formula.
-
-    An automated gradual pruning algorithm that prunes the smallest magnitude
-    weights to achieve a preset level of network sparsity.
-
-    Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
-    efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
-    Learning of Phones and other Consumer Devices,
-    (https://arxiv.org/pdf/1710.01878.pdf)
+class AgpPruningRate(object):
+    """A pruning-rate scheduler per https://arxiv.org/pdf/1710.01878.pdf.
     """
-
-    def __init__(self, name, initial_sparsity, final_sparsity):
-        super().__init__(name)
+    def __init__(self, initial_sparsity, final_sparsity):
         self.initial_sparsity = initial_sparsity
         self.final_sparsity = final_sparsity
         assert final_sparsity > initial_sparsity
 
-    def compute_target_sparsity(self, meta):
-        starting_epoch = meta['starting_epoch']
-        current_epoch = meta['current_epoch']
-        ending_epoch = meta['ending_epoch']
-        freq = meta['frequency']
+    def step(self, current_epoch, starting_epoch, ending_epoch, freq):
         span = ((ending_epoch - starting_epoch - 1) // freq) * freq
         assert span > 0
 
@@ -54,8 +36,32 @@ class AutomatedGradualPrunerBase(_ParameterPruner):
 
         return target_sparsity
 
+
+class AutomatedGradualPrunerBase(object):
+    """Prune to an exact sparsity level specification using a prescribed sparsity
+    level schedule formula.
+
+    An automated gradual pruning algorithm that prunes the smallest magnitude
+    weights to achieve a preset level of network sparsity.
+
+    Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
+    efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
+    Learning of Phones and other Consumer Devices,
+    (https://arxiv.org/pdf/1710.01878.pdf)
+    """
+
+    def __init__(self, name, rate_scheduler):
+        """
+        Args:
+            rate_scheduler - schedules the pruning rate.  You can plug in any
+            rate scheduler.  We implemented AGP paper's rate control in AgpPruningRate.
+        """
+        self.name = name
+        self.agp_pr = rate_scheduler
+
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
-        target_sparsity = self.compute_target_sparsity(meta)
+        target_sparsity = self.agp_pr.step(meta['current_epoch'], meta['starting_epoch'],
+                                           meta['ending_epoch'], meta['frequency'])
         self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, meta['model'])
 
     def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
@@ -69,8 +75,8 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase):
     An automated gradual pruning algorithm that prunes the smallest magnitude
     weights to achieve a preset level of network sparsity.
     """
-    def __init__(self, name, initial_sparsity, final_sparsity, weights):
-        super().__init__(name, initial_sparsity, final_sparsity)
+    def __init__(self, name, initial_sparsity, final_sparsity, weights, rate_scheduler_factory=AgpPruningRate):
+        super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity))
         self.params_names = weights
         assert self.params_names
 
@@ -80,7 +86,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase):
         super().set_param_mask(param, param_name, zeros_mask_dict, meta)
 
     def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
-        zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity)
+        zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, target_sparsity)
 
 
 class StructuredAGP(AutomatedGradualPrunerBase):
@@ -89,8 +95,8 @@ class StructuredAGP(AutomatedGradualPrunerBase):
     This is a base-class for structured pruning with an AGP schedule.  It is an
     extension of the AGP concept introduced by Zhu et. al.
     """
-    def __init__(self, name, initial_sparsity, final_sparsity):
-        super().__init__(name, initial_sparsity, final_sparsity)
+    def __init__(self, name, initial_sparsity, final_sparsity, rate_scheduler_factory=AgpPruningRate):
+        super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity))
         self.pruner = None
 
     def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model):
diff --git a/distiller/pruning/baidu_rnn_pruner.py b/distiller/pruning/baidu_rnn_pruner.py
index f195b85..8ec7fac 100755
--- a/distiller/pruning/baidu_rnn_pruner.py
+++ b/distiller/pruning/baidu_rnn_pruner.py
@@ -14,13 +14,10 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
-from .level_pruner import SparsityLevelParameterPruner
-from distiller.utils import *
-
 import distiller
 
-class BaiduRNNPruner(_ParameterPruner):
+
+class BaiduRNNPruner(object):
     """An element-wise pruner for RNN networks.
 
     Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017).
@@ -53,7 +50,7 @@ class BaiduRNNPruner(_ParameterPruner):
     def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights):
         # Initialize the pruner, using a configuration that originates from the
         # schedule YAML file.
-        super(BaiduRNNPruner, self).__init__(name)
+        self.name = name
         self.params_names = weights
         assert self.params_names
 
@@ -93,4 +90,4 @@ class BaiduRNNPruner(_ParameterPruner):
                    self.ramp_slope  * (current_epoch  - ramp_epoch + 1)) / freq
 
         # After computing the threshold, we can create the mask
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps)
+        zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, eps)
diff --git a/distiller/pruning/level_pruner.py b/distiller/pruning/level_pruner.py
index dc6449d..53a07cb 100755
--- a/distiller/pruning/level_pruner.py
+++ b/distiller/pruning/level_pruner.py
@@ -14,11 +14,10 @@
 # limitations under the License.
 #
 
-import torch
-from .pruner import _ParameterPruner
 import distiller
 
-class SparsityLevelParameterPruner(_ParameterPruner):
+
+class SparsityLevelParameterPruner(object):
     """Prune to an exact pruning level specification.
 
     This pruner is very similar to MagnitudeParameterPruner, but instead of
@@ -30,7 +29,7 @@ class SparsityLevelParameterPruner(_ParameterPruner):
     """
 
     def __init__(self, name, levels, **kwargs):
-        super(SparsityLevelParameterPruner, self).__init__(name)
+        self.name = name
         self.levels = levels
         assert self.levels
 
@@ -40,13 +39,4 @@ class SparsityLevelParameterPruner(_ParameterPruner):
         desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0))
         if desired_sparsity == 0:
             return
-
-        zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, desired_sparsity)
-
-    @staticmethod
-    def create_mask(param, desired_sparsity):
-        with torch.no_grad():
-            bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True)
-            threshold = bottomk.data[-1]  # This is the largest element from the group of elements that we prune away
-            mask = distiller.threshold_mask(param.data, threshold)
-            return mask
\ No newline at end of file
+        zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, desired_sparsity)
diff --git a/distiller/pruning/magnitude_pruner.py b/distiller/pruning/magnitude_pruner.py
index c46732e..c8d5cd6 100755
--- a/distiller/pruning/magnitude_pruner.py
+++ b/distiller/pruning/magnitude_pruner.py
@@ -14,12 +14,11 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
 import distiller
 import torch
 
 
-class MagnitudeParameterPruner(_ParameterPruner):
+class MagnitudeParameterPruner(object):
     """This is the most basic magnitude-based pruner.
 
     This pruner supports configuring a scalar threshold for each layer.
@@ -43,7 +42,7 @@ class MagnitudeParameterPruner(_ParameterPruner):
                value is used.
                Currently it is mandatory to include a '*' key in 'thresholds'.
         """
-        super(MagnitudeParameterPruner, self).__init__(name)
+        self.name = name
         assert thresholds is not None
         # Make sure there is a default threshold to use
         assert '*' in thresholds
@@ -51,13 +50,8 @@ class MagnitudeParameterPruner(_ParameterPruner):
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         threshold = self.thresholds.get(param_name, self.thresholds['*'])
-        zeros_mask_dict[param_name].mask = self.create_mask(param.data, threshold)
+        zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, threshold)
 
-    @staticmethod
-    def create_mask(param, threshold):
-        with torch.no_grad():
-            mask = distiller.threshold_mask(param.data, threshold)
-            return mask
 
 
 
diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py
deleted file mode 100755
index 7ba5f2f..0000000
--- a/distiller/pruning/pruner.py
+++ /dev/null
@@ -1,50 +0,0 @@
-#
-# Copyright (c) 2018 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import torch
-import distiller
-
-
-__all__ = ["mask_tensor"]
-
-
-class _ParameterPruner(object):
-    """Base class for all pruners.
-
-    Arguments:
-        name: pruner name is used mainly for debugging.
-    """
-    def __init__(self, name):
-        self.name = name
-
-    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
-        raise NotImplementedError
-
-
-def mask_tensor(tensor, mask):
-    """Mask the provided tensor
-
-    Args:
-        tensor - the torch-tensor to mask
-        mask - binary coefficient-masking tensor.  Has the same type and shape as `tensor`
-    Returns:
-        tensor = tensor * mask  ;where * is the element-wise multiplication operator
-    """
-    assert tensor.type() == mask.type()
-    assert tensor.shape == mask.shape
-    if mask:
-        tensor.data.mul_(mask)
-    return tensor
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 629df73..ea11aa1 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -22,7 +22,6 @@ import torch
 from torch.nn import functional as f
 from random import uniform
 import distiller
-from .pruner import _ParameterPruner
 
 
 __all__ = ["LpRankedStructureParameterPruner",
@@ -39,12 +38,12 @@ __all__ = ["LpRankedStructureParameterPruner",
 msglogger = logging.getLogger(__name__)
 
 
-class _RankedStructureParameterPruner(_ParameterPruner):
+class _RankedStructureParameterPruner(object):
     """Base class for pruning structures by ranking them.
     """
     def __init__(self, name, group_type, desired_sparsity, weights, 
                  group_dependency=None, group_size=1, rounding_fn=math.floor, noise=0.):
-        super().__init__(name)
+        self.name = name
         self.group_type = group_type
         self.group_dependency = group_dependency
         self.params_names = weights
@@ -343,14 +342,15 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner):
 
         # Use the parameter name to locate the module that has the activation sparsity statistics
         fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
+        #distiller.assign_layer_fq_names(model)
         module = distiller.find_module_by_fq_name(model, fq_name)
         if module is None:
             raise ValueError("Could not find a layer named %s in the model."
                              "\nMake sure to use assign_layer_fq_names()" % fq_name)
         if not hasattr(module, self.activation_rank_criterion):
-            raise ValueError("Could not find attribute \"{}\" in module {}"
-                             "\nMake sure to use SummaryActivationStatsCollector(\"{}\")".
-                             format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion))
+            raise ValueError("Could not find attribute \"%s\" in module %s"
+                             "\nMake sure to use SummaryActivationStatsCollector(\"%s\")" %
+                             (self.activation_rank_criterion, fq_name, self.activation_rank_criterion))
 
         quality_criterion, std = getattr(module, self.activation_rank_criterion).value()
         num_filters = param.size(0)
@@ -469,10 +469,7 @@ class BernoulliFilterPruner(_RankedStructureParameterPruner):
 
 
 class GradientRankedFilterPruner(_RankedStructureParameterPruner):
-    """Taylor expansion ranking.
-
-    Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. Pruning Convolutional Neural
-    Networks for Resource Efficient Inference. ArXiv, abs/1611.06440, 2016.
+    """Rank the importance of weight filters using the product of their gradients and the filter value.
     """
     def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
@@ -495,22 +492,23 @@ class GradientRankedFilterPruner(_RankedStructureParameterPruner):
             msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
             return
 
-        # Compute the multiplication of the filters times the filter_gradienrs
+        # Compute the multiplication of the filters times the filter_gradients
         view_filters = param.view(param.size(0), -1)
         view_filter_grads = param.grad.view(param.size(0), -1)
-        weighted_gradients = view_filter_grads * view_filters
-        weighted_gradients = weighted_gradients.sum(dim=1)
+        with torch.no_grad():
+            weighted_gradients = view_filter_grads * view_filters
+            weighted_gradients = weighted_gradients.sum(dim=1)
 
-        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
-        filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
-        mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
-        zeros_mask_dict[param_name].mask = mask
+            # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
+            filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
+            mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
+            zeros_mask_dict[param_name].mask = mask
 
-        msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-                       param_name,
-                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
-                       fraction_to_prune, num_filters_to_prune, num_filters)
-        return binary_map
+            msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                           param_name,
+                           distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                           fraction_to_prune, num_filters_to_prune, num_filters)
+            return binary_map
 
 
 from sklearn.linear_model import LinearRegression
diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py
index 5fc4f3f..5832399 100755
--- a/distiller/pruning/sensitivity_pruner.py
+++ b/distiller/pruning/sensitivity_pruner.py
@@ -14,12 +14,11 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
+
 import distiller
-import torch
 
 
-class SensitivityPruner(_ParameterPruner):
+class SensitivityPruner(object):
     """Use algorithm from "Learning both Weights and Connections for Efficient
     Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf
 
@@ -41,7 +40,7 @@ class SensitivityPruner(_ParameterPruner):
     """
 
     def __init__(self, name, sensitivities, **kwargs):
-        super(SensitivityPruner, self).__init__(name)
+        self.name = name
         self.sensitivities = sensitivities
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
@@ -53,13 +52,4 @@ class SensitivityPruner(_ParameterPruner):
         else:
             sensitivity = self.sensitivities[param_name]
 
-        zeros_mask_dict[param_name].mask = self.create_mask(param, sensitivity)
-
-    @staticmethod
-    def create_mask(param, sensitivity):
-        if not hasattr(param, 'stddev'):
-            param.stddev = torch.std(param).item()
-        with torch.no_grad():
-            threshold = param.stddev * sensitivity
-            mask = distiller.threshold_mask(param.data, threshold)
-            return mask
+        zeros_mask_dict[param_name].mask = distiller.create_mask_sensitivity_criterion(param, sensitivity)
diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py
index 2a2edd3..4dbb0b2 100755
--- a/distiller/pruning/splicing_pruner.py
+++ b/distiller/pruning/splicing_pruner.py
@@ -14,14 +14,12 @@
 # limitations under the License.
 #
 
-
-from .pruner import _ParameterPruner
 import torch
 import logging
 msglogger = logging.getLogger()
 
 
-class SplicingPruner(_ParameterPruner):
+class SplicingPruner(object):
     """A pruner that both prunes and splices connections.
 
     The idea of pruning and splicing working in tandem was first proposed in the following
@@ -37,7 +35,7 @@ class SplicingPruner(_ParameterPruner):
     def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0):
         """Arguments:
         """
-        super(SplicingPruner, self).__init__(name)
+        self.name = name
         self.sensitivities = sensitivities
         self.low_thresh_mult = low_thresh_mult
         self.hi_thresh_mult = hi_thresh_mult
diff --git a/distiller/pruning/structure_pruner.py b/distiller/pruning/structure_pruner.py
index efd98a3..0cbb42a 100755
--- a/distiller/pruning/structure_pruner.py
+++ b/distiller/pruning/structure_pruner.py
@@ -15,12 +15,11 @@
 #
 
 import logging
-from .pruner import _ParameterPruner
 import distiller
 msglogger = logging.getLogger()
 
 
-class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner):
+class StructureParameterPruner(distiller.GroupThresholdMixin):
     """Prune parameter structures.
 
     Pruning criterion: average L1-norm.  If the average L1-norm (absolute value) of the eleements
@@ -30,7 +29,6 @@ class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner):
     the structure size.
     """
     def __init__(self, name, model, reg_regims, threshold_criteria):
-        super(StructureParameterPruner, self).__init__(name)
         self.name = name
         self.model = model
         self.reg_regims = reg_regims
diff --git a/distiller/regularization/l1_regularizer.py b/distiller/regularization/l1_regularizer.py
index fdd6007..a79c32d 100755
--- a/distiller/regularization/l1_regularizer.py
+++ b/distiller/regularization/l1_regularizer.py
@@ -16,9 +16,7 @@
 
 """L1-norm regularization"""
 
-import torch
-import math
-import numpy as np
+
 import distiller
 from .regularizer import _Regularizer, EPSILON
 
@@ -39,7 +37,7 @@ class L1Regularizer(_Regularizer):
             return
 
         strength = self.reg_regims[param_name]
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength)
+        zeros_mask_dict[param_name].mask = distiller.pruning.create_mask_threshold_criterion(param, threshold=strength)
         zeros_mask_dict[param_name].is_regularization_mask = True
 
     @staticmethod
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 45a25e3..82266f0 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -263,7 +263,7 @@ class ParameterMasker(object):
 
     def revert_weights(self, parameter):
         if not self.use_double_copies or self.unmasked_copy is None:
-            msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
+            # This parameter does not maintain double copies (this is OK)
             return
         parameter.data.copy_(self.unmasked_copy)
         self.unmasked_copy = None
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index cb38909..590adc3 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -23,23 +23,10 @@ import numpy as np
 from distiller.norms import *
 
 
-__all__ = ["threshold_mask", "GroupThresholdMixin",
+__all__ = ["GroupThresholdMixin",
            "group_threshold_binary_map", "group_threshold_mask"]
 
 
-def threshold_mask(param, threshold):
-    """Create a threshold mask for the provided parameter tensor using
-    magnitude thresholding.
-
-    Arguments:
-        param: a parameter tensor which should be pruned.
-        threshold: the pruning threshold.
-    Returns:
-        prune_mask: The pruning mask.
-    """
-    return torch.gt(torch.abs(param), threshold).type(param.type())
-
-
 class GroupThresholdMixin(object):
     """A mixin class to add group thresholding capabilities
 
diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py
index f231abd..e014efa 100755
--- a/tests/test_thresholding.py
+++ b/tests/test_thresholding.py
@@ -77,12 +77,32 @@ def test_threshold_mask():
     # Change one element
     a[1, 4, 17, 31] = 0.2
     # Create and apply a mask
-    mask = distiller.threshold_mask(a, threshold=0.3)
+    mask = distiller.create_mask_threshold_criterion(a, threshold=0.3)
     assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
     assert mask[1, 4, 17, 31] == 0
     assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
 
 
+def test_level_mask():
+    # Create a 4-D tensor of 1s
+    a = torch.rand(3, 64, 32, 32)
+
+    # Create and apply a mask
+    mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3)
+    assert common.almost_equal(distiller.sparsity(mask), 0.3, max_diff=0.0001)
+
+
+def test_sensitivity_mask():
+    # Create a 4-D tensor of normally-distributed coefficients
+    a = torch.randn(3, 64, 32, 32)
+
+    # Create and apply a mask
+    mask = distiller.create_mask_sensitivity_criterion(a, sensitivity=1)
+    # The width of 1-std on ~N(0,1) is about 68.27%.  In other words:
+    # Pr(mean - std <= X <= mean + std) is about 68.27%
+    assert common.almost_equal(distiller.sparsity(mask), 0.6827, max_diff=0.005)
+
+
 def test_kernel_thresholding():
     p = get_test_4d_tensor().cuda()
     mask, map = distiller.group_threshold_mask(p, '2D', 6, 'L1')
@@ -240,6 +260,8 @@ def test_row_thresholding():
                                         [ 0.,  0.,  0.],
                                         [ 1.,  1.,  1.],
                                         [ 1.,  1.,  1.]], device=mask.device)).all()
+    masked_tensor = distiller.mask_tensor(p, mask)
+    assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask)
     return mask
 
 
-- 
GitLab