From 05d5592e2275db5e5085d63882d89462c741c512 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 7 Oct 2019 00:32:35 +0300
Subject: [PATCH] Low-level pruning API refactor (#401)

Some refactoring of the low-level pruning API

Added distiller/norms.py - for calculating norms of various sub-tensors.

ranked_structures_pruner.py:
-Removed l1_magnitude, l2_magnitude. Use instead distiller.norms.l1_norm
-Lots of refactoring
-replaced LpRankedStructureParameterPruner.ch_binary_map_to_mask with
distiller.thresholding.expand_binary_map
-FMReconstructionChannelPruner.rank_and_prune_channels used L2-norm
by default and now uses L1-norm (i.e.magnitude_fn=l2_magnitude was
replaced with magnitude_fn=distiller.norms.l1_norm)

thresholding.py:
-Delegated lots of the work to the new norms.py.
-Removed support for 4D (entire convolution layers) since that has not been
maintained for a longtime. This may break some old scripts that remove entire
layers.
-added expand_binary_map() explicitly so others can use it. Might need to
move to a different file
-removed threshold_policy()

utils.py:
-use distiller.norms.xxx for sparsity stats
---
 distiller/norms.py                            | 326 ++++++++++++++++++
 distiller/pruning/automated_gradual_pruner.py |   4 +-
 distiller/pruning/level_pruner.py             |  12 +-
 distiller/pruning/magnitude_pruner.py         |  12 +-
 distiller/pruning/pruner.py                   |  24 +-
 distiller/pruning/ranked_structures_pruner.py | 193 ++++-------
 distiller/pruning/sensitivity_pruner.py       |  16 +-
 distiller/pruning/splicing_pruner.py          |  61 ++--
 distiller/pruning/structure_pruner.py         |   1 +
 distiller/thresholding.py                     | 181 +++++-----
 distiller/utils.py                            |  32 +-
 examples/greedy_pruning/greedy_pruning.ipynb  |  13 +-
 .../resnet20.network_surgery.yaml             |   4 +-
 tests/test_infra.py                           |   6 +-
 tests/test_pruning.py                         |  12 -
 tests/test_thresholding.py                    | 227 +++++++++++-
 16 files changed, 790 insertions(+), 334 deletions(-)
 create mode 100644 distiller/norms.py

diff --git a/distiller/norms.py b/distiller/norms.py
new file mode 100644
index 0000000..46cca86
--- /dev/null
+++ b/distiller/norms.py
@@ -0,0 +1,326 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Norm functions.
+
+Norms functions map a tensor to a single real-valued scalar that represents
+the tensor's magnitude according to some definition.  p-norms (Lp norms)
+are the most common magnitude functions.
+
+Many times we want to divide a large 4D/3D/2D tensor into groups of
+equal-sized sub-tensors, to compute the norm of each sub-tensor. The
+most common use-case is ranking of sub-tensors according to some norm.
+
+
+For an interesting comparison of the characteristics of L1-norm vs. L2-norm,
+see: https://www.kaggle.com/residentmario/l1-norms-versus-l2-norms)
+
+"""
+import torch
+import numpy as np
+from functools import partial
+
+
+__all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm",
+           "kernels_norm", "channels_norm", "filters_norm", "sub_matrix_norm",
+           "rows_lp_norm", "cols_lp_norm",
+           "rows_norm", "cols_norm",
+           "l1_norm", "l2_norm", "max_norm",
+           "rank_channels", "rank_filters", "rank_cols"]
+
+
+class NamedFunction:
+    def __init__(self, f, name):
+        self.f = f
+        self.name = name
+
+    def __call__(self, *args, **kwargs):
+        return self.f(*args, **kwargs)
+
+    def __str__(self):
+        return self.name
+
+
+""" Norm (magnitude) functions.
+
+These functions are named-functions because it's convenient
+to refer to them when logging.
+"""
+
+
+def _max_norm(t, dim=1):
+    """Maximum norm.
+
+    if t is some vector such that t = (t1, t2, ...,tn), then
+        max_norm = max(|t1|, |t2|, ...,|tn|)
+    """
+    maxv, _ = t.abs().max(dim=dim)
+    return maxv
+
+
+l1_norm = NamedFunction(partial(torch.norm, p=1, dim=1), "L1")
+l2_norm = NamedFunction(partial(torch.norm, p=2, dim=1), "L2")
+max_norm = NamedFunction(_max_norm, "Max")
+
+
+def kernels_lp_norm(param, p=1, group_len=1, length_normalized=False):
+    """L1/L2 norm of kernel sub-tensors in a 4D tensor.
+
+    A kernel is an m x n matrix used for convolving a feature-map to extract features.
+
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        p: the exponent value in the norm formulation
+        group_len: the numbers of (adjacent) kernels in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with norms of the groups
+    """
+    assert p in (1, 2)
+    norm_fn = l1_norm if p == 1 else l2_norm
+    return kernels_norm(param, norm_fn, group_len, length_normalized)
+
+
+def kernels_norm(param, norm_fn, group_len=1, length_normalized=False):
+    """Compute some norm of 2D kernels of 4D parameter tensors.
+
+    Assumes 4D weights tensors.
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        norm_fn: a callable that computes a normal
+        group_len: the numbers of (adjacent) kernels in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with lp-norms of the groups
+    """
+    assert param.dim() == 4, "param has invalid dimensions"
+    group_size = group_len * np.prod(param.shape[2:])
+    return generic_norm(param.view(-1, group_size), norm_fn, group_size, length_normalized, dim=1)
+
+
+def channels_lp_norm(param, p=1, group_len=1, length_normalized=False):
+    """L1/L2 norm of channels sub-tensors in a 4D tensor
+
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        p: the exponent value in the norm formulation
+        group_len: the numbers of (adjacent) channels in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with norms of the groups
+    """
+    assert p in (1, 2)
+    norm_fn = l1_norm if p == 1 else l2_norm
+    return channels_norm(param, norm_fn, group_len, length_normalized)
+
+
+def channels_norm(param, norm_fn, group_len=1, length_normalized=False):
+    """Compute some norm of 3D channels of 4D parameter tensors.
+
+    Assumes 4D weights tensors.
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        norm_fn: a callable that computes a normal
+        group_len: the numbers of (adjacent) channels in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with lp-norms of the groups
+    """
+    assert param.dim() == 4, "param has invalid dimensions"
+    param = param.transpose(0, 1).contiguous()
+    group_size = group_len * np.prod(param.shape[1:])
+    return generic_norm(param.view(-1, group_size), norm_fn, group_size, length_normalized, dim=1)
+
+
+def filters_lp_norm(param, p=1, group_len=1, length_normalized=False):
+    """L1/L2 norm of filters sub-tensors in a 4D tensor
+
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        p: the exponent value in the norm formulation
+        group_len: the numbers of (adjacent) filters in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with norms of the groups
+    """
+    assert p in (1, 2)
+    norm_fn = l1_norm if p == 1 else l2_norm
+    return filters_norm(param, norm_fn, group_len, length_normalized)
+
+
+def filters_norm(param, norm_fn, group_len=1, length_normalized=False):
+    """Compute some norm of 3D filters of 4D parameter tensors.
+
+    Assumes 4D weights tensors.
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        norm_fn: a callable that computes a normal
+        group_len: the numbers of (adjacent) filters in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with lp-norms of the groups
+    """
+    assert param.dim() == 4, "param has invalid dimensions"
+    group_size = group_len * np.prod(param.shape[1:])
+    return generic_norm(param.view(-1, group_size), norm_fn, group_size, length_normalized, dim=1)
+
+
+def sub_matrix_norm(param, norm_fn, group_len, length_normalized, dim):
+    """Compute some norm of rows/cols of 2D parameter tensors.
+
+    Assumes 2D weights tensors.
+    Args:
+        param: shape (num_filters(0), nun_channels(1), kernel_height(2), kernel_width(3))
+        norm_fn: a callable that computes a normal
+        group_len: the numbers of (adjacent) filters in each group.  Norms are calculated
+           on the entire group.
+        length_normalized: if True then normalize the norm.  I.e.
+           norm = group_norm / num_elements_in_group
+
+    Returns:
+        1D tensor with lp-norms of the groups
+    """
+    assert param.dim() == 2, "param has invalid dimensions"
+    group_size = group_len * param.size(abs(dim - 1))
+    return generic_norm(param, norm_fn, group_size, length_normalized, dim)
+
+
+def rows_lp_norm(param, p=1, group_len=1, length_normalized=False):
+    assert p in (1, 2)
+    norm_fn = l1_norm if p == 1 else l2_norm
+    return sub_matrix_norm(param, norm_fn, group_len, length_normalized, dim=1)
+
+
+def rows_norm(param, norm_fn, group_len=1, length_normalized=False):
+    return sub_matrix_norm(param, norm_fn, group_len, length_normalized, dim=1)
+
+
+def cols_lp_norm(param, p=1, group_len=1, length_normalized=False):
+    assert p in (1, 2)
+    norm_fn = l1_norm if p == 1 else l2_norm
+    return sub_matrix_norm(param, norm_fn, group_len, length_normalized, dim=0)
+
+
+def cols_norm(param, norm_fn, group_len=1, length_normalized=False):
+    return sub_matrix_norm(param, norm_fn, group_len, length_normalized, dim=0)
+
+
+def generic_norm(param, norm_fn, group_size, length_normalized, dim):
+    with torch.no_grad():
+        if dim is not None:
+            norm = norm_fn(param, dim=dim)
+        else:
+            # The norm may have been specified as part of the norm function
+            norm = norm_fn(param)
+        if length_normalized:
+            norm = norm / group_size
+        return norm
+
+
+"""
+Ranking functions
+"""
+
+
+def num_structs_to_prune(n_elems, group_len, fraction_to_prune, rounding_fn):
+    n_structs_to_prune = rounding_fn(fraction_to_prune * n_elems)
+    n_structs_to_prune = int(rounding_fn(n_structs_to_prune * 1. / group_len) * group_len)
+
+    # We can't allow removing all of the structs in a layer! --
+    # Except when the fraction_to_prune is explicitly instructing us to do so.
+    # n_ch_to_prune is the number of channels to prune.
+    if n_structs_to_prune == n_elems and fraction_to_prune != 1.0:
+        n_structs_to_prune = n_elems - group_len
+    return n_structs_to_prune
+
+
+def e_greedy_normal_noise(mags, e):
+    """Epsilon-greedy noise
+
+    If e>0 then with probability(adding noise) = e, multiply mags by a normally-distributed
+    noise.
+    :param mags: input magnitude tensor
+    :param e: epsilon (real scalar s.t. 0 <= e <=1)
+    :return: noise-multiplier.
+    """
+    if e and uniform(0, 1) <= e:
+        # msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing channels",
+        #                threshold_type, param_name)
+        return torch.randn_like(mags)
+    return 1
+
+
+def k_smallest_elems(mags, k, noise):
+    """Partial sort of tensor `mags` returning a list of the k smallest elements in order.
+
+    :param mags: tensor of magnitudes to partially sort
+    :param k: partition point
+    :param noise: probability
+    :return:
+    """
+    mags *= e_greedy_normal_noise(mags, noise)
+    k_smallest_elements, _ = torch.topk(mags, k, largest=False, sorted=True)
+    return k_smallest_elements, mags
+
+
+def rank_channels(param, group_len, magnitude_fn, fraction_to_partition, rounding_fn, noise):
+    assert param.dim() == 4, "This ranking is only supported for 4D tensors"
+    n_channels = param.size(1)
+    n_ch_to_prune = num_structs_to_prune(n_channels, group_len, fraction_to_partition, rounding_fn)
+    if n_ch_to_prune == 0:
+        return None, None
+    mags = channels_norm(param, magnitude_fn, group_len, length_normalized=True)
+    return k_smallest_elems(mags, n_ch_to_prune, noise)
+
+
+def rank_filters(param, group_len, magnitude_fn, fraction_to_partition, rounding_fn, noise):
+    assert param.dim() == 4, "This ranking is only supported for 4D tensors"
+    n_filters = param.size(0)
+    n_filters_to_prune = num_structs_to_prune(n_filters, group_len, fraction_to_partition, rounding_fn)
+    if n_filters_to_prune == 0:
+        return None, None
+    mags = filters_norm(param, magnitude_fn, group_len, length_normalized=True)
+    return k_smallest_elems(mags, n_filters_to_prune, noise)
+
+
+def rank_cols(param, group_len, magnitude_fn, fraction_to_partition, rounding_fn, noise):
+    assert param.dim() == 2, "This ranking is only supported for 2D tensors"
+    COLS_DIM = 0
+    n_cols = param.size(COLS_DIM)
+    n_cols_to_prune = num_structs_to_prune(n_cols, group_len, fraction_to_partition, rounding_fn)
+    if n_cols_to_prune == 0:
+        return None, None
+    mags = cols_norm(param, magnitude_fn, group_len, length_normalized=True)
+    return k_smallest_elems(mags, n_cols_to_prune, noise)
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index 51a9abb..d906a12 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -79,7 +79,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase):
         super().set_param_mask(param, param_name, zeros_mask_dict, meta)
 
     def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None):
-        return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity)
+        zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity)
 
 
 class StructuredAGP(AutomatedGradualPrunerBase):
@@ -96,8 +96,6 @@ class StructuredAGP(AutomatedGradualPrunerBase):
         self.pruner.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model)
 
 
-# TODO: this class parameterization is cumbersome: the ranking functions (per structure)
-# should come from the YAML schedule
 class L1RankedStructureParameterPruner_AGP(StructuredAGP):
     def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None):
         super().__init__(name, initial_sparsity, final_sparsity)
diff --git a/distiller/pruning/level_pruner.py b/distiller/pruning/level_pruner.py
index 3779b11..dc6449d 100755
--- a/distiller/pruning/level_pruner.py
+++ b/distiller/pruning/level_pruner.py
@@ -41,10 +41,12 @@ class SparsityLevelParameterPruner(_ParameterPruner):
         if desired_sparsity == 0:
             return
 
-        self.prune_level(param, param_name, zeros_mask_dict, desired_sparsity)
+        zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, desired_sparsity)
 
     @staticmethod
-    def prune_level(param, param_name, zeros_mask_dict, desired_sparsity):
-        bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True)
-        threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
+    def create_mask(param, desired_sparsity):
+        with torch.no_grad():
+            bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True)
+            threshold = bottomk.data[-1]  # This is the largest element from the group of elements that we prune away
+            mask = distiller.threshold_mask(param.data, threshold)
+            return mask
\ No newline at end of file
diff --git a/distiller/pruning/magnitude_pruner.py b/distiller/pruning/magnitude_pruner.py
index fdbfa20..c46732e 100755
--- a/distiller/pruning/magnitude_pruner.py
+++ b/distiller/pruning/magnitude_pruner.py
@@ -16,6 +16,7 @@
 
 from .pruner import _ParameterPruner
 import distiller
+import torch
 
 
 class MagnitudeParameterPruner(_ParameterPruner):
@@ -50,4 +51,13 @@ class MagnitudeParameterPruner(_ParameterPruner):
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         threshold = self.thresholds.get(param_name, self.thresholds['*'])
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
+        zeros_mask_dict[param_name].mask = self.create_mask(param.data, threshold)
+
+    @staticmethod
+    def create_mask(param, threshold):
+        with torch.no_grad():
+            mask = distiller.threshold_mask(param.data, threshold)
+            return mask
+
+
+
diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py
index 1dec718..7ba5f2f 100755
--- a/distiller/pruning/pruner.py
+++ b/distiller/pruning/pruner.py
@@ -17,6 +17,10 @@
 import torch
 import distiller
 
+
+__all__ = ["mask_tensor"]
+
+
 class _ParameterPruner(object):
     """Base class for all pruners.
 
@@ -29,12 +33,18 @@ class _ParameterPruner(object):
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
         raise NotImplementedError
 
-def threshold_model(model, threshold):
-    """Threshold an entire model using the provided threshold
 
-    This function prunes weights only (biases are left untouched).
+def mask_tensor(tensor, mask):
+    """Mask the provided tensor
+
+    Args:
+        tensor - the torch-tensor to mask
+        mask - binary coefficient-masking tensor.  Has the same type and shape as `tensor`
+    Returns:
+        tensor = tensor * mask  ;where * is the element-wise multiplication operator
     """
-    for name, p in model.named_parameters():
-       if 'weight' in name:
-           mask = distiller.threshold_mask(p.data, threshold)
-           p.data = p.data.mul_(mask)
+    assert tensor.type() == mask.type()
+    assert tensor.shape == mask.shape
+    if mask:
+        tensor.data.mul_(mask)
+    return tensor
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 84c96b7..92336dc 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -98,10 +98,6 @@ class _RankedStructureParameterPruner(_ParameterPruner):
         raise NotImplementedError
 
 
-l1_magnitude = partial(torch.norm, p=1)
-l2_magnitude = partial(torch.norm, p=2)
-
-
 class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
     """Uses Lp-norm to rank and prune structures.
 
@@ -146,139 +142,58 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
                                       group_size=self.group_size)
         return binary_map
 
-    @staticmethod
-    def rank_channels(magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise):
-        assert len(param.shape) == 4
-        num_filters, num_channels = param.size(0), param.size(1)
-        kernel_size = param.size(2) * param.size(3)
-
-        # First, reshape the weights tensor such that each channel (kernel) in the original
-        # tensor, is now a row in the 2D tensor.
-        view_2d = param.view(-1, kernel_size)
-        # Next, compute the sums of each kernel
-        kernel_mags = magnitude_fn(view_2d, dim=1)
-        # Now group by channels
-        k_sums_mat = kernel_mags.view(num_filters, num_channels).t()
-        channel_mags = k_sums_mat.mean(dim=1)
-
-        # TODO: the code below computes L1 in a simple manner - extend this for other norms
-        #channel_mags = torch.abs(param).sum((0, 2, 3))
-
-        # Round the number of channels to prune, (floor/ceil) to the nearest integer.
-        k = rounding_fn(fraction_to_prune * num_channels)
-        k = int(rounding_fn(k * 1. / group_size) * group_size)
-        
-        # We can't allow removing all of the channels! --
-        # Except when the fraction_to_prune is explicitly instructing us to do so.
-        if k == num_channels and fraction_to_prune != 1.0:
-            k = num_channels - group_size
-        if k == 0:
-            msglogger.info("Too few channels (%d)- can't prune %.1f%% channels",
-                            num_channels, 100*fraction_to_prune)
-            return None, None
-
-        if noise and uniform(0, 1) <= noise:
-            #msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing channels", 
-            #                threshold_type, param_name)
-            channel_mags *= torch.randn_like(channel_mags)
-
-        bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True)
-        return bottomk, channel_mags
-
-    @staticmethod
-    def ch_binary_map_to_mask(binary_map, param):
-        num_filters, num_channels = param.size(0), param.size(1)
-        a = binary_map.expand(num_filters, num_channels)
-        c = a.unsqueeze(-1)
-        d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous()
-        return d.view(num_filters, num_channels, param.size(2), param.size(3))
-
     @staticmethod
     def rank_and_prune_channels(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, 
-                                model=None, binary_map=None, magnitude_fn=l1_magnitude,
+                                model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm,
                                 noise=0.0, group_size=1, rounding_fn=math.floor):
