Skip to content
Snippets Groups Projects
Commit 97d5e48c authored by Neta Zmora's avatar Neta Zmora
Browse files

Filter pruning: rank filters by mean value of feature-map channels

A small change to support ranking weight filters by the mean mean-value
of the feature-map channels.
Mean mean-value refers to computing the average value (across many
input images) of the mean-value of each channel.
parent b476d028
No related branches found
No related tags found
No related merge requests found
...@@ -20,16 +20,22 @@ ...@@ -20,16 +20,22 @@
from .magnitude_pruner import MagnitudeParameterPruner from .magnitude_pruner import MagnitudeParameterPruner
from .automated_gradual_pruner import AutomatedGradualPruner, \ from .automated_gradual_pruner import AutomatedGradualPruner, \
L1RankedStructureParameterPruner_AGP, L2RankedStructureParameterPruner_AGP, \ L1RankedStructureParameterPruner_AGP, \
ActivationAPoZRankedFilterPruner_AGP, GradientRankedFilterPruner_AGP, \ L2RankedStructureParameterPruner_AGP, \
ActivationAPoZRankedFilterPruner_AGP, \
ActivationMeanRankedFilterPruner_AGP, \
GradientRankedFilterPruner_AGP, \
RandomRankedFilterPruner_AGP RandomRankedFilterPruner_AGP
from .level_pruner import SparsityLevelParameterPruner from .level_pruner import SparsityLevelParameterPruner
from .sensitivity_pruner import SensitivityPruner from .sensitivity_pruner import SensitivityPruner
from .splicing_pruner import SplicingPruner from .splicing_pruner import SplicingPruner
from .structure_pruner import StructureParameterPruner from .structure_pruner import StructureParameterPruner
from .ranked_structures_pruner import L1RankedStructureParameterPruner, L2RankedStructureParameterPruner, \ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \
L2RankedStructureParameterPruner, \
ActivationAPoZRankedFilterPruner, \ ActivationAPoZRankedFilterPruner, \
RandomRankedFilterPruner, GradientRankedFilterPruner ActivationMeanRankedFilterPruner, \
GradientRankedFilterPruner, \
RandomRankedFilterPruner
from .baidu_rnn_pruner import BaiduRNNPruner from .baidu_rnn_pruner import BaiduRNNPruner
from .greedy_filter_pruning import greedy_pruner from .greedy_filter_pruning import greedy_pruner
......
...@@ -120,6 +120,13 @@ class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP): ...@@ -120,6 +120,13 @@ class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP):
weights=weights, group_dependency=group_dependency) weights=weights, group_dependency=group_dependency)
class ActivationMeanRankedFilterPruner_AGP(StructuredAGP):
def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None):
assert group_type in ['3D', 'Filters']
super().__init__(name, initial_sparsity, final_sparsity)
self.pruner = ActivationMeanRankedFilterPruner(name, group_type, desired_sparsity=0,
weights=weights, group_dependency=group_dependency)
class GradientRankedFilterPruner_AGP(StructuredAGP): class GradientRankedFilterPruner_AGP(StructuredAGP):
def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None):
assert group_type in ['3D', 'Filters'] assert group_type in ['3D', 'Filters']
......
...@@ -341,17 +341,17 @@ def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, bin ...@@ -341,17 +341,17 @@ def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, bin
return expanded.view(param.shape), binary_map return expanded.view(param.shape), binary_map
class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner): class ActivationRankedFilterPruner(RankedStructureParameterPruner):
"""Uses mean APoZ (average percentage of zeros) activation channels to rank structures """Base class for pruners ranking convolution filters by some quality criterion of the
and prune a specified percentage of structures. corresponding feature-map channels (e.g. mean channel activation L1 value).
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures",
Hengyuan Hu, Rui Peng, Yu-Wing Tai, Chi-Keung Tang, ICLR 2016
https://arxiv.org/abs/1607.03250
""" """
def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
super().__init__(name, group_type, desired_sparsity, weights, group_dependency) super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
@property
def activation_rank_criterion(self):
raise NotImplementedError
def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
if fraction_to_prune == 0: if fraction_to_prune == 0:
return return
...@@ -368,11 +368,12 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner): ...@@ -368,11 +368,12 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner):
if module is None: if module is None:
raise ValueError("Could not find a layer named %s in the model." raise ValueError("Could not find a layer named %s in the model."
"\nMake sure to use assign_layer_fq_names()" % fq_name) "\nMake sure to use assign_layer_fq_names()" % fq_name)
if not hasattr(module, 'apoz_channels'): if not hasattr(module, self.activation_rank_criterion):
raise ValueError("Could not find attribute \'apoz_channels\' in module %s" raise ValueError("Could not find attribute \"{}\" in module %s"
"\nMake sure to use SummaryActivationStatsCollector(\"apoz_channels\")" % fq_name) "\nMake sure to use SummaryActivationStatsCollector(\"{}\")".
format(self.activation_rank_criterion, fq_name, self.activation_rank_criterion))
apoz, std = module.apoz_channels.value() quality_criterion, std = getattr(module, self.activation_rank_criterion).value()
num_filters = param.size(0) num_filters = param.size(0)
num_filters_to_prune = int(fraction_to_prune * num_filters) num_filters_to_prune = int(fraction_to_prune * num_filters)
if num_filters_to_prune == 0: if num_filters_to_prune == 0:
...@@ -380,8 +381,8 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner): ...@@ -380,8 +381,8 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner):
return return
# Sort from low to high, and remove the bottom 'num_filters_to_prune' filters # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters
filters_ordered_by_apoz = np.argsort(apoz)[:-num_filters_to_prune] filters_ordered_by_criterion = np.argsort(quality_criterion)[:-num_filters_to_prune]
mask, binary_map = mask_from_filter_order(filters_ordered_by_apoz, param, num_filters, binary_map) mask, binary_map = mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map)
zeros_mask_dict[param_name].mask = mask zeros_mask_dict[param_name].mask = mask
msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
...@@ -391,8 +392,33 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner): ...@@ -391,8 +392,33 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner):
return binary_map return binary_map
class ActivationAPoZRankedFilterPruner(ActivationRankedFilterPruner):
"""Uses mean APoZ (average percentage of zeros) activation channels to rank filters
and prune a specified percentage of filters.
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures,"
Hengyuan Hu, Rui Peng, Yu-Wing Tai, Chi-Keung Tang. ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
@property
def activation_rank_criterion(self):
return 'apoz_channels'
class ActivationMeanRankedFilterPruner(ActivationRankedFilterPruner):
"""Uses mean value of activation channels to rank filters and prune a specified percentage of filters.
"Pruning Convolutional Neural Networks for Resource Efficient Inference,"
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila, Jan Kautz. ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
@property
def activation_rank_criterion(self):
return 'mean_channels'
class RandomRankedFilterPruner(RankedStructureParameterPruner): class RandomRankedFilterPruner(RankedStructureParameterPruner):
"""A Random raanking of filters. """A Random ranking of filters.
This is used for sanity testing of other algorithms. This is used for sanity testing of other algorithms.
""" """
......
...@@ -704,6 +704,8 @@ def create_activation_stats_collectors(model, *phases): ...@@ -704,6 +704,8 @@ def create_activation_stats_collectors(model, *phases):
distiller.utils.activation_channels_l1), distiller.utils.activation_channels_l1),
"apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels", "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
distiller.utils.activation_channels_apoz), distiller.utils.activation_channels_apoz),
"mean_channels": SummaryActivationStatsCollector(model, "mean_channels",
distiller.utils.activation_channels_means),
"records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) "records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
}) })
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment