diff --git a/distiller/__init__.py b/distiller/__init__.py
index e68701064123cc83fe1c6c59e2d18ad5af92b7fb..13ef908e3cc8cfa7086004c3d97d47be4f4c5c6e 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -16,10 +16,11 @@
 
 import torch
 from .utils import *
-from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask
+from .thresholding import GroupThresholdMixin, group_threshold_mask
 from .config import file_config, dict_config, config_component_from_file_by_class
 from .model_summaries import *
 from .scheduler import *
+from .pruning import *
 from .sensitivity import *
 from .directives import *
 from .policy import *
@@ -41,6 +42,21 @@ except pkg_resources.DistributionNotFound:
     __version__ = "Unknown"
 
 
+def __check_pytorch_version():
+    from pkg_resources import parse_version
+    required = '1.3.1'
+    actual = torch.__version__
+    if parse_version(actual) < parse_version(required):
+        msg = "\n\nWRONG PYTORCH VERSION\n"\
+              "Required:  {}\n" \
+              "Installed: {}\n"\
+              "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual)
+        raise RuntimeError(msg)
+
+
+__check_pytorch_version()
+
+
 def model_find_param_name(model, param_to_find):
     """Look up the name of a model parameter.
 
@@ -104,17 +120,3 @@ def model_find_module(model, module_to_find):
             return m
     return None
 
-
-def check_pytorch_version():
-    from pkg_resources import parse_version
-    required = '1.3.1'
-    actual = torch.__version__
-    if parse_version(actual) < parse_version(required):
-        msg = "\n\nWRONG PYTORCH VERSION\n"\
-              "Required:  {}\n" \
-              "Installed: {}\n"\
-              "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual)
-        raise RuntimeError(msg)
-
-
-check_pytorch_version()
diff --git a/distiller/norms.py b/distiller/norms.py
index a73a07fc1ee65d45fb5e5289e5446afda69f038a..c43019bc9557dd16f0a967f8225d74888f658249 100644
--- a/distiller/norms.py
+++ b/distiller/norms.py
@@ -33,6 +33,7 @@ see: https://www.kaggle.com/residentmario/l1-norms-versus-l2-norms)
 import torch
 import numpy as np
 from functools import partial
+from random import uniform
 
 
 __all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm",
diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index 710c7b5f7c040fce7d425a23399e7cbb40c04d61..25e115f922c0378b5327b4f66749179646171730 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -42,6 +42,7 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \
                                       FMReconstructionChannelPruner
 from .baidu_rnn_pruner import BaiduRNNPruner
 from .greedy_filter_pruning import greedy_pruner
+import torch
 
 del magnitude_pruner
 del automated_gradual_pruner
@@ -49,3 +50,88 @@ del level_pruner
 del sensitivity_pruner
 del structure_pruner
 del ranked_structures_pruner
+
+
+def mask_tensor(tensor, mask, inplace=True):
+    """Mask the provided tensor
+
+    Args:
+        tensor - the torch-tensor to mask
+        mask - binary coefficient-masking tensor.  Has the same type and shape as `tensor`
+    Returns:
+        tensor = tensor * mask  ;where * is the element-wise multiplication operator
+    """
+    assert tensor.type() == mask.type()
+    assert tensor.shape == mask.shape
+    if mask is not None:
+        return tensor.data.mul_(mask) if inplace else tensor.data.mul(mask)
+    return tensor
+
+
+def create_mask_threshold_criterion(tensor, threshold):
+    """Create a tensor mask using a threshold criterion.
+
+    All values smaller or equal to the threshold will be masked-away.
+    Granularity: Element-wise
+    Args:
+        tensor - the tensor to threshold.
+        threshold - a floating-point threshold value.
+    Returns:
+        boolean mask tensor, having the same size as the input tensor.
+    """
+    with torch.no_grad():
+        mask = torch.gt(torch.abs(tensor), threshold).type(tensor.type())
+        return mask
+
+
+def create_mask_level_criterion(tensor, desired_sparsity):
+    """Create a tensor mask using a level criterion.
+
+    A specified fraction of the input tensor will be masked.  The tensor coefficients
+    are first sorted by their L1-norm (absolute value), and then the lower `desired_sparsity`
+    coefficients are masked.
+    Granularity: Element-wise
+
+    WARNING: due to the implementation details (speed over correctness), this will perform
+    incorrectly if "too many" of the coefficients have the same value. E.g. this will fail:
+        a = torch.ones(3, 64, 32, 32)
+        mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3)
+        assert math.isclose(distiller.sparsity(mask), 0.3)
+
+    Args:
+        tensor - the tensor to mask.
+        desired_sparsity - a floating-point value in the range (0..1) specifying what
+            percentage of the tensor will be masked.
+    Returns:
+        boolean mask tensor, having the same size as the input tensor.
+    """
+    with torch.no_grad():
+        # partial sort
+        bottomk, _ = torch.topk(tensor.abs().view(-1),
+                                int(desired_sparsity * tensor.numel()),
+                                largest=False,
+                                sorted=True)
+        threshold = bottomk.data[-1]  # This is the largest element from the group of elements that we prune away
+        mask = create_mask_threshold_criterion(tensor, threshold)
+        return mask
+
+
+def create_mask_sensitivity_criterion(tensor, sensitivity):
+    """Create a tensor mask using a sensitivity criterion.
+
+    Mask an input tensor based on the variance of the distribution of the tensor coefficients.
+    Coefficients in the distribution's specified band around the mean will be masked (symmetrically).
+    Granularity: Element-wise
+    Args:
+        tensor - the tensor to mask.
+        sensitivity - a floating-point value specifying the sensitivity.  This is a simple
+            multiplier of the standard-deviation.
+    Returns:
+        boolean mask tensor, having the same size as the input tensor.
+    """
+    if not hasattr(tensor, 'stddev'):
+        tensor.stddev = torch.std(tensor).item()
+    with torch.no_grad():
+        threshold = tensor.stddev * sensitivity
+        mask = create_mask_threshold_criterion(tensor, threshold)
+        return mask
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index c2e6d4ce44291895fc9dda0642915c0a86e18b9c..e499ae2a55d990d28d2f1e502502b4b1e88dfe32 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -14,37 +14,19 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
-from .level_pruner import SparsityLevelParameterPruner
 from .ranked_structures_pruner import *
-from distiller.utils import *
-from functools import partial
+import distiller
 
 
-class AutomatedGradualPrunerBase(_ParameterPruner):
-    """Prune to an exact sparsity level specification using a prescribed sparsity
-    level schedule formula.
-
-    An automated gradual pruning algorithm that prunes the smallest magnitude
-    weights to achieve a preset level of network sparsity.
-
-    Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
-    efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
-    Learning of Phones and other Consumer Devices,
-    (https://arxiv.org/pdf/1710.01878.pdf)
+class AgpPruningRate(object):
+    """A pruning-rate scheduler per https://arxiv.org/pdf/1710.01878.pdf.
     """
-
-    def __init__(self, name, initial_sparsity, final_sparsity):
-        super().__init__(name)
+    def __init__(self, initial_sparsity, final_sparsity):
         self.initial_sparsity = initial_sparsity
         self.final_sparsity = final_sparsity
         assert final_sparsity > initial_sparsity
 
-    def compute_target_sparsity(self, meta):
-        starting_epoch = meta['starting_epoch']
-        current_epoch = meta['current_epoch']
-        ending_epoch = meta['ending_epoch']
-        freq = meta['frequency']
+    def step(self, current_epoch, starting_epoch, ending_epoch, freq):
         span = ((ending_epoch - starting_epoch - 1) // freq) * freq
         assert span > 0
 
@@ -54,8 +36,32 @@ class AutomatedGradualPrunerBase(_ParameterPruner):
 
         return target_sparsity
 
+
+class AutomatedGradualPrunerBase(object):
+    """Prune to an exact sparsity level specification using a prescribed sparsity
+    level schedule formula.
+
+    An automated gradual pruning algorithm that prunes the smallest magnitude
+    weights to achieve a preset level of network sparsity.
+
+    Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
+    efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
+    Learning of Phones and other Consumer Devices,
+    (https://arxiv.org/pdf/1710.01878.pdf)
+    """
+
+    def __init__(self, name, rate_scheduler):
+        """
+        Args:
+            rate_scheduler - schedules the pruning rate.  You can plug in any
+            rate scheduler.  We implemented AGP paper's rate control in AgpPruningRate.
+        """
+        self.name = name
+        self.agp_pr = rate_scheduler
+
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
-        target_sparsity = self.compute_target_sparsity(meta)
+        target_sparsity = self.agp_pr.step(meta['current_epoch'], meta['starting_epoch'],
+                                           meta['ending_epoch'], meta['frequency'])
         self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, meta['model'])
 
     def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
@@ -69,8 +75,8 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase):
     An automated gradual pruning algorithm that prunes the smallest magnitude
     weights to achieve a preset level of network sparsity.
     """
-    def __init__(self, name, initial_sparsity, final_sparsity, weights):
-        super().__init__(name, initial_sparsity, final_sparsity)
+    def __init__(self, name, initial_sparsity, final_sparsity, weights, rate_scheduler_factory=AgpPruningRate):
+        super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity))
         self.params_names = weights
         assert self.params_names
 
@@ -80,7 +86,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase):
         super().set_param_mask(param, param_name, zeros_mask_dict, meta)
 
     def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
-        zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity)
+        zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, target_sparsity)
 
 
 class StructuredAGP(AutomatedGradualPrunerBase):
@@ -89,8 +95,8 @@ class StructuredAGP(AutomatedGradualPrunerBase):
     This is a base-class for structured pruning with an AGP schedule.  It is an
     extension of the AGP concept introduced by Zhu et. al.
     """