-
         if binary_map is None:
-            bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_channels(
-                magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise)
+            bottomk_channels, channel_mags = distiller.norms.rank_channels(param, group_size, magnitude_fn,
+                                                                           fraction_to_prune, rounding_fn, noise)
+            # bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_channels(
+            #     magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise)
             if bottomk_channels is None:
                 # Empty list means that fraction_to_prune is too low to prune anything
                 return
             threshold = bottomk_channels[-1]
             binary_map = channel_mags.gt(threshold).type(param.data.type())
 
-        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         if zeros_mask_dict is not None:
-            zeros_mask_dict[param_name].mask = LpRankedStructureParameterPruner.ch_binary_map_to_mask(binary_map, param)
+            mask, _ = distiller.thresholding.expand_binary_map(param, 'Channels', binary_map)
+            zeros_mask_dict[param_name].mask = mask
             msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-                           threshold_type, param_name,
+                           magnitude_fn, param_name,
                            distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
                            fraction_to_prune, binary_map.sum().item(), param.size(1))
         return binary_map
 
     @staticmethod
-    def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, 
-                               model=None, binary_map=None, magnitude_fn=l1_magnitude, 
+    def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict,
+                               model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm,
                                noise=0.0, group_size=1, rounding_fn=math.floor):
         assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
-
-        threshold = None
-        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
-        num_filters = param.size(0)
-        num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters)
-        num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size)
-        # We can't allow removing all of the filters! --
-        # Except when the fraction_to_prune is explicitly instructing us to do so.
-        if num_filters_to_prune == num_filters and fraction_to_prune != 1.0:
-            num_filters_to_prune = num_filters - group_size  # We can't allow removing all of the filters!
-
         if binary_map is None:
-            # First we rank the filters
-            view_filters = param.view(num_filters, -1)
-            filter_mags = magnitude_fn(view_filters, dim=1)
-
-            if noise and uniform(0, 1) <= noise:
-                msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", 
-                               threshold_type, param_name)
-                filter_mags *= torch.randn_like(filter_mags)
-
-            if num_filters_to_prune == 0:
-                msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
+            bottomk_filters, filter_mags = distiller.norms.rank_filters(param, group_size, magnitude_fn,
+                                                                        fraction_to_prune, rounding_fn, noise)
+            if bottomk_filters is None:
+                # Empty list means that fraction_to_prune is too low to prune anything
+                msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune)
                 return
-            bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True)
-            threshold = bottomk[-1]
-            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)",
-                           threshold_type, param_name,
-                           num_filters_to_prune, filter_mags.size(0))
-
-        # Now apply a threshold
-        mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)
+            threshold = bottomk_filters[-1]
+            binary_map = filter_mags.gt(threshold).type(param.data.type())
 
         if zeros_mask_dict is not None:
+            mask, _ = distiller.thresholding.expand_binary_map(param, 'Filters', binary_map)
             zeros_mask_dict[param_name].mask = mask
