From 68514d178b583d3f6d59f13bb9f34120ef51eb4c Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Mon, 20 Apr 2020 15:02:22 +0300 Subject: [PATCH] small tensor masking API refactoring (#499) Added masking primitives: -mask_tensor -create_mask_threshold_criterion -create_mask_level_criterion -create_mask_sensitivity_criterion These APIs have a clearer name and communicate their responsibility better: create a tensor mask, based on some criterion. Previously, distiller.pruning.create_mask_threshold_criterion was named distiller.threshold_mask which did not communicate well what this function did. Masking functionality is no longer hidden inside the Pruner instances, so they can be used directly by an application, or to compose new Pruner classes. Removed file distiller.pruning.pruner: -The base-class _ParameterPruner is useless and adds needless details to the implementation. AGP: Separated the pruning-rate schedule from the rest of the logic. This allows us to mix-and-match different pruning-rate schedules (just like LR schedulers). --- distiller/__init__.py | 32 +++---- distiller/norms.py | 1 + distiller/pruning/__init__.py | 86 +++++++++++++++++++ distiller/pruning/automated_gradual_pruner.py | 64 +++++++------- distiller/pruning/baidu_rnn_pruner.py | 11 +-- distiller/pruning/level_pruner.py | 18 +--- distiller/pruning/magnitude_pruner.py | 12 +-- distiller/pruning/pruner.py | 50 ----------- distiller/pruning/ranked_structures_pruner.py | 42 +++++---- distiller/pruning/sensitivity_pruner.py | 18 +--- distiller/pruning/splicing_pruner.py | 6 +- distiller/pruning/structure_pruner.py | 4 +- distiller/regularization/l1_regularizer.py | 6 +- distiller/scheduler.py | 2 +- distiller/thresholding.py | 15 +--- tests/test_thresholding.py | 24 +++++- 16 files changed, 204 insertions(+), 187 deletions(-) delete mode 100755 distiller/pruning/pruner.py diff --git a/distiller/__init__.py b/distiller/__init__.py index e687010..13ef908 100755 --- a/distiller/__init__.py +++ b/distiller/__init__.py @@ -16,10 +16,11 @@ import torch from .utils import * -from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask +from .thresholding import GroupThresholdMixin, group_threshold_mask from .config import file_config, dict_config, config_component_from_file_by_class from .model_summaries import * from .scheduler import * +from .pruning import * from .sensitivity import * from .directives import * from .policy import * @@ -41,6 +42,21 @@ except pkg_resources.DistributionNotFound: __version__ = "Unknown" +def __check_pytorch_version(): + from pkg_resources import parse_version + required = '1.3.1' + actual = torch.__version__ + if parse_version(actual) < parse_version(required): + msg = "\n\nWRONG PYTORCH VERSION\n"\ + "Required: {}\n" \ + "Installed: {}\n"\ + "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual) + raise RuntimeError(msg) + + +__check_pytorch_version() + + def model_find_param_name(model, param_to_find): """Look up the name of a model parameter. @@ -104,17 +120,3 @@ def model_find_module(model, module_to_find): return m return None - -def check_pytorch_version(): - from pkg_resources import parse_version - required = '1.3.1' - actual = torch.__version__ - if parse_version(actual) < parse_version(required): - msg = "\n\nWRONG PYTORCH VERSION\n"\ - "Required: {}\n" \ - "Installed: {}\n"\ - "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual) - raise RuntimeError(msg) - - -check_pytorch_version() diff --git a/distiller/norms.py b/distiller/norms.py index a73a07f..c43019b 100644 --- a/distiller/norms.py +++ b/distiller/norms.py @@ -33,6 +33,7 @@ see: https://www.kaggle.com/residentmario/l1-norms-versus-l2-norms) import torch import numpy as np from functools import partial +from random import uniform __all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm", diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 710c7b5..25e115f 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -42,6 +42,7 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \ FMReconstructionChannelPruner from .baidu_rnn_pruner import BaiduRNNPruner from .greedy_filter_pruning import greedy_pruner +import torch del magnitude_pruner del automated_gradual_pruner @@ -49,3 +50,88 @@ del level_pruner del sensitivity_pruner del structure_pruner del ranked_structures_pruner + + +def mask_tensor(tensor, mask, inplace=True): + """Mask the provided tensor + + Args: + tensor - the torch-tensor to mask + mask - binary coefficient-masking tensor. Has the same type and shape as `tensor` + Returns: + tensor = tensor * mask ;where * is the element-wise multiplication operator + """ + assert tensor.type() == mask.type() + assert tensor.shape == mask.shape + if mask is not None: + return tensor.data.mul_(mask) if inplace else tensor.data.mul(mask) + return tensor + + +def create_mask_threshold_criterion(tensor, threshold): + """Create a tensor mask using a threshold criterion. + + All values smaller or equal to the threshold will be masked-away. + Granularity: Element-wise + Args: + tensor - the tensor to threshold. + threshold - a floating-point threshold value. + Returns: + boolean mask tensor, having the same size as the input tensor. + """ + with torch.no_grad(): + mask = torch.gt(torch.abs(tensor), threshold).type(tensor.type()) + return mask + + +def create_mask_level_criterion(tensor, desired_sparsity): + """Create a tensor mask using a level criterion. + + A specified fraction of the input tensor will be masked. The tensor coefficients + are first sorted by their L1-norm (absolute value), and then the lower `desired_sparsity` + coefficients are masked. + Granularity: Element-wise + + WARNING: due to the implementation details (speed over correctness), this will perform + incorrectly if "too many" of the coefficients have the same value. E.g. this will fail: + a = torch.ones(3, 64, 32, 32) + mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3) + assert math.isclose(distiller.sparsity(mask), 0.3) + + Args: + tensor - the tensor to mask. + desired_sparsity - a floating-point value in the range (0..1) specifying what + percentage of the tensor will be masked. + Returns: + boolean mask tensor, having the same size as the input tensor. + """ + with torch.no_grad(): + # partial sort + bottomk, _ = torch.topk(tensor.abs().view(-1), + int(desired_sparsity * tensor.numel()), + largest=False, + sorted=True) + threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away + mask = create_mask_threshold_criterion(tensor, threshold) + return mask + + +def create_mask_sensitivity_criterion(tensor, sensitivity): + """Create a tensor mask using a sensitivity criterion. + + Mask an input tensor based on the variance of the distribution of the tensor coefficients. + Coefficients in the distribution's specified band around the mean will be masked (symmetrically). + Granularity: Element-wise + Args: + tensor - the tensor to mask. + sensitivity - a floating-point value specifying the sensitivity. This is a simple + multiplier of the standard-deviation. + Returns: + boolean mask tensor, having the same size as the input tensor. + """ + if not hasattr(tensor, 'stddev'): + tensor.stddev = torch.std(tensor).item() + with torch.no_grad(): + threshold = tensor.stddev * sensitivity + mask = create_mask_threshold_criterion(tensor, threshold) + return mask diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index c2e6d4c..e499ae2 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -14,37 +14,19 @@ # limitations under the License. # -from .pruner import _ParameterPruner -from .level_pruner import SparsityLevelParameterPruner from .ranked_structures_pruner import * -from distiller.utils import * -from functools import partial +import distiller -class AutomatedGradualPrunerBase(_ParameterPruner): - """Prune to an exact sparsity level specification using a prescribed sparsity - level schedule formula. - - An automated gradual pruning algorithm that prunes the smallest magnitude - weights to achieve a preset level of network sparsity. - - Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the - efficacy of pruning for model compression", 2017 NIPS Workshop on Machine - Learning of Phones and other Consumer Devices, - (https://arxiv.org/pdf/1710.01878.pdf) +class AgpPruningRate(object): + """A pruning-rate scheduler per https://arxiv.org/pdf/1710.01878.pdf. """ - - def __init__(self, name, initial_sparsity, final_sparsity): - super().__init__(name) + def __init__(self, initial_sparsity, final_sparsity): self.initial_sparsity = initial_sparsity self.final_sparsity = final_sparsity assert final_sparsity > initial_sparsity - def compute_target_sparsity(self, meta): - starting_epoch = meta['starting_epoch'] - current_epoch = meta['current_epoch'] - ending_epoch = meta['ending_epoch'] - freq = meta['frequency'] + def step(self, current_epoch, starting_epoch, ending_epoch, freq): span = ((ending_epoch - starting_epoch - 1) // freq) * freq assert span > 0 @@ -54,8 +36,32 @@ class AutomatedGradualPrunerBase(_ParameterPruner): return target_sparsity + +class AutomatedGradualPrunerBase(object): + """Prune to an exact sparsity level specification using a prescribed sparsity + level schedule formula. + + An automated gradual pruning algorithm that prunes the smallest magnitude + weights to achieve a preset level of network sparsity. + + Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the + efficacy of pruning for model compression", 2017 NIPS Workshop on Machine + Learning of Phones and other Consumer Devices, + (https://arxiv.org/pdf/1710.01878.pdf) + """ + + def __init__(self, name, rate_scheduler): + """ + Args: + rate_scheduler - schedules the pruning rate. You can plug in any + rate scheduler. We implemented AGP paper's rate control in AgpPruningRate. + """ + self.name = name + self.agp_pr = rate_scheduler + def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - target_sparsity = self.compute_target_sparsity(meta) + target_sparsity = self.agp_pr.step(meta['current_epoch'], meta['starting_epoch'], + meta['ending_epoch'], meta['frequency']) self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, meta['model']) def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): @@ -69,8 +75,8 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase): An automated gradual pruning algorithm that prunes the smallest magnitude weights to achieve a preset level of network sparsity. """ - def __init__(self, name, initial_sparsity, final_sparsity, weights): - super().__init__(name, initial_sparsity, final_sparsity) + def __init__(self, name, initial_sparsity, final_sparsity, weights, rate_scheduler_factory=AgpPruningRate): + super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity)) self.params_names = weights assert self.params_names @@ -80,7 +86,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase): super().set_param_mask(param, param_name, zeros_mask_dict, meta) def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): - zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity) + zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, target_sparsity) class StructuredAGP(AutomatedGradualPrunerBase): @@ -89,8 +95,8 @@ class StructuredAGP(AutomatedGradualPrunerBase): This is a base-class for structured pruning with an AGP schedule. It is an extension of the AGP concept introduced by Zhu et. al. """ - def __init__(self, name, initial_sparsity, final_sparsity): - super().__init__(name, initial_sparsity, final_sparsity) + def __init__(self, name, initial_sparsity, final_sparsity, rate_scheduler_factory=AgpPruningRate): + super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity)) self.pruner = None def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model): diff --git a/distiller/pruning/baidu_rnn_pruner.py b/distiller/pruning/baidu_rnn_pruner.py index f195b85..8ec7fac 100755 --- a/distiller/pruning/baidu_rnn_pruner.py +++ b/distiller/pruning/baidu_rnn_pruner.py @@ -14,13 +14,10 @@ # limitations under the License. # -from .pruner import _ParameterPruner -from .level_pruner import SparsityLevelParameterPruner -from distiller.utils import * - import distiller -class BaiduRNNPruner(_ParameterPruner): + +class BaiduRNNPruner(object): """An element-wise pruner for RNN networks. Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). @@ -53,7 +50,7 @@ class BaiduRNNPruner(_ParameterPruner): def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights): # Initialize the pruner, using a configuration that originates from the # schedule YAML file. - super(BaiduRNNPruner, self).__init__(name) + self.name = name self.params_names = weights assert self.params_names @@ -93,4 +90,4 @@ class BaiduRNNPruner(_ParameterPruner): self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq # After computing the threshold, we can create the mask - zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps) + zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, eps) diff --git a/distiller/pruning/level_pruner.py b/distiller/pruning/level_pruner.py index dc6449d..53a07cb 100755 --- a/distiller/pruning/level_pruner.py +++ b/distiller/pruning/level_pruner.py @@ -14,11 +14,10 @@ # limitations under the License. # -import torch -from .pruner import _ParameterPruner import distiller -class SparsityLevelParameterPruner(_ParameterPruner): + +class SparsityLevelParameterPruner(object): """Prune to an exact pruning level specification. This pruner is very similar to MagnitudeParameterPruner, but instead of @@ -30,7 +29,7 @@ class SparsityLevelParameterPruner(_ParameterPruner): """ def __init__(self, name, levels, **kwargs): - super(SparsityLevelParameterPruner, self).__init__(name) + self.name = name self.levels = levels assert self.levels @@ -40,13 +39,4 @@ class SparsityLevelParameterPruner(_ParameterPruner): desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0)) if desired_sparsity == 0: return - - zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, desired_sparsity) - - @staticmethod - def create_mask(param, desired_sparsity): - with torch.no_grad(): - bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True) - threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away - mask = distiller.threshold_mask(param.data, threshold) - return mask \ No newline at end of file + zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, desired_sparsity) diff --git a/distiller/pruning/magnitude_pruner.py b/distiller/pruning/magnitude_pruner.py index c46732e..c8d5cd6 100755 --- a/distiller/pruning/magnitude_pruner.py +++ b/distiller/pruning/magnitude_pruner.py @@ -14,12 +14,11 @@ # limitations under the License. # -from .pruner import _ParameterPruner import distiller import torch -class MagnitudeParameterPruner(_ParameterPruner): +class MagnitudeParameterPruner(object): """This is the most basic magnitude-based pruner. This pruner supports configuring a scalar threshold for each layer. @@ -43,7 +42,7 @@ class MagnitudeParameterPruner(_ParameterPruner): value is used. Currently it is mandatory to include a '*' key in 'thresholds'. """ - super(MagnitudeParameterPruner, self).__init__(name) + self.name = name assert thresholds is not None # Make sure there is a default threshold to use assert '*' in thresholds @@ -51,13 +50,8 @@ class MagnitudeParameterPruner(_ParameterPruner): def set_param_mask(self, param, param_name, zeros_mask_dict, meta): threshold = self.thresholds.get(param_name, self.thresholds['*']) - zeros_mask_dict[param_name].mask = self.create_mask(param.data, threshold) + zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, threshold) - @staticmethod - def create_mask(param, threshold): - with torch.no_grad(): - mask = distiller.threshold_mask(param.data, threshold) - return mask diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py deleted file mode 100755 index 7ba5f2f..0000000 --- a/distiller/pruning/pruner.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Copyright (c) 2018 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -import distiller - - -__all__ = ["mask_tensor"] - - -class _ParameterPruner(object): - """Base class for all pruners. - - Arguments: - name: pruner name is used mainly for debugging. - """ - def __init__(self, name): - self.name = name - - def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - raise NotImplementedError - - -def mask_tensor(tensor, mask): - """Mask the provided tensor - - Args: - tensor - the torch-tensor to mask - mask - binary coefficient-masking tensor. Has the same type and shape as `tensor` - Returns: - tensor = tensor * mask ;where * is the element-wise multiplication operator - """ - assert tensor.type() == mask.type() - assert tensor.shape == mask.shape - if mask: - tensor.data.mul_(mask) - return tensor diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 629df73..ea11aa1 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -22,7 +22,6 @@ import torch from torch.nn import functional as f from random import uniform import distiller -from .pruner import _ParameterPruner __all__ = ["LpRankedStructureParameterPruner", @@ -39,12 +38,12 @@ __all__ = ["LpRankedStructureParameterPruner", msglogger = logging.getLogger(__name__) -class _RankedStructureParameterPruner(_ParameterPruner): +class _RankedStructureParameterPruner(object): """Base class for pruning structures by ranking them. """ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, group_size=1, rounding_fn=math.floor, noise=0.): - super().__init__(name) + self.name = name self.group_type = group_type self.group_dependency = group_dependency self.params_names = weights @@ -343,14 +342,15 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner): # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] + #distiller.assign_layer_fq_names(model) module = distiller.find_module_by_fq_name(model, fq_name) if module is None: raise ValueError("Could not find a layer named %s in the model." "\nMake sure to use assign_layer_fq_names()" % fq_name) if not hasattr(module, self.activation_rank_criterion): - raise ValueError("Could not find attribute \"{}\" in module {}" - "\nMake sure to use SummaryActivationStatsCollector(\"{}\")". - format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) + raise ValueError("Could not find attribute \"%s\" in module %s" + "\nMake sure to use SummaryActivationStatsCollector(\"%s\")" % + (self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) quality_criterion, std = getattr(module, self.activation_rank_criterion).value() num_filters = param.size(0) @@ -469,10 +469,7 @@ class BernoulliFilterPruner(_RankedStructureParameterPruner): class GradientRankedFilterPruner(_RankedStructureParameterPruner): - """Taylor expansion ranking. - - Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. Pruning Convolutional Neural - Networks for Resource Efficient Inference. ArXiv, abs/1611.06440, 2016. + """Rank the importance of weight filters using the product of their gradients and the filter value. """ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): super().__init__(name, group_type, desired_sparsity, weights, group_dependency) @@ -495,22 +492,23 @@ class GradientRankedFilterPruner(_RankedStructureParameterPruner): msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return - # Compute the multiplication of the filters times the filter_gradienrs + # Compute the multiplication of the filters times the filter_gradients view_filters = param.view(param.size(0), -1) view_filter_grads = param.grad.view(param.size(0), -1) - weighted_gradients = view_filter_grads * view_filters - weighted_gradients = weighted_gradients.sum(dim=1) + with torch.no_grad(): + weighted_gradients = view_filter_grads * view_filters + weighted_gradients = weighted_gradients.sum(dim=1) - # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters - filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune] - mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map) - zeros_mask_dict[param_name].mask = mask + # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters + filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune] + mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map) + zeros_mask_dict[param_name].mask = mask - msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", - param_name, - distiller.sparsity_3D(zeros_mask_dict[param_name].mask), - fraction_to_prune, num_filters_to_prune, num_filters) - return binary_map + msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + param_name, + distiller.sparsity_3D(zeros_mask_dict[param_name].mask), + fraction_to_prune, num_filters_to_prune, num_filters) + return binary_map from sklearn.linear_model import LinearRegression diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py index 5fc4f3f..5832399 100755 --- a/distiller/pruning/sensitivity_pruner.py +++ b/distiller/pruning/sensitivity_pruner.py @@ -14,12 +14,11 @@ # limitations under the License. # -from .pruner import _ParameterPruner + import distiller -import torch -class SensitivityPruner(_ParameterPruner): +class SensitivityPruner(object): """Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf @@ -41,7 +40,7 @@ class SensitivityPruner(_ParameterPruner): """ def __init__(self, name, sensitivities, **kwargs): - super(SensitivityPruner, self).__init__(name) + self.name = name self.sensitivities = sensitivities def set_param_mask(self, param, param_name, zeros_mask_dict, meta): @@ -53,13 +52,4 @@ class SensitivityPruner(_ParameterPruner): else: sensitivity = self.sensitivities[param_name] - zeros_mask_dict[param_name].mask = self.create_mask(param, sensitivity) - - @staticmethod - def create_mask(param, sensitivity): - if not hasattr(param, 'stddev'): - param.stddev = torch.std(param).item() - with torch.no_grad(): - threshold = param.stddev * sensitivity - mask = distiller.threshold_mask(param.data, threshold) - return mask + zeros_mask_dict[param_name].mask = distiller.create_mask_sensitivity_criterion(param, sensitivity) diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py index 2a2edd3..4dbb0b2 100755 --- a/distiller/pruning/splicing_pruner.py +++ b/distiller/pruning/splicing_pruner.py @@ -14,14 +14,12 @@ # limitations under the License. # - -from .pruner import _ParameterPruner import torch import logging msglogger = logging.getLogger() -class SplicingPruner(_ParameterPruner): +class SplicingPruner(object): """A pruner that both prunes and splices connections. The idea of pruning and splicing working in tandem was first proposed in the following @@ -37,7 +35,7 @@ class SplicingPruner(_ParameterPruner): def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0): """Arguments: """ - super(SplicingPruner, self).__init__(name) + self.name = name self.sensitivities = sensitivities self.low_thresh_mult = low_thresh_mult self.hi_thresh_mult = hi_thresh_mult diff --git a/distiller/pruning/structure_pruner.py b/distiller/pruning/structure_pruner.py index efd98a3..0cbb42a 100755 --- a/distiller/pruning/structure_pruner.py +++ b/distiller/pruning/structure_pruner.py @@ -15,12 +15,11 @@ # import logging -from .pruner import _ParameterPruner import distiller msglogger = logging.getLogger() -class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): +class StructureParameterPruner(distiller.GroupThresholdMixin): """Prune parameter structures. Pruning criterion: average L1-norm. If the average L1-norm (absolute value) of the eleements @@ -30,7 +29,6 @@ class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): the structure size. """ def __init__(self, name, model, reg_regims, threshold_criteria): - super(StructureParameterPruner, self).__init__(name) self.name = name self.model = model self.reg_regims = reg_regims diff --git a/distiller/regularization/l1_regularizer.py b/distiller/regularization/l1_regularizer.py index fdd6007..a79c32d 100755 --- a/distiller/regularization/l1_regularizer.py +++ b/distiller/regularization/l1_regularizer.py @@ -16,9 +16,7 @@ """L1-norm regularization""" -import torch -import math -import numpy as np + import distiller from .regularizer import _Regularizer, EPSILON @@ -39,7 +37,7 @@ class L1Regularizer(_Regularizer): return strength = self.reg_regims[param_name] - zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength) + zeros_mask_dict[param_name].mask = distiller.pruning.create_mask_threshold_criterion(param, threshold=strength) zeros_mask_dict[param_name].is_regularization_mask = True @staticmethod diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 45a25e3..82266f0 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -263,7 +263,7 @@ class ParameterMasker(object): def revert_weights(self, parameter): if not self.use_double_copies or self.unmasked_copy is None: - msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name)) + # This parameter does not maintain double copies (this is OK) return parameter.data.copy_(self.unmasked_copy) self.unmasked_copy = None diff --git a/distiller/thresholding.py b/distiller/thresholding.py index cb38909..590adc3 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -23,23 +23,10 @@ import numpy as np from distiller.norms import * -__all__ = ["threshold_mask", "GroupThresholdMixin", +__all__ = ["GroupThresholdMixin", "group_threshold_binary_map", "group_threshold_mask"] -def threshold_mask(param, threshold): - """Create a threshold mask for the provided parameter tensor using - magnitude thresholding. - - Arguments: - param: a parameter tensor which should be pruned. - threshold: the pruning threshold. - Returns: - prune_mask: The pruning mask. - """ - return torch.gt(torch.abs(param), threshold).type(param.type()) - - class GroupThresholdMixin(object): """A mixin class to add group thresholding capabilities diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py index f231abd..e014efa 100755 --- a/tests/test_thresholding.py +++ b/tests/test_thresholding.py @@ -77,12 +77,32 @@ def test_threshold_mask(): # Change one element a[1, 4, 17, 31] = 0.2 # Create and apply a mask - mask = distiller.threshold_mask(a, threshold=0.3) + mask = distiller.create_mask_threshold_criterion(a, threshold=0.3) assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1) assert mask[1, 4, 17, 31] == 0 assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a)) +def test_level_mask(): + # Create a 4-D tensor of 1s + a = torch.rand(3, 64, 32, 32) + + # Create and apply a mask + mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3) + assert common.almost_equal(distiller.sparsity(mask), 0.3, max_diff=0.0001) + + +def test_sensitivity_mask(): + # Create a 4-D tensor of normally-distributed coefficients + a = torch.randn(3, 64, 32, 32) + + # Create and apply a mask + mask = distiller.create_mask_sensitivity_criterion(a, sensitivity=1) + # The width of 1-std on ~N(0,1) is about 68.27%. In other words: + # Pr(mean - std <= X <= mean + std) is about 68.27% + assert common.almost_equal(distiller.sparsity(mask), 0.6827, max_diff=0.005) + + def test_kernel_thresholding(): p = get_test_4d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, '2D', 6, 'L1') @@ -240,6 +260,8 @@ def test_row_thresholding(): [ 0., 0., 0.], [ 1., 1., 1.], [ 1., 1., 1.]], device=mask.device)).all() + masked_tensor = distiller.mask_tensor(p, mask) + assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask) return mask -- GitLab