-    def __init__(self, name, initial_sparsity, final_sparsity):
-        super().__init__(name, initial_sparsity, final_sparsity)
+    def __init__(self, name, initial_sparsity, final_sparsity, rate_scheduler_factory=AgpPruningRate):
+        super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity))
         self.pruner = None
 
     def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model):
diff --git a/distiller/pruning/baidu_rnn_pruner.py b/distiller/pruning/baidu_rnn_pruner.py
index f195b8598cbc2614b9e4d7522e6672dc13bffd6c..8ec7fac1e24b55320246fc45a2dae44026bdfc1e 100755
--- a/distiller/pruning/baidu_rnn_pruner.py
+++ b/distiller/pruning/baidu_rnn_pruner.py
@@ -14,13 +14,10 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
-from .level_pruner import SparsityLevelParameterPruner
-from distiller.utils import *
-
 import distiller
 
-class BaiduRNNPruner(_ParameterPruner):
+
+class BaiduRNNPruner(object):
     """An element-wise pruner for RNN networks.
 
     Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017).
@@ -53,7 +50,7 @@ class BaiduRNNPruner(_ParameterPruner):
     def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights):
         # Initialize the pruner, using a configuration that originates from the
         # schedule YAML file.
-        super(BaiduRNNPruner, self).__init__(name)
+        self.name = name
         self.params_names = weights
         assert self.params_names
 
@@ -93,4 +90,4 @@ class BaiduRNNPruner(_ParameterPruner):
                    self.ramp_slope  * (current_epoch  - ramp_epoch + 1)) / freq
 
         # After computing the threshold, we can create the mask
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps)
+        zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, eps)
diff --git a/distiller/pruning/level_pruner.py b/distiller/pruning/level_pruner.py
index dc6449dea88c0acbda66d0cc9a494fa514407ff7..53a07cbf5c544a62547e6829bb468fbad0745512 100755
--- a/distiller/pruning/level_pruner.py
+++ b/distiller/pruning/level_pruner.py
@@ -14,11 +14,10 @@
 # limitations under the License.
 #
 
-import torch
-from .pruner import _ParameterPruner
 import distiller
 
-class SparsityLevelParameterPruner(_ParameterPruner):
+
+class SparsityLevelParameterPruner(object):
     """Prune to an exact pruning level specification.
 
     This pruner is very similar to MagnitudeParameterPruner, but instead of