-        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
-                       threshold_type, param_name,
-                       distiller.sparsity(mask),
-                       fraction_to_prune)
-        # param.data = torch.randn_like(param)
+            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
+                           magnitude_fn, param_name,
+                           distiller.sparsity(mask),
+                           fraction_to_prune)
         return binary_map
 
-    @staticmethod
-    def rank_rows(magnitude_fn, fraction_to_prune, param): # , group_size, rounding_fn, noise):
-        assert param.dim() == 2, "This pruning is only supported for 2D weights"
-        ROWS_DIM = 0
-        cols_mags = magnitude_fn(param, dim=ROWS_DIM)
-        num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM))
-        if num_cols_to_prune == 0:
-            msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
-            return None, None
-        bottomk_cols, _ = torch.topk(cols_mags, num_cols_to_prune, largest=False, sorted=True)
-        return bottomk_cols, cols_mags
-
     @staticmethod
     def rank_and_prune_rows(fraction_to_prune, param, param_name,
                             zeros_mask_dict, model=None, binary_map=None,
-                            magnitude_fn=l1_magnitude, group_size=1):
+                            magnitude_fn=distiller.norms.l1_norm, group_size=1):
         """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
 
         PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
@@ -293,24 +208,29 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
         that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far
         away from each other in memory, and this means poor use of caches and system memory.
         """
-        bottomk_cols, cols_mags = LpRankedStructureParameterPruner.rank_rows(magnitude_fn, fraction_to_prune, param)
-        THRESHOLD_DIM = 'Cols'
-        threshold = bottomk_cols[-1]
-        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
-        zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
-                                                                                      threshold, threshold_type)
-        ROWS_DIM = 0
-        num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM))
-        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-                       threshold_type, param_name,
-                       distiller.sparsity(zeros_mask_dict[param_name].mask),
-                       fraction_to_prune, num_cols_to_prune, cols_mags.size(ROWS_DIM))
+        if binary_map is None:
+            bottomk_cols, cols_mags = distiller.norms.rank_cols(param, group_size, magnitude_fn, fraction_to_prune,
+                                                                rounding_fn=math.floor, noise=None)
+            if bottomk_cols is None:
+                # Empty list means that fraction_to_prune is too low to prune anything
+                msglogger.info("Too few cols - can't prune %.1f%% cols", 100 * fraction_to_prune)
+                return
+            threshold = bottomk_cols[-1]
+            binary_map = cols_mags.gt(threshold).type(param.data.type())
+
+        if zeros_mask_dict is not None:
+            mask, _ = distiller.thresholding.expand_binary_map(param, 'Cols', binary_map)
+            zeros_mask_dict[param_name].mask = mask
+            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
+                           magnitude_fn, param_name,
+                           distiller.sparsity(mask),
+                           fraction_to_prune)
         return binary_map
 
     @staticmethod
     def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
                               model=None, binary_map=None, block_shape=None,
-                              magnitude_fn=l1_magnitude, group_size=1):
+                              magnitude_fn=distiller.norms.l1_norm, group_size=1):
         """Block-wise pruning for 4D tensors.
 
         The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width].
@@ -380,11 +300,10 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
             threshold = bottomk_blocks[-1]
             binary_map = block_mags.gt(threshold).type(param.data.type())
 
-        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         if zeros_mask_dict is not None:
             zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
             msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-                           threshold_type, param_name,
+                           magnitude_fn, param_name,
                            distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape),
                            fraction_to_prune, binary_map.sum().item(), num_super_blocks)
         return binary_map
@@ -399,7 +318,7 @@ class L1RankedStructureParameterPruner(LpRankedStructureParameterPruner):
                  group_dependency=None, kwargs=None, noise=0.0,
                  group_size=1, rounding_fn=math.floor):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency, 
-                         kwargs, magnitude_fn=l1_magnitude, noise=noise,
+                         kwargs, magnitude_fn=distiller.norms.l1_norm, noise=noise,
                          group_size=group_size, rounding_fn=rounding_fn)
 
 
@@ -412,7 +331,7 @@ class L2RankedStructureParameterPruner(LpRankedStructureParameterPruner):
                  group_dependency=None, kwargs=None, noise=0.0,
                  group_size=1, rounding_fn=math.floor):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency, 
-                         kwargs, magnitude_fn=l2_magnitude, noise=noise,
+                         kwargs, magnitude_fn=distiller.norms.l2_norm, noise=noise,
                          group_size=group_size, rounding_fn=rounding_fn)
 
 
@@ -437,9 +356,9 @@ def _mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, bi
     if binary_map is None:
         binary_map = torch.zeros(num_filters).cuda()
         binary_map[filters_ordered_by_criterion] = 1
-    binary_map = binary_map.detach()
-    expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
-    return expanded.view(param.shape), binary_map
+
+    expanded = binary_map.expand(np.prod(list(param.size()[1:])), param.size(0)).t().contiguous()
+    return distiller.thresholding.expand_binary_map(param, "Filters", binary_map)
 
 
 class ActivationRankedFilterPruner(_RankedStructureParameterPruner):
@@ -591,7 +510,10 @@ class BernoulliFilterPruner(_RankedStructureParameterPruner):
 
 
 class GradientRankedFilterPruner(_RankedStructureParameterPruner):
-    """
+    """Taylor expansion ranking.
+
+    Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. Pruning Convolutional Neural
+    Networks for Resource Efficient Inference. ArXiv, abs/1611.06440, 2016.
     """
     def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
@@ -736,7 +658,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
         intermediate_fms['input_fms'][module.distiller_name].append(X)
 
     def __init__(self, name, group_type, desired_sparsity, weights,
-                 group_dependency=None, kwargs=None, magnitude_fn=l1_magnitude, 
+                 group_dependency=None, kwargs=None, magnitude_fn=distiller.norms.l1_norm,
                  group_size=1, rounding_fn=math.floor, ranking_noise=0.):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency,
                          group_size=group_size, rounding_fn=rounding_fn, noise=ranking_noise)
@@ -761,7 +683,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
     @staticmethod
     def rank_and_prune_channels(fraction_to_prune, param, param_name=None,
                                 zeros_mask_dict=None, model=None, binary_map=None, 
-                                magnitude_fn=l2_magnitude, group_size=1, rounding_fn=math.floor,
+                                magnitude_fn=distiller.norms.l1_norm, group_size=1, rounding_fn=math.floor,
                                 noise=0):
         assert binary_map is None
         if binary_map is None:
@@ -771,7 +693,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
                     magnitude_fn, fraction_to_prune, param, group_size, rounding_fn, noise)
 
             else:
-                bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_rows(
+                bottomk_channels, channel_mags = LpRankedStructureParameterPruner.rank_cols(
                      magnitude_fn, fraction_to_prune, param)
 
             # Todo: this little piece of code can be refactored
@@ -812,7 +734,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
                 # X is (batch, ck^2, num_pts)
                 # we want:   (batch, c, k^2, num_pts)
                 X = X.view(X.size(0), -1, np.prod(conv.kernel_size), X.size(2))
-                X = X[:,binary_map,:,:]
+                X = X[:, binary_map, :, :]
                 X = X.view(X.size(0), -1, X.size(3))
                 X = X.transpose(1, 2)
                 X = X.contiguous().view(-1, X.size(2))
@@ -828,14 +750,15 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner):
 
                 # Copy the weights that we learned from minimizing the feature-maps least squares error,
                 # to our actual weights tensor.
-                param.detach()[:,indices,:,:] = new_w.type(param.type())
+                param.detach()[:, indices, :,   :] = new_w.type(param.type())
             else:
                 param.detach()[:, indices] = new_w.type(param.type())
 
         if zeros_mask_dict is not None:
             binary_map = binary_map.type(param.type())
             if op_type == 'conv':
