diff --git a/distiller/__init__.py b/distiller/__init__.py index e68701064123cc83fe1c6c59e2d18ad5af92b7fb..13ef908e3cc8cfa7086004c3d97d47be4f4c5c6e 100755 --- a/distiller/__init__.py +++ b/distiller/__init__.py @@ -16,10 +16,11 @@ import torch from .utils import * -from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask +from .thresholding import GroupThresholdMixin, group_threshold_mask from .config import file_config, dict_config, config_component_from_file_by_class from .model_summaries import * from .scheduler import * +from .pruning import * from .sensitivity import * from .directives import * from .policy import * @@ -41,6 +42,21 @@ except pkg_resources.DistributionNotFound: __version__ = "Unknown" +def __check_pytorch_version(): + from pkg_resources import parse_version + required = '1.3.1' + actual = torch.__version__ + if parse_version(actual) < parse_version(required): + msg = "\n\nWRONG PYTORCH VERSION\n"\ + "Required: {}\n" \ + "Installed: {}\n"\ + "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual) + raise RuntimeError(msg) + + +__check_pytorch_version() + + def model_find_param_name(model, param_to_find): """Look up the name of a model parameter. @@ -104,17 +120,3 @@ def model_find_module(model, module_to_find): return m return None - -def check_pytorch_version(): - from pkg_resources import parse_version - required = '1.3.1' - actual = torch.__version__ - if parse_version(actual) < parse_version(required): - msg = "\n\nWRONG PYTORCH VERSION\n"\ - "Required: {}\n" \ - "Installed: {}\n"\ - "Please run 'pip install -e .' from the Distiller repo root dir\n".format(required, actual) - raise RuntimeError(msg) - - -check_pytorch_version() diff --git a/distiller/norms.py b/distiller/norms.py index a73a07fc1ee65d45fb5e5289e5446afda69f038a..c43019bc9557dd16f0a967f8225d74888f658249 100644 --- a/distiller/norms.py +++ b/distiller/norms.py @@ -33,6 +33,7 @@ see: https://www.kaggle.com/residentmario/l1-norms-versus-l2-norms) import torch import numpy as np from functools import partial +from random import uniform __all__ = ["kernels_lp_norm", "channels_lp_norm", "filters_lp_norm", diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 710c7b5f7c040fce7d425a23399e7cbb40c04d61..25e115f922c0378b5327b4f66749179646171730 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -42,6 +42,7 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \ FMReconstructionChannelPruner from .baidu_rnn_pruner import BaiduRNNPruner from .greedy_filter_pruning import greedy_pruner +import torch del magnitude_pruner del automated_gradual_pruner @@ -49,3 +50,88 @@ del level_pruner del sensitivity_pruner del structure_pruner del ranked_structures_pruner + + +def mask_tensor(tensor, mask, inplace=True): + """Mask the provided tensor + + Args: + tensor - the torch-tensor to mask + mask - binary coefficient-masking tensor. Has the same type and shape as `tensor` + Returns: + tensor = tensor * mask ;where * is the element-wise multiplication operator + """ + assert tensor.type() == mask.type() + assert tensor.shape == mask.shape + if mask is not None: + return tensor.data.mul_(mask) if inplace else tensor.data.mul(mask) + return tensor + + +def create_mask_threshold_criterion(tensor, threshold): + """Create a tensor mask using a threshold criterion. + + All values smaller or equal to the threshold will be masked-away. + Granularity: Element-wise + Args: + tensor - the tensor to threshold. + threshold - a floating-point threshold value. + Returns: + boolean mask tensor, having the same size as the input tensor. + """ + with torch.no_grad(): + mask = torch.gt(torch.abs(tensor), threshold).type(tensor.type()) + return mask + + +def create_mask_level_criterion(tensor, desired_sparsity): + """Create a tensor mask using a level criterion. + + A specified fraction of the input tensor will be masked. The tensor coefficients + are first sorted by their L1-norm (absolute value), and then the lower `desired_sparsity` + coefficients are masked. + Granularity: Element-wise + + WARNING: due to the implementation details (speed over correctness), this will perform + incorrectly if "too many" of the coefficients have the same value. E.g. this will fail: + a = torch.ones(3, 64, 32, 32) + mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3) + assert math.isclose(distiller.sparsity(mask), 0.3) + + Args: + tensor - the tensor to mask. + desired_sparsity - a floating-point value in the range (0..1) specifying what + percentage of the tensor will be masked. + Returns: + boolean mask tensor, having the same size as the input tensor. + """ + with torch.no_grad(): + # partial sort + bottomk, _ = torch.topk(tensor.abs().view(-1), + int(desired_sparsity * tensor.numel()), + largest=False, + sorted=True) + threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away + mask = create_mask_threshold_criterion(tensor, threshold) + return mask + + +def create_mask_sensitivity_criterion(tensor, sensitivity): + """Create a tensor mask using a sensitivity criterion. + + Mask an input tensor based on the variance of the distribution of the tensor coefficients. + Coefficients in the distribution's specified band around the mean will be masked (symmetrically). + Granularity: Element-wise + Args: + tensor - the tensor to mask. + sensitivity - a floating-point value specifying the sensitivity. This is a simple + multiplier of the standard-deviation. + Returns: + boolean mask tensor, having the same size as the input tensor. + """ + if not hasattr(tensor, 'stddev'): + tensor.stddev = torch.std(tensor).item() + with torch.no_grad(): + threshold = tensor.stddev * sensitivity + mask = create_mask_threshold_criterion(tensor, threshold) + return mask diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index c2e6d4ce44291895fc9dda0642915c0a86e18b9c..e499ae2a55d990d28d2f1e502502b4b1e88dfe32 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -14,37 +14,19 @@ # limitations under the License. # -from .pruner import _ParameterPruner -from .level_pruner import SparsityLevelParameterPruner from .ranked_structures_pruner import * -from distiller.utils import * -from functools import partial +import distiller -class AutomatedGradualPrunerBase(_ParameterPruner): - """Prune to an exact sparsity level specification using a prescribed sparsity - level schedule formula. - - An automated gradual pruning algorithm that prunes the smallest magnitude - weights to achieve a preset level of network sparsity. - - Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the - efficacy of pruning for model compression", 2017 NIPS Workshop on Machine - Learning of Phones and other Consumer Devices, - (https://arxiv.org/pdf/1710.01878.pdf) +class AgpPruningRate(object): + """A pruning-rate scheduler per https://arxiv.org/pdf/1710.01878.pdf. """ - - def __init__(self, name, initial_sparsity, final_sparsity): - super().__init__(name) + def __init__(self, initial_sparsity, final_sparsity): self.initial_sparsity = initial_sparsity self.final_sparsity = final_sparsity assert final_sparsity > initial_sparsity - def compute_target_sparsity(self, meta): - starting_epoch = meta['starting_epoch'] - current_epoch = meta['current_epoch'] - ending_epoch = meta['ending_epoch'] - freq = meta['frequency'] + def step(self, current_epoch, starting_epoch, ending_epoch, freq): span = ((ending_epoch - starting_epoch - 1) // freq) * freq assert span > 0 @@ -54,8 +36,32 @@ class AutomatedGradualPrunerBase(_ParameterPruner): return target_sparsity + +class AutomatedGradualPrunerBase(object): + """Prune to an exact sparsity level specification using a prescribed sparsity + level schedule formula. + + An automated gradual pruning algorithm that prunes the smallest magnitude + weights to achieve a preset level of network sparsity. + + Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the + efficacy of pruning for model compression", 2017 NIPS Workshop on Machine + Learning of Phones and other Consumer Devices, + (https://arxiv.org/pdf/1710.01878.pdf) + """ + + def __init__(self, name, rate_scheduler): + """ + Args: + rate_scheduler - schedules the pruning rate. You can plug in any + rate scheduler. We implemented AGP paper's rate control in AgpPruningRate. + """ + self.name = name + self.agp_pr = rate_scheduler + def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - target_sparsity = self.compute_target_sparsity(meta) + target_sparsity = self.agp_pr.step(meta['current_epoch'], meta['starting_epoch'], + meta['ending_epoch'], meta['frequency']) self._set_param_mask_by_sparsity_target(param, param_name, zeros_mask_dict, target_sparsity, meta['model']) def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): @@ -69,8 +75,8 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase): An automated gradual pruning algorithm that prunes the smallest magnitude weights to achieve a preset level of network sparsity. """ - def __init__(self, name, initial_sparsity, final_sparsity, weights): - super().__init__(name, initial_sparsity, final_sparsity) + def __init__(self, name, initial_sparsity, final_sparsity, weights, rate_scheduler_factory=AgpPruningRate): + super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity)) self.params_names = weights assert self.params_names @@ -80,7 +86,7 @@ class AutomatedGradualPruner(AutomatedGradualPrunerBase): super().set_param_mask(param, param_name, zeros_mask_dict, meta) def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): - zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, target_sparsity) + zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, target_sparsity) class StructuredAGP(AutomatedGradualPrunerBase): @@ -89,8 +95,8 @@ class StructuredAGP(AutomatedGradualPrunerBase): This is a base-class for structured pruning with an AGP schedule. It is an extension of the AGP concept introduced by Zhu et. al. """ - def __init__(self, name, initial_sparsity, final_sparsity): - super().__init__(name, initial_sparsity, final_sparsity) + def __init__(self, name, initial_sparsity, final_sparsity, rate_scheduler_factory=AgpPruningRate): + super().__init__(name, rate_scheduler=rate_scheduler_factory(initial_sparsity, final_sparsity)) self.pruner = None def _set_param_mask_by_sparsity_target(self, param, param_name, zeros_mask_dict, target_sparsity, model): diff --git a/distiller/pruning/baidu_rnn_pruner.py b/distiller/pruning/baidu_rnn_pruner.py index f195b8598cbc2614b9e4d7522e6672dc13bffd6c..8ec7fac1e24b55320246fc45a2dae44026bdfc1e 100755 --- a/distiller/pruning/baidu_rnn_pruner.py +++ b/distiller/pruning/baidu_rnn_pruner.py @@ -14,13 +14,10 @@ # limitations under the License. # -from .pruner import _ParameterPruner -from .level_pruner import SparsityLevelParameterPruner -from distiller.utils import * - import distiller -class BaiduRNNPruner(_ParameterPruner): + +class BaiduRNNPruner(object): """An element-wise pruner for RNN networks. Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). @@ -53,7 +50,7 @@ class BaiduRNNPruner(_ParameterPruner): def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights): # Initialize the pruner, using a configuration that originates from the # schedule YAML file. - super(BaiduRNNPruner, self).__init__(name) + self.name = name self.params_names = weights assert self.params_names @@ -93,4 +90,4 @@ class BaiduRNNPruner(_ParameterPruner): self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq # After computing the threshold, we can create the mask - zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps) + zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, eps) diff --git a/distiller/pruning/level_pruner.py b/distiller/pruning/level_pruner.py index dc6449dea88c0acbda66d0cc9a494fa514407ff7..53a07cbf5c544a62547e6829bb468fbad0745512 100755 --- a/distiller/pruning/level_pruner.py +++ b/distiller/pruning/level_pruner.py @@ -14,11 +14,10 @@ # limitations under the License. # -import torch -from .pruner import _ParameterPruner import distiller -class SparsityLevelParameterPruner(_ParameterPruner): + +class SparsityLevelParameterPruner(object): """Prune to an exact pruning level specification. This pruner is very similar to MagnitudeParameterPruner, but instead of @@ -30,7 +29,7 @@ class SparsityLevelParameterPruner(_ParameterPruner): """ def __init__(self, name, levels, **kwargs): - super(SparsityLevelParameterPruner, self).__init__(name) + self.name = name self.levels = levels assert self.levels @@ -40,13 +39,4 @@ class SparsityLevelParameterPruner(_ParameterPruner): desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0)) if desired_sparsity == 0: return - - zeros_mask_dict[param_name].mask = SparsityLevelParameterPruner.create_mask(param, desired_sparsity) - - @staticmethod - def create_mask(param, desired_sparsity): - with torch.no_grad(): - bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True) - threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away - mask = distiller.threshold_mask(param.data, threshold) - return mask \ No newline at end of file + zeros_mask_dict[param_name].mask = distiller.create_mask_level_criterion(param, desired_sparsity) diff --git a/distiller/pruning/magnitude_pruner.py b/distiller/pruning/magnitude_pruner.py index c46732edf76018885b0949f331e464b0c3e69311..c8d5cd692e7bfdfc499cb2845a6dcfa9f0b37a4f 100755 --- a/distiller/pruning/magnitude_pruner.py +++ b/distiller/pruning/magnitude_pruner.py @@ -14,12 +14,11 @@ # limitations under the License. # -from .pruner import _ParameterPruner import distiller import torch -class MagnitudeParameterPruner(_ParameterPruner): +class MagnitudeParameterPruner(object): """This is the most basic magnitude-based pruner. This pruner supports configuring a scalar threshold for each layer. @@ -43,7 +42,7 @@ class MagnitudeParameterPruner(_ParameterPruner): value is used. Currently it is mandatory to include a '*' key in 'thresholds'. """ - super(MagnitudeParameterPruner, self).__init__(name) + self.name = name assert thresholds is not None # Make sure there is a default threshold to use assert '*' in thresholds @@ -51,13 +50,8 @@ class MagnitudeParameterPruner(_ParameterPruner): def set_param_mask(self, param, param_name, zeros_mask_dict, meta): threshold = self.thresholds.get(param_name, self.thresholds['*']) - zeros_mask_dict[param_name].mask = self.create_mask(param.data, threshold) + zeros_mask_dict[param_name].mask = distiller.create_mask_threshold_criterion(param, threshold) - @staticmethod - def create_mask(param, threshold): - with torch.no_grad(): - mask = distiller.threshold_mask(param.data, threshold) - return mask diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py deleted file mode 100755 index 7ba5f2f9f42fe4c2d4978cc2d895e8b2eda89f4d..0000000000000000000000000000000000000000 --- a/distiller/pruning/pruner.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Copyright (c) 2018 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -import distiller - - -__all__ = ["mask_tensor"] - - -class _ParameterPruner(object): - """Base class for all pruners. - - Arguments: - name: pruner name is used mainly for debugging. - """ - def __init__(self, name): - self.name = name - - def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - raise NotImplementedError - - -def mask_tensor(tensor, mask): - """Mask the provided tensor - - Args: - tensor - the torch-tensor to mask - mask - binary coefficient-masking tensor. Has the same type and shape as `tensor` - Returns: - tensor = tensor * mask ;where * is the element-wise multiplication operator - """ - assert tensor.type() == mask.type() - assert tensor.shape == mask.shape - if mask: - tensor.data.mul_(mask) - return tensor diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 629df73f4de52907b285eea4fa11e994a0bbe10b..ea11aa106cc140ecd101aa1985df3e940881a101 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -22,7 +22,6 @@ import torch from torch.nn import functional as f from random import uniform import distiller -from .pruner import _ParameterPruner __all__ = ["LpRankedStructureParameterPruner", @@ -39,12 +38,12 @@ __all__ = ["LpRankedStructureParameterPruner", msglogger = logging.getLogger(__name__) -class _RankedStructureParameterPruner(_ParameterPruner): +class _RankedStructureParameterPruner(object): """Base class for pruning structures by ranking them. """ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, group_size=1, rounding_fn=math.floor, noise=0.): - super().__init__(name) + self.name = name self.group_type = group_type self.group_dependency = group_dependency self.params_names = weights @@ -343,14 +342,15 @@ class ActivationRankedFilterPruner(_RankedStructureParameterPruner): # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] + #distiller.assign_layer_fq_names(model) module = distiller.find_module_by_fq_name(model, fq_name) if module is None: raise ValueError("Could not find a layer named %s in the model." "\nMake sure to use assign_layer_fq_names()" % fq_name) if not hasattr(module, self.activation_rank_criterion): - raise ValueError("Could not find attribute \"{}\" in module {}" - "\nMake sure to use SummaryActivationStatsCollector(\"{}\")". - format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) + raise ValueError("Could not find attribute \"%s\" in module %s" + "\nMake sure to use SummaryActivationStatsCollector(\"%s\")" % + (self.activation_rank_criterion, fq_name, self.activation_rank_criterion)) quality_criterion, std = getattr(module, self.activation_rank_criterion).value() num_filters = param.size(0) @@ -469,10 +469,7 @@ class BernoulliFilterPruner(_RankedStructureParameterPruner): class GradientRankedFilterPruner(_RankedStructureParameterPruner): - """Taylor expansion ranking. - - Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. Pruning Convolutional Neural - Networks for Resource Efficient Inference. ArXiv, abs/1611.06440, 2016. + """Rank the importance of weight filters using the product of their gradients and the filter value. """ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): super().__init__(name, group_type, desired_sparsity, weights, group_dependency) @@ -495,22 +492,23 @@ class GradientRankedFilterPruner(_RankedStructureParameterPruner): msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return - # Compute the multiplication of the filters times the filter_gradienrs + # Compute the multiplication of the filters times the filter_gradients view_filters = param.view(param.size(0), -1) view_filter_grads = param.grad.view(param.size(0), -1) - weighted_gradients = view_filter_grads * view_filters - weighted_gradients = weighted_gradients.sum(dim=1) + with torch.no_grad(): + weighted_gradients = view_filter_grads * view_filters + weighted_gradients = weighted_gradients.sum(dim=1) - # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters - filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune] - mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map) - zeros_mask_dict[param_name].mask = mask + # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters + filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune] + mask, binary_map = _mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map) + zeros_mask_dict[param_name].mask = mask - msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", - param_name, - distiller.sparsity_3D(zeros_mask_dict[param_name].mask), - fraction_to_prune, num_filters_to_prune, num_filters) - return binary_map + msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + param_name, + distiller.sparsity_3D(zeros_mask_dict[param_name].mask), + fraction_to_prune, num_filters_to_prune, num_filters) + return binary_map from sklearn.linear_model import LinearRegression diff --git a/distiller/pruning/sensitivity_pruner.py b/distiller/pruning/sensitivity_pruner.py index 5fc4f3f4e9b363f9b2c984b062c8cafd18d7a870..5832399dfb1ceb97cbb01387d762fb40765a9be9 100755 --- a/distiller/pruning/sensitivity_pruner.py +++ b/distiller/pruning/sensitivity_pruner.py @@ -14,12 +14,11 @@ # limitations under the License. # -from .pruner import _ParameterPruner + import distiller -import torch -class SensitivityPruner(_ParameterPruner): +class SensitivityPruner(object): """Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf @@ -41,7 +40,7 @@ class SensitivityPruner(_ParameterPruner): """ def __init__(self, name, sensitivities, **kwargs): - super(SensitivityPruner, self).__init__(name) + self.name = name self.sensitivities = sensitivities def set_param_mask(self, param, param_name, zeros_mask_dict, meta): @@ -53,13 +52,4 @@ class SensitivityPruner(_ParameterPruner): else: sensitivity = self.sensitivities[param_name] - zeros_mask_dict[param_name].mask = self.create_mask(param, sensitivity) - - @staticmethod - def create_mask(param, sensitivity): - if not hasattr(param, 'stddev'): - param.stddev = torch.std(param).item() - with torch.no_grad(): - threshold = param.stddev * sensitivity - mask = distiller.threshold_mask(param.data, threshold) - return mask + zeros_mask_dict[param_name].mask = distiller.create_mask_sensitivity_criterion(param, sensitivity) diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py index 2a2edd3411fdd31b71f1bc95591ebc8563d8e70e..4dbb0b2e4a636848e2483b569a6aa76665121138 100755 --- a/distiller/pruning/splicing_pruner.py +++ b/distiller/pruning/splicing_pruner.py @@ -14,14 +14,12 @@ # limitations under the License. # - -from .pruner import _ParameterPruner import torch import logging msglogger = logging.getLogger() -class SplicingPruner(_ParameterPruner): +class SplicingPruner(object): """A pruner that both prunes and splices connections. The idea of pruning and splicing working in tandem was first proposed in the following @@ -37,7 +35,7 @@ class SplicingPruner(_ParameterPruner): def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0): """Arguments: """ - super(SplicingPruner, self).__init__(name) + self.name = name self.sensitivities = sensitivities self.low_thresh_mult = low_thresh_mult self.hi_thresh_mult = hi_thresh_mult diff --git a/distiller/pruning/structure_pruner.py b/distiller/pruning/structure_pruner.py index efd98a36950ef887bd9816fac4d8756b93c6c899..0cbb42a3759e6c4111320ab629b175996c62f4b2 100755 --- a/distiller/pruning/structure_pruner.py +++ b/distiller/pruning/structure_pruner.py @@ -15,12 +15,11 @@ # import logging -from .pruner import _ParameterPruner import distiller msglogger = logging.getLogger() -class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): +class StructureParameterPruner(distiller.GroupThresholdMixin): """Prune parameter structures. Pruning criterion: average L1-norm. If the average L1-norm (absolute value) of the eleements @@ -30,7 +29,6 @@ class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): the structure size. """ def __init__(self, name, model, reg_regims, threshold_criteria): - super(StructureParameterPruner, self).__init__(name) self.name = name self.model = model self.reg_regims = reg_regims diff --git a/distiller/regularization/l1_regularizer.py b/distiller/regularization/l1_regularizer.py index fdd6007ebc8562d5a0774c24118d0470fbfc4d43..a79c32d44cb7ebce408ed8a0264051ead6e1abb9 100755 --- a/distiller/regularization/l1_regularizer.py +++ b/distiller/regularization/l1_regularizer.py @@ -16,9 +16,7 @@ """L1-norm regularization""" -import torch -import math -import numpy as np + import distiller from .regularizer import _Regularizer, EPSILON @@ -39,7 +37,7 @@ class L1Regularizer(_Regularizer): return strength = self.reg_regims[param_name] - zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength) + zeros_mask_dict[param_name].mask = distiller.pruning.create_mask_threshold_criterion(param, threshold=strength) zeros_mask_dict[param_name].is_regularization_mask = True @staticmethod diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 45a25e379ddec155446fc0f63dd98bf503a5e627..82266f04a0aa5f0e9f2e1af671f2dda1a4ffec89 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -263,7 +263,7 @@ class ParameterMasker(object): def revert_weights(self, parameter): if not self.use_double_copies or self.unmasked_copy is None: - msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name)) + # This parameter does not maintain double copies (this is OK) return parameter.data.copy_(self.unmasked_copy) self.unmasked_copy = None diff --git a/distiller/thresholding.py b/distiller/thresholding.py index cb38909667788baa6d5d360230a08b6031654ccc..590adc321814bfce85a3e742d036e51361fcdb1a 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -23,23 +23,10 @@ import numpy as np from distiller.norms import * -__all__ = ["threshold_mask", "GroupThresholdMixin", +__all__ = ["GroupThresholdMixin", "group_threshold_binary_map", "group_threshold_mask"] -def threshold_mask(param, threshold): - """Create a threshold mask for the provided parameter tensor using - magnitude thresholding. - - Arguments: - param: a parameter tensor which should be pruned. - threshold: the pruning threshold. - Returns: - prune_mask: The pruning mask. - """ - return torch.gt(torch.abs(param), threshold).type(param.type()) - - class GroupThresholdMixin(object): """A mixin class to add group thresholding capabilities diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py index f231abda93eedd19f0f7c0b46f6dfdce8baf1ffe..e014efa2ba14054d19f2255759ae4bd59ee547ef 100755 --- a/tests/test_thresholding.py +++ b/tests/test_thresholding.py @@ -77,12 +77,32 @@ def test_threshold_mask(): # Change one element a[1, 4, 17, 31] = 0.2 # Create and apply a mask - mask = distiller.threshold_mask(a, threshold=0.3) + mask = distiller.create_mask_threshold_criterion(a, threshold=0.3) assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1) assert mask[1, 4, 17, 31] == 0 assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a)) +def test_level_mask(): + # Create a 4-D tensor of 1s + a = torch.rand(3, 64, 32, 32) + + # Create and apply a mask + mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3) + assert common.almost_equal(distiller.sparsity(mask), 0.3, max_diff=0.0001) + + +def test_sensitivity_mask(): + # Create a 4-D tensor of normally-distributed coefficients + a = torch.randn(3, 64, 32, 32) + + # Create and apply a mask + mask = distiller.create_mask_sensitivity_criterion(a, sensitivity=1) + # The width of 1-std on ~N(0,1) is about 68.27%. In other words: + # Pr(mean - std <= X <= mean + std) is about 68.27% + assert common.almost_equal(distiller.sparsity(mask), 0.6827, max_diff=0.005) + + def test_kernel_thresholding(): p = get_test_4d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, '2D', 6, 'L1') @@ -240,6 +260,8 @@ def test_row_thresholding(): [ 0., 0., 0.], [ 1., 1., 1.], [ 1., 1., 1.]], device=mask.device)).all() + masked_tensor = distiller.mask_tensor(p, mask) + assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask) return mask