@@ -30,7 +29,7 @@ class SparsityLevelParameterPruner(_ParameterPruner):
     """
 
     def __init__(self, name, levels, **kwargs):
-        super(SparsityLevelParameterPruner, self).__init__(name)
+        self.name = name
         self.levels = levels
         assert self.levels
 
@@ -40,13 +39,4 @@ class SparsityLevelParameterPruner(_ParameterPruner):
         desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0))
         if desired_sparsity == 0:
             return
-
-        zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, desired_sparsity)
-
-    @staticmethod
-    def create_mask(param, desired_sparsity):
-        with torch.no_grad():
-            bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True)
-            threshold = bottomk.data[-1]  # This is the largest element from the group of elements that we prune away
-            mask = distiller.threshold_mask(param.data, threshold)
-            return mask
\ No newline at end of file
+        zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, desired_sparsity)
diff --git a/distiller/pruning/magnitude_pruner.py b/distiller/pruning/magnitude_pruner.py
index c46732edf76018885b0949f331e464b0c3e69311..c8d5cd692e7bfdfc499cb2845a6dcfa9f0b37a4f 100755
--- a/distiller/pruning/magnitude_pruner.py
+++ b/distiller/pruning/magnitude_pruner.py
@@ -14,12 +14,11 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
 import distiller
 import torch
 
 
-class MagnitudeParameterPruner(_ParameterPruner):
+class MagnitudeParameterPruner(object):
     """This is the most basic magnitude-based pruner.
 
     This pruner supports configuring a scalar threshold for each layer.