-                zeros_mask_dict[param_name].mask = LpRankedStructureParameterPruner.ch_binary_map_to_mask(binary_map, param)
+                zeros_mask_dict[param_name].mask = _ = distiller.thresholding.expand_binary_map(param,
+                                                                                                'Channels', binary_map)
                 msglogger.info("FMReconstructionChannelPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                                param_name,
                                distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py
index 624b84d..5fc4f3f 100755
--- a/distiller/pruning/sensitivity_pruner.py
+++ b/distiller/pruning/sensitivity_pruner.py
@@ -18,6 +18,7 @@ from .pruner import _ParameterPruner
 import distiller
 import torch
 
+
 class SensitivityPruner(_ParameterPruner):
     """Use algorithm from "Learning both Weights and Connections for Efficient
     Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf
@@ -44,9 +45,6 @@ class SensitivityPruner(_ParameterPruner):
         self.sensitivities = sensitivities
 
     def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
-        if not hasattr(param, 'stddev'):
-            param.stddev = torch.std(param).item()
-
         if param_name not in self.sensitivities:
             if '*' not in self.sensitivities:
                 return
@@ -55,7 +53,13 @@ class SensitivityPruner(_ParameterPruner):
         else:
             sensitivity = self.sensitivities[param_name]
 
-        threshold = param.stddev * sensitivity
+        zeros_mask_dict[param_name].mask = self.create_mask(param, sensitivity)
 
-        # After computing the threshold, we can create the mask
-        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
+    @staticmethod
+    def create_mask(param, sensitivity):
+        if not hasattr(param, 'stddev'):
+            param.stddev = torch.std(param).item()
+        with torch.no_grad():
+            threshold = param.stddev * sensitivity
+            mask = distiller.threshold_mask(param.data, threshold)
+            return mask
diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py
index 25d57db..2a2edd3 100755
--- a/distiller/pruning/splicing_pruner.py
+++ b/distiller/pruning/splicing_pruner.py
@@ -52,39 +52,46 @@ class SplicingPruner(_ParameterPruner):
         else:
             sensitivity = self.sensitivities[param_name]
 
-        if not hasattr(param, '_std'):
-            # Compute the mean and standard-deviation once, and cache them.
-            param._std = torch.std(param.abs()).item()
-            param._mean = torch.mean(param.abs()).item()
-
         if self.sensitivity_multiplier > 0:
             # Linearly growing sensitivity - for now this is hard-coded
             starting_epoch = meta['starting_epoch']
             current_epoch = meta['current_epoch']
             sensitivity *= (current_epoch - starting_epoch) * self.sensitivity_multiplier + 1
 
-        threshold_low = (param._mean + param._std * sensitivity) * self.low_thresh_mult
-        threshold_hi = (param._mean + param._std * sensitivity) * self.hi_thresh_mult
-
         if zeros_mask_dict[param_name].mask is None:
             zeros_mask_dict[param_name].mask = torch.ones_like(param)
+        zeros_mask_dict[param_name].mask = self.create_mask(param,
+                                                            zeros_mask_dict[param_name].mask,
+                                                            sensitivity,
+                                                            self.low_thresh_mult,
+                                                            self.hi_thresh_mult)
+
+    @staticmethod
+    def create_mask(param, current_mask, sensitivity, low_thresh_mult, hi_thresh_mult):
+        with torch.no_grad():
+            if not hasattr(param, '_std'):
+                # Compute the mean and standard-deviation once, and cache them.
+                param._std = torch.std(param.abs()).item()
+                param._mean = torch.mean(param.abs()).item()
+
+            threshold_low = (param._mean + param._std * sensitivity) * low_thresh_mult
+            threshold_hi = (param._mean + param._std * sensitivity) * hi_thresh_mult
+
+            # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper:
+            #
+            #           0    if a  > |W|
+            # h(W) =    mask if a <= |W| < b
+            #           1    if b <= |W|
+            #
+            # h(W) is the so-called "network surgery function".
+            # mask is the mask used in the previous iteration.
+            # a and b are the low and high thresholds, respectively.
+            # We followed the example implementation from Yiwen Guo in Caffe, and used the
+            # weight tensor's starting mean and std.
+            # This is very similar to the initialization performed by distiller.SensitivityPruner.
 
-        # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper:
-        #
-        #           0    if a  > |W|
-        # h(W) =    mask if a <= |W| < b
-        #           1    if b <= |W|
-        #
-        # h(W) is the so-called "network surgery function".
-        # mask is the mask used in the previous iteration.
-        # a and b are the low and high thresholds, respectively.
-        # We followed the example implementation from Yiwen Guo in Caffe, and used the
-        # weight tensor's starting mean and std.
-        # This is very similar to the initialization performed by distiller.SensitivityPruner.
-    
-        mask = zeros_mask_dict[param_name].mask
-        zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type())
-        weights_abs = param.abs()
-        new_mask = torch.where(threshold_low > weights_abs, zeros, mask)
-        new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask)
-        zeros_mask_dict[param_name].mask = new_mask
+            zeros, ones = torch.zeros_like(current_mask), torch.ones_like(current_mask)
+            weights_abs = param.abs()
+            new_mask = torch.where(threshold_low > weights_abs, zeros, current_mask)
+            new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask)
+            return new_mask
diff --git a/distiller/pruning/structure_pruner.py b/distiller/pruning/structure_pruner.py
index fe6803a..efd98a3 100755
--- a/distiller/pruning/structure_pruner.py
+++ b/distiller/pruning/structure_pruner.py
@@ -19,6 +19,7 @@ from .pruner import _ParameterPruner
 import distiller
 msglogger = logging.getLogger()
 
+
 class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner):
     """Prune parameter structures.
 
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 55be4a1..2f93662 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2018 Intel Corporation
+# Copyright (c) 2019 Intel Corporation
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,19 +20,24 @@ The code below supports fine-grained tensor thresholding and group-wise threshol
 """
 import torch
 import numpy as np
+from distiller.norms import *
 
 
-def threshold_mask(weights, threshold):
+__all__ = ["threshold_mask", "GroupThresholdMixin",
+           "group_threshold_binary_map", "group_threshold_mask"]
+
+
+def threshold_mask(param, threshold):
     """Create a threshold mask for the provided parameter tensor using
     magnitude thresholding.
 
     Arguments:
-        weights: a parameter tensor which should be pruned.
+        param: a parameter tensor which should be pruned.
         threshold: the pruning threshold.
     Returns:
         prune_mask: The pruning mask.
     """
-    return torch.gt(torch.abs(weights), threshold).type(weights.type())
+    return torch.gt(torch.abs(param), threshold).type(param.type())
 
 
 class GroupThresholdMixin(object):
@@ -48,74 +53,71 @@ class GroupThresholdMixin(object):
 
 
 def group_threshold_binary_map(param, group_type, threshold, threshold_criteria):
-    """Return a threshold mask for the provided parameter and group type.
+    """Return a threshold binary map for the provided parameter and group type.
+
+    This function thresholds a parameter tensor, using the provided threshold.
+    Thresholding is performed by breaking the parameter tensor into groups as
+    specified by group_type, computing the norm of each group instance using
+    threshold_criteria, and then thresholding that norm.  The result is called
+    binary_map and contains 1s where the group norm was larger than the threshold
+    value, zero otherwise.
 
     Args:
         param: The parameter to mask
         group_type: The elements grouping type (structure).
-          One of:2D, 3D, 4D, Channels, Row, Cols
+          One of:2D, 3D, Channels, Row, Cols
         threshold: The threshold
         threshold_criteria: The thresholding criteria.
-          'Mean_Abs' thresholds the entire element group using the mean of the
+          ('Mean_Abs', 'Mean_L1', 'L1') - thresholds the entire element group using the mean of the
           absolute values of the tensor elements.
-          'Max' thresholds the entire group using the magnitude of the largest
+          ('Mean_L2', 'L2') -  - thresholds the entire element group using the L2 norm
+          'Max' - thresholds the entire group using the magnitude of the largest
           element in the group.
+
+    Returns:
+        binary_map
     """
+    if isinstance(threshold, torch.Tensor):
+        threshold = threshold.item()
+    length_normalized = 'Mean' in threshold_criteria
+    if threshold_criteria in ('Mean_Abs', 'Mean_L1', 'L1'):
+        norm_fn = l1_norm
+    elif threshold_criteria in ('Mean_L2', 'L2'):
+        norm_fn = l2_norm
+    elif threshold_criteria == 'Max':
+        norm_fn = max_norm
+    else:
+        raise ValueError("Illegal threshold_criteria %s", threshold_criteria)
+
     if group_type == '2D':
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-        view_2d = param.view(-1, param.size(2) * param.size(3))
-        # 1. Determine if the kernel "value" is below the threshold, by creating a 1D
-        #    thresholds tensor with length = #IFMs * # OFMs
-        thresholds = torch.Tensor([threshold] * param.size(0) * param.size(1)).to(param.device)
-        # 2. Create a binary thresholds mask, where we use the mean of the abs values of the
-        #    elements in each channel as the threshold filter.
-        # 3. Apply the threshold filter
-        binary_map = threshold_policy(view_2d, thresholds, threshold_criteria)
-        return binary_map
+        thresholds = param.new_full((param.size(0) * param.size(1),), threshold)
+        norms = kernels_norm(param, norm_fn, length_normalized=length_normalized)
 
     elif group_type == 'Rows':
         assert param.dim() == 2, "This regularization is only supported for 2D weights"
-        thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
-        binary_map = threshold_policy(param, thresholds, threshold_criteria)
-        return binary_map
+        thresholds = param.new_full((param.size(0),), threshold)
+        norms = sub_matrix_norm(param, norm_fn, group_len=1, length_normalized=length_normalized, dim=1)
 
     elif group_type == 'Cols':
         assert param.dim() == 2, "This regularization is only supported for 2D weights"
-        thresholds = torch.Tensor([threshold] * param.size(1)).to(param.device)
-        binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0)
-        return binary_map
+        thresholds = param.new_full((param.size(1),), threshold)
+        norms = sub_matrix_norm(param, norm_fn, group_len=1, length_normalized=length_normalized, dim=0)
 
     elif group_type == '3D' or group_type == 'Filters':
         assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
-        view_filters = param.view(param.size(0), -1)
-        thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
-        binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
-        return binary_map
-
-    elif group_type == '4D':
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-        if threshold_criteria == 'Mean_Abs':
-            if param.data.abs().mean() > threshold:
-                return None
-            return torch.zeros_like(param.data)
-        elif threshold_criteria == 'Max':
-            if param.data.abs().max() > threshold:
-                return None
-            return torch.zeros_like(param.data)
-        raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
+        n_filters = param.size(0)
+        thresholds = param.new_full((n_filters,), threshold)
+        norms = filters_norm(param, norm_fn, length_normalized=length_normalized)
 
     elif group_type == 'Channels':
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-        num_filters = param.size(0)
-        num_kernels_per_filter = param.size(1)
+        n_channels = param.size(1)
+        thresholds = param.new_full((n_channels,),  threshold)
+        norms = channels_norm(param, norm_fn, length_normalized=length_normalized)
 
-        view_2d = param.view(-1, param.size(2) * param.size(3))
-        # Next, compute the sum of the squares (of the elements in each row/kernel)
-        kernel_means = view_2d.abs().mean(dim=1)
-        k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t()
-        thresholds = torch.Tensor([threshold] * num_kernels_per_filter).to(param.device)
-        binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type())
-        return binary_map
+    binary_map = norms.gt(thresholds).type(param.type())
+    return binary_map
 
 
 def group_threshold_mask(param, group_type, threshold, threshold_criteria, binary_map=None):
@@ -124,76 +126,57 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
     Args:
         param: The parameter to mask
         group_type: The elements grouping type (structure).
-          One of:2D, 3D, 4D, Channels, Row, Cols
+          One of:2D, 3D, Channels, Row, Cols
         threshold: The threshold
         threshold_criteria: The thresholding criteria.
           'Mean_Abs' thresholds the entire element group using the mean of the
           absolute values of the tensor elements.
           'Max' thresholds the entire group using the magnitude of the largest
           element in the group.
+        binary_map:
+
+    Returns:
+        (mask, binary_map)
     """
-    if group_type == '2D':
-        if binary_map is None:
-            binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
+    assert group_type in ('2D', 'Rows', 'Cols', '3D', 'Filters', 'Channels')
+    if binary_map is None:
+        binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
+
+    # Now let's expand back up to a 4D mask
+    return expand_binary_map(param, group_type, binary_map)
+
+
+def expand_binary_map(param, group_type, binary_map):
+    """Expands a binary_map to the shape of the provided parameter.
 