@@ -43,7 +42,7 @@ class MagnitudeParameterPruner(_ParameterPruner):
                value is used.
                Currently it is mandatory to include a '*' key in 'thresholds'.
         """
-        super(MagnitudeParameterPruner, self).__init__(name)
+        self.name = name
         assert thresholds is not None
         # Make sure there is a default threshold to use
         assert '*' in thresholds
@@ -51,13 +50,8 @@ class MagnitudeParameterPruner(_ParameterPruner):
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         threshold = self.thresholds.get(param_name, self.thresholds['*'])
-        zeros_mask_dict[param_name].mask = self.create_mask(param.data, threshold)
+        zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, threshold)
 
-    @staticmethod
-    def create_mask(param, threshold):
-        with torch.no_grad():
-            mask = distiller.threshold_mask(param.data, threshold)
-            return mask
 
 
 
diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py
deleted file mode 100755
index 7ba5f2f9f42fe4c2d4978cc2d895e8b2eda89f4d..0000000000000000000000000000000000000000
--- a/distiller/pruning/pruner.py
+++ /dev/null
@@ -1,50 +0,0 @@
-#
-# Copyright (c) 2018 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import torch
-import distiller
-
-
-__all__ = ["mask_tensor"]
-
-
-class _ParameterPruner(object):
-    """Base class for all pruners.
-
-    Arguments:
-        name: pruner name is used mainly for debugging.
-    """
-    def __init__(self, name):
-        self.name = name
-
-    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
-        raise NotImplementedError
-
-
-def mask_tensor(tensor, mask):
-    """Mask the provided tensor
-
-    Args:
-        tensor - the torch-tensor to mask
-        mask - binary coefficient-masking tensor.  Has the same type and shape as `tensor`
-    Returns:
-        tensor = tensor * mask  ;where * is the element-wise multiplication operator
-    """
-    assert tensor.type() == mask.type()
-    assert tensor.shape == mask.shape
-    if mask:
-        tensor.data.mul_(mask)
-    return tensor
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 629df73f4de52907b285eea4fa11e994a0bbe10b..ea11aa106cc140ecd101aa1985df3e940881a101 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -22,7 +22,6 @@ import torch
 from torch.nn import functional as f
 from random import uniform
 import distiller
-from .pruner import _ParameterPruner
 
 
 __all__ = ["LpRankedStructureParameterPruner",
@@ -39,12 +38,12 @@ __all__ = ["LpRankedStructureParameterPruner",
 msglogger = logging.getLogger(__name__)
 
 
-class _RankedStructureParameterPruner(_ParameterPruner):
+class _RankedStructureParameterPruner(object):
     """Base class for pruning structures by ranking them.
     """
     def __init__(self, name, group_type, desired_sparsity, weights, 
                  group_dependency=None, group_size=1, rounding_fn=math.floor, noise=0.):
-        super().__init__(name)
+        self.name = name
         self.group_type = group_type
         self.group_dependency = group_dependency
         self.params_names = weights
@@ -343,14 +342,15 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner):
 
         # Use the parameter name to locate the module that has the activation sparsity statistics
         fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
+        #distiller.assign_layer_fq_names(model)
         module = distiller.find_module_by_fq_name(model, fq_name)
         if module is None:
             raise ValueError("Could not find a layer named %s in the model."
                              "\nMake sure to use assign_layer_fq_names()" % fq_name)
         if not hasattr(module, self.activation_rank_criterion):
-            raise ValueError("Could not find attribute \"{}\" in module {}"
-                             "\nMake sure to use SummaryActivationStatsCollector(\"{}\")".
-                             format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion))
+            raise ValueError("Could not find attribute \"%s\" in module %s"
+                             "\nMake sure to use SummaryActivationStatsCollector(\"%s\")" %
+                             (self.activation_rank_criterion, fq_name, self.activation_rank_criterion))
 
         quality_criterion, std = getattr(module, self.activation_rank_criterion).value()
         num_filters = param.size(0)
@@ -469,10 +469,7 @@ class BernoulliFilterPruner(_RankedStructureParameterPruner):
 
 
 class GradientRankedFilterPruner(_RankedStructureParameterPruner):
-    """Taylor expansion ranking.
-
-    Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. Pruning Convolutional Neural
-    Networks for Resource Efficient Inference. ArXiv, abs/1611.06440, 2016.
+    """Rank the importance of weight filters using the product of their gradients and the filter value.
     """
     def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
@@ -495,22 +492,23 @@ class GradientRankedFilterPruner(_RankedStructureParameterPruner):
             msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
             return
 
-        # Compute the multiplication of the filters times the filter_gradienrs
+        # Compute the multiplication of the filters times the filter_gradients
         view_filters = param.view(param.size(0), -1)
         view_filter_grads = param.grad.view(param.size(0), -1)
-        weighted_gradients = view_filter_grads * view_filters
-        weighted_gradients = weighted_gradients.sum(dim=1)
+        with torch.no_grad():
+            weighted_gradients = view_filter_grads * view_filters
+            weighted_gradients = weighted_gradients.sum(dim=1)
 
-        # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
-        filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
-        mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
-        zeros_mask_dict[param_name].mask = mask
+            # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
+            filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
+            mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
+            zeros_mask_dict[param_name].mask = mask
 
-        msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-                       param_name,
-                       distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
-                       fraction_to_prune, num_filters_to_prune, num_filters)
-        return binary_map
+            msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                           param_name,
+                           distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                           fraction_to_prune, num_filters_to_prune, num_filters)
+            return binary_map
 
 
 from sklearn.linear_model import LinearRegression
diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py
index 5fc4f3f4e9b363f9b2c984b062c8cafd18d7a870..5832399dfb1ceb97cbb01387d762fb40765a9be9 100755
--- a/distiller/pruning/sensitivity_pruner.py
+++ b/distiller/pruning/sensitivity_pruner.py
@@ -14,12 +14,11 @@
 # limitations under the License.
 #
 