-        # 3. Finally, expand the thresholds and view as a 4D tensor
+    Args:
+        param: The parameter to mask
+        group_type: The elements grouping type (structure).
+          One of:2D, 3D, 4D, Channels, Row, Cols
+        binary_map: the binary map that matches the specified `group_type`
+
+    Returns:
+        (mask, binary_map)
+    """
+    assert group_type in ('2D', 'Rows', 'Cols', '3D', 'Filters', '4D', 'Channels')
+    assert binary_map is not None
+
+    # Now let's expand back up to a 4D mask
+    if group_type == '2D':
         a = binary_map.expand(param.size(2) * param.size(3),
                               param.size(0) * param.size(1)).t()
         return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
-
     elif group_type == 'Rows':
-        if binary_map is None:
-            binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         return binary_map.expand(param.size(1), param.size(0)).t(), binary_map
-
     elif group_type == 'Cols':
-        if binary_map is None:
-            binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         return binary_map.expand(param.size(0), param.size(1)), binary_map
-
     elif group_type == '3D' or group_type == 'Filters':
-        if binary_map is None:
-            binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()
         return a.view(*param.shape), binary_map
-
-    elif group_type == '4D':
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
-        if threshold_criteria == 'Mean_Abs':
-            if param.data.abs().mean() > threshold:
-                return None
-            return torch.zeros_like(param.data)
-        elif threshold_criteria == 'Max':
-            if param.data.abs().max() > threshold:
-                return None
-            return torch.zeros_like(param.data)
-        raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
-
     elif group_type == 'Channels':
-        if binary_map is None:
-            binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
-        num_filters = param.size(0)
-        num_kernels_per_filter = param.size(1)
-
-        # Now let's expand back up to a 4D mask
-        a = binary_map.expand(num_filters, num_kernels_per_filter)
+        num_filters, num_channels = param.size(0), param.size(1)
+        a = binary_map.expand(num_filters, num_channels)
         c = a.unsqueeze(-1)
-        d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
+        d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous()
         return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
 
-
-def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
-    """
-    """
-    if threshold_criteria in ['Mean_Abs', 'Mean_L1']:
-        return weights.data.norm(p=1, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type())
-    if threshold_criteria == 'Mean_L2':
-        return weights.data.norm(p=2, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type())
-    elif threshold_criteria == 'L1':
-        return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type())
-    elif threshold_criteria == 'L2':
-        return weights.data.norm(p=2, dim=dim).gt(thresholds).type(weights.type())
-    elif threshold_criteria == 'Max':
-        maxv, _ = weights.data.abs().max(dim=dim)
-        return maxv.gt(thresholds).type(weights.type())
-    raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
diff --git a/distiller/utils.py b/distiller/utils.py
index ac0158e..d1c1feb 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -58,7 +58,7 @@ def to_np(var):
 def size2str(torch_size):
     if isinstance(torch_size, torch.Size):
         return size_to_str(torch_size)
-    if isinstance(torch_size, torch.FloatTensor) or isinstance(torch_size, torch.cuda.FloatTensor):
+    if isinstance(torch_size, (torch.FloatTensor, torch.cuda.FloatTensor)):
         return size_to_str(torch_size.size())
     if isinstance(torch_size, torch.autograd.Variable):
         return size_to_str(torch_size.data.size())
@@ -198,10 +198,10 @@ def sparsity_3D(tensor):
     """Filter-wise sparsity for 4D tensors"""
     if tensor.dim() != 4:
         return 0
-    view_3d = tensor.view(-1, tensor.size(1) * tensor.size(2) * tensor.size(3))
-    num_filters = view_3d.size()[0]
-    nonzero_filters = len(torch.nonzero(view_3d.abs().sum(dim=1)))
-    return 1 - nonzero_filters/num_filters
+    l1_norms = distiller.norms.filters_lp_norm(tensor, p=1, length_normalized=False)
+    num_nonzero_filters = len(torch.nonzero(l1_norms))
+    num_filters = tensor.size(0)
+    return 1 - num_nonzero_filters / num_filters
 
 
 def density_3D(tensor):
@@ -255,16 +255,8 @@ def non_zero_channels(tensor):
     if tensor.dim() != 4:
         raise ValueError("Expecting a 4D tensor")
 
-    n_filters, n_channels, k_h, k_w = (tensor.size(i) for i in range(4))
-
-    # First, reshape the weights tensor such that each channel (kernel) in
-    # the original tensor, is now a row in a 2D tensor.
-    view_2d = tensor.view(-1, k_h * k_w)
-    # Next, compute the sums of each kernel
-    kernel_sums = view_2d.abs().sum(dim=1)
-    # Now group by channels
-    k_sums_mat = kernel_sums.view(n_filters, n_channels).t()
-    nonzero_channels = torch.nonzero(k_sums_mat.abs().sum(dim=1))
+    norms = distiller.norms.channels_lp_norm(tensor, p=1)
+    nonzero_channels = torch.nonzero(norms)
     return nonzero_channels
 
 
@@ -400,15 +392,7 @@ def model_params_stats(model, param_dims=[2, 4], param_types=['weight', 'bias'])
 
 
 def norm_filters(weights, p=1):
-    """Compute the p-norm of convolution filters.
-
-    Args:
-        weights - a 4D convolution weights tensor.
-                  Has shape = (#filters, #channels, k_w, k_h)
-        p - the exponent value in the norm formulation
-    """
-    assert weights.dim() == 4
-    return weights.view(weights.size(0), -1).norm(p=p, dim=1)
+    return distiller.norms.filters_lp_norm(weights, p)
 
 
 def model_numel(model, param_dims=[2, 4], param_types=['weight', 'bias']):
diff --git a/examples/greedy_pruning/greedy_pruning.ipynb b/examples/greedy_pruning/greedy_pruning.ipynb
index 069736c..5c39e3e 100644
--- a/examples/greedy_pruning/greedy_pruning.ipynb
+++ b/examples/greedy_pruning/greedy_pruning.ipynb
@@ -9,7 +9,7 @@
     "\n",
     "In this notebook we review a couple of experiments performed on Plain20 using Distiller's version of greedy filter pruning.  This implementation is very similar to the greed algorithms defined in [1] and [2].\n",
     "\n",
-    "This is another means to explore the network sub-space around a pre-trained model: by small geedy and iterative greedy filter-subtraction operations.  \n",
+    "This is another means to explore the network sub-space around a pre-trained model: by small greedy and iterative greedy filter-subtraction operations.  \n",
     "\n",
     "We want to answer the question: *how important is the duration of the short-term fine-tuning (see NetAdapt for definition of short-term FT)?*\n",
     "\n",
@@ -262,8 +262,17 @@
    "display_name": "Python 3",
    "language": "python",
    "name": "python3"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "source": [],
+    "metadata": {
+     "collapsed": false
+    }
+   }
   }
  },
  "nbformat": 4,
  "nbformat_minor": 2
-}
+}
\ No newline at end of file
diff --git a/examples/network_surgery/resnet20.network_surgery.yaml b/examples/network_surgery/resnet20.network_surgery.yaml
index e77b980..c48bcd3 100755
--- a/examples/network_surgery/resnet20.network_surgery.yaml
+++ b/examples/network_surgery/resnet20.network_surgery.yaml
@@ -26,7 +26,7 @@
 #     # of parameters: 270,896
 #
 # Results:
-#     Best Top1: 91.490 (on Epoch: 339)
+#     Best Top1: 91.43
 #     Total MACs: 40,813,184
 #     Total sparsity: 69.1%
 #     # of parameters: 83,671
@@ -126,7 +126,7 @@ policies:
         keep_mask: True
         mini_batch_pruning_frequency: 1
         mask_on_forward_only: True
-        # use_double_copies: True
+        #use_double_copies: True
     starting_epoch: 0
     ending_epoch: 100
     frequency: 1
diff --git a/tests/test_infra.py b/tests/test_infra.py
index 2a266de..f47b969 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -295,4 +295,8 @@ def test_load_checkpoint_without_model():
         assert model
         assert model.arch == "resnet20_cifar"
         assert model.dataset == "cifar10"
-        os.remove(temp_checkpoint)
\ No newline at end of file
+        os.remove(temp_checkpoint)
+
+
+if __name__ == '__main__':
+    test_load_gpu_model_on_cpu_with_thinning()
\ No newline at end of file
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index dd1b04f..23a795c 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -414,18 +414,6 @@ def test_mobilenet_conv_fc_interface(is_parallel=parallel, model=None, zeros_mas
                            zeros_mask_dict=zeros_mask_dict)
 
 
-def test_threshold_mask():
-    # Create a 4-D tensor of 1s
-    a = torch.ones(3, 64, 32, 32)
-    # Change one element
-    a[1, 4, 17, 31] = 0.2
-    # Create and apply a mask
-    mask = distiller.threshold_mask(a, threshold=0.3)
-    assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
-    assert mask[1, 4, 17, 31] == 0
-    assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
-
-
 def test_magnitude_pruning():
     # Create a 4-D tensor of 1s
     a = torch.ones(3, 64, 32, 32)
diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py
index 6034c02..f231abd 100755
--- a/tests/test_thresholding.py
+++ b/tests/test_thresholding.py
@@ -13,22 +13,229 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+
 import torch
-import pytest
+import numpy as np
 import distiller
+import common
 
 
-def get_test_tensor():
+def get_test_2d_tensor():
     return torch.tensor([[1.0, 2.0, 3.0],
                          [4.0, 5.0, 6.0],
                          [7.0, 8.0, 9.0],
                          [10., 11., 12.]])
 
 
+def get_test_4d_tensor():
+    # Channel normalized L1 norms:
+    # 0.8362   torch.norm(a[:,0,:,:], p=1) / 18
+    # 0.7625   torch.norm(a[:,1,:,:], p=1) / 18
+    # 0.6832   torch.norm(a[:,2,:,:], p=1) / 18
+
+    # Channel L2 norms: tensor([4.3593, 3.6394, 3.9037])
+    #   a.transpose(0,1).contiguous().view(3,-1).norm(2, dim=1)
+
+    return torch.tensor(
+       [
+        # Filter L2 = 4.5039   torch.norm(a[0,:,:,:], p=2)
+        [[[-1.2982,  0.7574,  0.7962],  # Kernel L1 = 6.5997   torch.norm(a[0,0,:,:], p=1)
+          [-0.6695,  1.5907,  0.2659],
+          [ 0.1423,  0.3165, -0.7629]],
+
+         [[-0.5480, -1.2718,  0.8286],  # Kernel L1 = 7.7756   torch.norm(a[0,1,:,:], p=1)
+          [-0.6427,  0.3814, -0.7988],
+          [ 1.0346,  1.3023, -0.9674]],
+
+         [[-0.7951,  1.8784, -0.5654],  # Kernel L1 = 5.8073   torch.norm(a[0,2,:,:], p=1)
+          [ 0.0456, -0.2849, -0.3332],
+          [-0.2367,  0.7467,  0.9212]]],
+
+        # Filter L2 = 5.2156   torch.norm(a[1,:,:,:], p=2)
+        [[[ 1.3672,  0.2993, -0.0619],  # Kernel L1 = 8.4522   torch.norm(a[1,0,:,:], p=1)
+          [ 1.8156,  0.7599,  0.1815],
+          [ 0.4136,  1.8316,  1.7214]],
+
+         [[ 0.5125, -1.5329,  0.9257],  # Kernel L1 = 5.9498   torch.norm(a[1,1,:,:], p=1)
+          [ 0.9200,  0.4376,  0.5743],
+          [-0.0097,  0.9473, -0.0899]],
+
+         [[ 0.2372,  2.4369, -0.3410],  # Kernel L1 = 6.4908  torch.norm(a[1,2,:,:], p=1)
+          [-1.0595,  0.8056, -0.0357],
+          [-1.0105, -0.1451, -0.4194]]]])
+
+
+def test_norm_names():
+    assert str(distiller.norms.l1_norm) == "L1"
+    assert str(distiller.norms.l2_norm) == "L2"
+    assert str(distiller.norms.max_norm) == "Max"
+
+
+def test_threshold_mask():
+    # Create a 4-D tensor of 1s
+    a = torch.ones(3, 64, 32, 32)
+    # Change one element
+    a[1, 4, 17, 31] = 0.2
+    # Create and apply a mask
+    mask = distiller.threshold_mask(a, threshold=0.3)
+    assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
+    assert mask[1, 4, 17, 31] == 0
+    assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
+
+
+def test_kernel_thresholding():
+    p = get_test_4d_tensor().cuda()
+    mask, map = distiller.group_threshold_mask(p, '2D', 6, 'L1')
+
+    # Test the binary map: 1s indicate 2D-kernels that have an L1 above 6
+    assert map.shape == torch.Size([6])
+    assert torch.eq(map, torch.tensor([1., 1., 0.,
+                                       1., 0., 1.], device=map.device)).all()
+    # Test the full mask
+    expected_mask = torch.tensor(
+       [[[[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]]],
+
+
+        [[[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]]]], device=mask.device)
+    assert torch.eq(mask, expected_mask).all()
+    return mask
+
+
+def test_filter_thresholding():
+    p = get_test_4d_tensor().cuda()
+    mask, map = distiller.group_threshold_mask(p, '3D', 4.7, 'L2')
+
+    # Test the binary map: 1s indicate 3D-filters that have an L2 above 4.7
+    assert map.shape == torch.Size([2])
+    assert torch.eq(map, torch.tensor([0., 1.], device=map.device)).all()
+    # Test the full mask
+    expected_mask = torch.tensor(
+       [[[[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]]],
+
+
+        [[[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]]]], device=mask.device)
+    assert torch.eq(mask, expected_mask).all()
+    return mask
+
+
+def test_channel_thresholding_1():
+    p = get_test_4d_tensor().cuda()
+    mask, map = distiller.group_threshold_mask(p, 'Channels', 3.7, 'L2')
+
+    # Test the binary map: 1s indicate 3D-channels that have a length-normalized-L2 above 1.3
+    assert map.shape == torch.Size([3])
+    assert torch.eq(map, torch.tensor([1., 0., 1.], device=map.device)).all()
+    # Test the full mask
+    expected_mask = torch.tensor(
+       [[[[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]]],
+
+
+        [[[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]]]], device=mask.device)
+    assert torch.eq(mask, expected_mask).all()
+    return mask
+
+
+def test_channel_thresholding_2():
+    p = get_test_4d_tensor().cuda()
+    mask, map = distiller.group_threshold_mask(p, 'Channels', 0.7, 'Mean_L1')
+
+    # Test the binary map: 1s indicate 3D-channels that have a length-normalized-L2 above 1.3
+    assert map.shape == torch.Size([3])
+    assert torch.eq(map, torch.tensor([1., 1., 0.], device=map.device)).all()
+    # Test the full mask
+    expected_mask = torch.tensor(
+       [[[[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[1., 1., 1.],
+          [1., 1., 1.],
+          [1., 1., 1.]],
+
+         [[0., 0., 0.],
+          [0., 0., 0.],
+          [0., 0., 0.]]],
+
+        [[[1., 1., 1.],
+           [1., 1., 1.],
+           [1., 1., 1.]],
+
+          [[1., 1., 1.],
+           [1., 1., 1.],
+           [1., 1., 1.]],
+
+          [[0., 0., 0.],
+           [0., 0., 0.],
+           [0., 0., 0.]]]], device=mask.device)
+    assert torch.eq(mask, expected_mask).all()
+    return mask
+
+
 def test_row_thresholding():
-    p = get_test_tensor().cuda()
-    group_th = distiller.GroupThresholdMixin()
-    mask = group_th.group_threshold_mask(p, 'Rows', 7, 'Max')
+    p = get_test_2d_tensor().cuda()
+    mask, map = distiller.group_threshold_mask(p, 'Rows', 7, 'Max')
+
+    assert torch.eq(map, torch.tensor([ 0.,  0.,  1., 1.], device=mask.device)).all()
     assert torch.eq(mask, torch.tensor([[ 0.,  0.,  0.],
                                         [ 0.,  0.,  0.],
                                         [ 1.,  1.,  1.],
@@ -37,15 +244,15 @@ def test_row_thresholding():
 
 
 def test_col_thresholding():
-    p = get_test_tensor().cuda()
-    group_th = distiller.GroupThresholdMixin()
-    mask = group_th.group_threshold_mask(p, 'Cols', 11, 'Max')
+    p = get_test_2d_tensor().cuda()
+    mask, map = distiller.group_threshold_mask(p, 'Cols', 11, 'Max')
     assert torch.eq(mask, torch.tensor([[ 0.,  0.,  1.],
                                         [ 0.,  0.,  1.],
                                         [ 0.,  0.,  1.],
                                         [ 0.,  0.,  1.]], device=mask.device)).all()
     return mask
 
+
 if __name__ == '__main__':
-    m = test_col_thresholding()
-    print(m)
+    m = test_channel_thresholding_2()
+    #print(m)
-- 
GitLab