-from .pruner import _ParameterPruner
+
 import distiller
-import torch
 
 
-class SensitivityPruner(_ParameterPruner):
+class SensitivityPruner(object):
     """Use algorithm from "Learning both Weights and Connections for Efficient
     Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf
 
@@ -41,7 +40,7 @@ class SensitivityPruner(_ParameterPruner):
     """
 
     def __init__(self, name, sensitivities, **kwargs):
-        super(SensitivityPruner, self).__init__(name)
+        self.name = name
         self.sensitivities = sensitivities
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
@@ -53,13 +52,4 @@ class SensitivityPruner(_ParameterPruner):
         else:
             sensitivity = self.sensitivities[param_name]
 
-        zeros_mask_dict[param_name].mask = self.create_mask(param, sensitivity)
-
-    @staticmethod
-    def create_mask(param, sensitivity):
-        if not hasattr(param, 'stddev'):
-            param.stddev = torch.std(param).item()
-        with torch.no_grad():
-            threshold = param.stddev * sensitivity
-            mask = distiller.threshold_mask(param.data, threshold)
-            return mask
+        zeros_mask_dict[param_name].mask = distiller.create_mask_sensitivity_criterion(param, sensitivity)
diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py
index 2a2edd3411fdd31b71f1bc95591ebc8563d8e70e..4dbb0b2e4a636848e2483b569a6aa76665121138 100755
--- a/distiller/pruning/splicing_pruner.py
+++ b/distiller/pruning/splicing_pruner.py
@@ -14,14 +14,12 @@
 # limitations under the License.
 #
 
-
-from .pruner import _ParameterPruner
 import torch
 import logging
 msglogger = logging.getLogger()
 
 
-class SplicingPruner(_ParameterPruner):
+class SplicingPruner(object):
     """A pruner that both prunes and splices connections.
 
     The idea of pruning and splicing working in tandem was first proposed in the following
@@ -37,7 +35,7 @@ class SplicingPruner(_ParameterPruner):
     def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0):
         """Arguments:
         """
-        super(SplicingPruner, self).__init__(name)
+        self.name = name
         self.sensitivities = sensitivities
         self.low_thresh_mult = low_thresh_mult
         self.hi_thresh_mult = hi_thresh_mult
diff --git a/distiller/pruning/structure_pruner.py b/distiller/pruning/structure_pruner.py
index efd98a36950ef887bd9816fac4d8756b93c6c899..0cbb42a3759e6c4111320ab629b175996c62f4b2 100755
--- a/distiller/pruning/structure_pruner.py
+++ b/distiller/pruning/structure_pruner.py
@@ -15,12 +15,11 @@
 #
 
 import logging
-from .pruner import _ParameterPruner
 import distiller
 msglogger = logging.getLogger()
 
 
-class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner):
+class StructureParameterPruner(distiller.GroupThresholdMixin):
     """Prune parameter structures.
 
     Pruning criterion: average L1-norm.  If the average L1-norm (absolute value) of the eleements
@@ -30,7 +29,6 @@ class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner):
     the structure size.
     """
     def __init__(self, name, model, reg_regims, threshold_criteria):
-        super(StructureParameterPruner, self).__init__(name)
         self.name = name
         self.model = model
         self.reg_regims = reg_regims
diff --git a/distiller/regularization/l1_regularizer.py b/distiller/regularization/l1_regularizer.py
index fdd6007ebc8562d5a0774c24118d0470fbfc4d43..a79c32d44cb7ebce408ed8a0264051ead6e1abb9 100755
--- a/distiller/regularization/l1_regularizer.py
+++ b/distiller/regularization/l1_regularizer.py
@@ -16,9 +16,7 @@
 
 """L1-norm regularization"""
 
-import torch
-import math
-import numpy as np
+
 import distiller
 from .regularizer import _Regularizer, EPSILON
 
@@ -39,7 +37,7 @@ class L1Regularizer(_Regularizer):
             return
 
         strength = self.reg_regims[param_name]
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength)
+        zeros_mask_dict[param_name].mask = distiller.pruning.create_mask_threshold_criterion(param, threshold=strength)
         zeros_mask_dict[param_name].is_regularization_mask = True
 
     @staticmethod
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 45a25e379ddec155446fc0f63dd98bf503a5e627..82266f04a0aa5f0e9f2e1af671f2dda1a4ffec89 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -263,7 +263,7 @@ class ParameterMasker(object):
 
     def revert_weights(self, parameter):
         if not self.use_double_copies or self.unmasked_copy is None:
-            msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
+            # This parameter does not maintain double copies (this is OK)
             return
         parameter.data.copy_(self.unmasked_copy)
         self.unmasked_copy = None
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index cb38909667788baa6d5d360230a08b6031654ccc..590adc321814bfce85a3e742d036e51361fcdb1a 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -23,23 +23,10 @@ import numpy as np
 from distiller.norms import *
 
 
-__all__ = ["threshold_mask", "GroupThresholdMixin",
+__all__ = ["GroupThresholdMixin",
            "group_threshold_binary_map", "group_threshold_mask"]
 
 
-def threshold_mask(param, threshold):
-    """Create a threshold mask for the provided parameter tensor using
-    magnitude thresholding.
-
-    Arguments:
-        param: a parameter tensor which should be pruned.
-        threshold: the pruning threshold.
-    Returns:
-        prune_mask: The pruning mask.
-    """
-    return torch.gt(torch.abs(param), threshold).type(param.type())
-
-
 class GroupThresholdMixin(object):
     """A mixin class to add group thresholding capabilities
 
diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py
index f231abda93eedd19f0f7c0b46f6dfdce8baf1ffe..e014efa2ba14054d19f2255759ae4bd59ee547ef 100755
--- a/tests/test_thresholding.py
+++ b/tests/test_thresholding.py
@@ -77,12 +77,32 @@ def test_threshold_mask():
     # Change one element
     a[1, 4, 17, 31] = 0.2
     # Create and apply a mask
-    mask = distiller.threshold_mask(a, threshold=0.3)
+    mask = distiller.create_mask_threshold_criterion(a, threshold=0.3)
     assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
     assert mask[1, 4, 17, 31] == 0
     assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
 
 
+def test_level_mask():
+    # Create a 4-D tensor of 1s
+    a = torch.rand(3, 64, 32, 32)
+
+    # Create and apply a mask
+    mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3)
+    assert common.almost_equal(distiller.sparsity(mask), 0.3, max_diff=0.0001)
+
+
+def test_sensitivity_mask():
+    # Create a 4-D tensor of normally-distributed coefficients
+    a = torch.randn(3, 64, 32, 32)
+
+    # Create and apply a mask
+    mask = distiller.create_mask_sensitivity_criterion(a, sensitivity=1)
+    # The width of 1-std on ~N(0,1) is about 68.27%.  In other words:
+    # Pr(mean - std <= X <= mean + std) is about 68.27%
+    assert common.almost_equal(distiller.sparsity(mask), 0.6827, max_diff=0.005)
+
+
 def test_kernel_thresholding():
     p = get_test_4d_tensor().cuda()
     mask, map = distiller.group_threshold_mask(p, '2D', 6, 'L1')
@@ -240,6 +260,8 @@ def test_row_thresholding():
                                         [ 0.,  0.,  0.],
                                         [ 1.,  1.,  1.],
                                         [ 1.,  1.,  1.]], device=mask.device)).all()
+    masked_tensor = distiller.mask_tensor(p, mask)
+    assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask)
     return mask