diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index dfb80e58bc4f59def418f3ebeb3b001a7fdc9ae2..bede4887de3ee225eb8d798df37a955b056b372e 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -18,12 +18,12 @@ from .pruner import _ParameterPruner from .level_pruner import SparsityLevelParameterPruner from .ranked_structures_pruner import * from distiller.utils import * -# import logging -# msglogger = logging.getLogger() +from functools import partial -class AutomatedGradualPruner(_ParameterPruner): - """Prune to an exact pruning level specification. +class AutomatedGradualPrunerBase(_ParameterPruner): + """Prune to an exact sparsity level specification using a prescribed sparsity + level schedule formula. An automated gradual pruning algorithm that prunes the smallest magnitude weights to achieve a preset level of network sparsity. @@ -34,23 +34,13 @@ class AutomatedGradualPruner(_ParameterPruner): (https://arxiv.org/pdf/1710.01878.pdf) """ - def __init__(self, name, initial_sparsity, final_sparsity, weights, - pruning_fn=None): - super(AutomatedGradualPruner, self).__init__(name) + def __init__(self, name, initial_sparsity, final_sparsity): + super().__init__(name) self.initial_sparsity = initial_sparsity self.final_sparsity = final_sparsity assert final_sparsity > initial_sparsity - self.params_names = weights - assert self.params_names - if pruning_fn is None: - self.pruning_fn = self.prune_to_target_sparsity - else: - self.pruning_fn = pruning_fn - - def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - if param_name not in self.params_names: - return + def compute_target_sparsity(self, meta): starting_epoch = meta['starting_epoch'] current_epoch = meta['current_epoch'] ending_epoch = meta['ending_epoch'] @@ -61,56 +51,79 @@ class AutomatedGradualPruner(_ParameterPruner): target_sparsity = (self.final_sparsity + (self.initial_sparsity-self.final_sparsity) * (1.0 - ((current_epoch-starting_epoch)/span))**3) - self.pruning_fn(param, param_name, zeros_mask_dict, target_sparsity, meta['model']) - @staticmethod - def prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model=None): - return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity) + return target_sparsity + def set_param_mask(self, param, param_name, zeros_mask_dict, meta): + target_sparsity = self.compute_target_sparsity(meta) + self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, meta['model']) -class CriterionParameterizedAGP(AutomatedGradualPruner): - def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): - self.reg_regims = reg_regims - weights = [weight for weight in reg_regims.keys()] - if not all([group in ['3D', 'Filters', 'Channels', 'Rows'] for group in reg_regims.values()]): - raise ValueError("Unsupported group structure") - super(CriterionParameterizedAGP, self).__init__(name, initial_sparsity, - final_sparsity, weights, - pruning_fn=self.prune_to_target_sparsity) + def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): + raise NotImplementedError - def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model): - if self.reg_regims[param_name] in ['3D', 'Filters']: - self.filters_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model) - elif self.reg_regims[param_name] == 'Channels': - self.channels_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model) - elif self.reg_regims[param_name] == 'Rows': - self.rows_ranking_fn(target_sparsity, param, param_name, zeros_mask_dict, model) +class AutomatedGradualPruner(AutomatedGradualPrunerBase): + """Fine-grained pruning with an AGP sparsity schedule. -# TODO: this class parameterization is cumbersome: the ranking functions (per structure) -# should come from the YAML schedule + An automated gradual pruning algorithm that prunes the smallest magnitude + weights to achieve a preset level of network sparsity. + """ + def __init__(self, name, initial_sparsity, final_sparsity, weights): + super().__init__(name, initial_sparsity, final_sparsity) + self.params_names = weights + assert self.params_names + + def set_param_mask(self, param, param_name, zeros_mask_dict, meta): + if param_name not in self.params_names: + return + super().set_param_mask(param, param_name, zeros_mask_dict, meta) -class L1RankedStructureParameterPruner_AGP(CriterionParameterizedAGP): - def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): - super(L1RankedStructureParameterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims) - self.filters_ranking_fn = L1RankedStructureParameterPruner.rank_prune_filters - self.channels_ranking_fn = L1RankedStructureParameterPruner.rank_prune_channels - self.rows_ranking_fn = L1RankedStructureParameterPruner.rank_prune_rows + def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model=None): + return SparsityLevelParameterPruner.prune_level(param, param_name, zeros_mask_dict, target_sparsity) -class ActivationAPoZRankedFilterPruner_AGP(CriterionParameterizedAGP): - def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): - super(ActivationAPoZRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims) - self.filters_ranking_fn = ActivationAPoZRankedFilterPruner.rank_prune_filters +class StructuredAGP(AutomatedGradualPrunerBase): + """Structured pruning with an AGP sparsity schedule. + This is a base-class for structured pruning with an AGP schedule. It is an + extension of the AGP concept introduced by Zhu et. al. + """ + def __init__(self, name, initial_sparsity, final_sparsity): + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = None -class GradientRankedFilterPruner_AGP(CriterionParameterizedAGP): - def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): - super(GradientRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims) - self.filters_ranking_fn = GradientRankedFilterPruner.rank_prune_filters + def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model): + self.pruner.prune_to_target_sparsity(param, param_name, zeros_mask_dict, target_sparsity, model) -class RandomRankedFilterPruner_AGP(CriterionParameterizedAGP): - def __init__(self, name, initial_sparsity, final_sparsity, reg_regims): - super(RandomRankedFilterPruner_AGP, self).__init__(name, initial_sparsity, final_sparsity, reg_regims) - self.filters_ranking_fn = RandomRankedFilterPruner.rank_prune_filters +# TODO: this class parameterization is cumbersome: the ranking functions (per structure) +# should come from the YAML schedule +class L1RankedStructureParameterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0, + weights=weights, group_dependency=group_dependency) + + +class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + assert group_type in ['3D', 'Filters'] + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = ActivationAPoZRankedFilterPruner(name, group_type, desired_sparsity=0, + weights=weights, group_dependency=group_dependency) + + +class GradientRankedFilterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + assert group_type in ['3D', 'Filters'] + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = GradientRankedFilterPruner(name, group_type, desired_sparsity=0, + weights=weights, group_dependency=group_dependency) + + +class RandomRankedFilterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + assert group_type in ['3D', 'Filters'] + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = RandomRankedFilterPruner(name, group_type, desired_sparsity=0, + weights=weights, group_dependency=group_dependency) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 8697d231295bb542fb3128672533f488f9173394..82d2c3bb480f1098d72365880bac694e01047391 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -22,103 +22,166 @@ from .pruner import _ParameterPruner msglogger = logging.getLogger() -# TODO: support different policies for ranking structures class RankedStructureParameterPruner(_ParameterPruner): - """Uses mean L1-norm to rank structures and prune a specified percentage of structures + """Base class for pruning structures by ranking them. """ - def __init__(self, name, reg_regims): + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): super().__init__(name) - self.reg_regims = reg_regims + self.group_type = group_type + self.group_dependency = group_dependency + self.params_names = weights + assert self.params_names + self.leader_binary_map = None + self.last_target_sparsity = None + self.desired_sparsity = desired_sparsity + def leader(self): + # The "leader" is the first weights-tensor in the list + return self.params_names[0] -class L1RankedStructureParameterPruner(RankedStructureParameterPruner): - """Uses mean L1-norm to rank structures and prune a specified percentage of structures - """ - def __init__(self, name, reg_regims): - super().__init__(name, reg_regims) + def is_supported(self, param_name): + return param_name in self.params_names + + def fraction_to_prune(self, param_name): + return self.desired_sparsity def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - if param_name not in self.reg_regims.keys(): + if not self.is_supported(param_name): return - - group_type = self.reg_regims[param_name][1] - fraction_to_prune = self.reg_regims[param_name][0] - if fraction_to_prune == 0: + fraction_to_prune = self.fraction_to_prune(param_name) + try: + model = meta['model'] + except TypeError: + model = None + return self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, fraction_to_prune, model) + + def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model): + if not self.is_supported(param_name): return - if group_type in ['3D', 'Filters']: - return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict) - elif group_type == 'Channels': - return self.rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict) - elif group_type == 'Rows': - return self.rank_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict) - else: - raise ValueError("Currently only filter (3D) and channel ranking is supported") + binary_map = None + if self.group_dependency == "Leader": + if target_sparsity != self.last_target_sparsity: + # Each time we change the target sparsity we need to compute and cache the leader's binary-map. + # We don't have control over the order that this function is invoked, so the only indication that + # we need to compute a new leader binary-map is the change of the target_sparsity. + self.last_target_sparsity = target_sparsity + self.leader_binary_map = self.prune_group(target_sparsity, model.state_dict()[self.leader()], + self.leader(), zeros_mask_dict=None) + assert self.leader_binary_map is not None + binary_map = self.leader_binary_map + # Delegate the actual pruning to a sub-class + self.prune_group(target_sparsity, param, param_name, zeros_mask_dict, model, binary_map) - @staticmethod - def rank_channels(fraction_to_prune, param): - num_filters = param.size(0) - num_channels = param.size(1) - kernel_size = param.size(2) * param.size(3) - - # First, reshape the weights tensor such that each channel (kernel) in the original - # tensor, is now a row in the 2D tensor. - view_2d = param.view(-1, kernel_size) - # Next, compute the sums of each kernel - kernel_sums = view_2d.abs().sum(dim=1) - # Now group by channels - k_sums_mat = kernel_sums.view(num_filters, num_channels).t() - channel_mags = k_sums_mat.mean(dim=1) - k = int(fraction_to_prune * channel_mags.size(0)) - if k == 0: - msglogger.info("Too few channels (%d)- can't prune %.1f%% channels", - num_channels, 100*fraction_to_prune) - return None, None - - bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True) - return bottomk, channel_mags + def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): + raise NotImplementedError - @staticmethod - def rank_prune_channels(fraction_to_prune, param, param_name, zeros_mask_dict, model=None): - bottomk_channels, channel_mags = L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param) - if bottomk_channels is None: - # Empty list means that fraction_to_prune is too low to prune anything - return - num_filters = param.size(0) - num_channels = param.size(1) +class L1RankedStructureParameterPruner(RankedStructureParameterPruner): + """Uses mean L1-norm to rank and prune structures. + + This class prunes to a prescribed percentage of structured-sparsity (level pruning). + """ + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): + super().__init__(name, group_type, desired_sparsity, weights, group_dependency) + if group_type not in ['3D', 'Filters', 'Channels', 'Rows']: + raise ValueError("Structure {} was requested but" + "currently only filter (3D) and channel ranking is supported". + format(group_type)) + + def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): + if fraction_to_prune == 0: + return + if self.group_type in ['3D', 'Filters']: + group_pruning_fn = self.rank_and_prune_filters + elif self.group_type == 'Channels': + group_pruning_fn = self.rank_and_prune_channels + elif self.group_type == 'Rows': + group_pruning_fn = self.rank_and_prune_rows - threshold = bottomk_channels[-1] - binary_map = channel_mags.gt(threshold).type(param.data.type()) - a = binary_map.expand(num_filters, num_channels) - c = a.unsqueeze(-1) - d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous() - zeros_mask_dict[param_name].mask = d.view(num_filters, num_channels, param.size(2), param.size(3)) + binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + return binary_map - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, - distiller.sparsity_ch(zeros_mask_dict[param_name].mask), - fraction_to_prune, len(bottomk_channels), num_channels) + @staticmethod + def rank_and_prune_channels(fraction_to_prune, param, param_name=None, + zeros_mask_dict=None, model=None, binary_map=None): + def rank_channels(fraction_to_prune, param): + num_filters = param.size(0) + num_channels = param.size(1) + kernel_size = param.size(2) * param.size(3) + + # First, reshape the weights tensor such that each channel (kernel) in the original + # tensor, is now a row in the 2D tensor. + view_2d = param.view(-1, kernel_size) + # Next, compute the sums of each kernel + kernel_sums = view_2d.abs().sum(dim=1) + # Now group by channels + k_sums_mat = kernel_sums.view(num_filters, num_channels).t() + channel_mags = k_sums_mat.mean(dim=1) + k = int(fraction_to_prune * channel_mags.size(0)) + if k == 0: + msglogger.info("Too few channels (%d)- can't prune %.1f%% channels", + num_channels, 100*fraction_to_prune) + return None, None + + bottomk, _ = torch.topk(channel_mags, k, largest=False, sorted=True) + return bottomk, channel_mags + + def binary_map_to_mask(binary_map, param): + num_filters = param.size(0) + num_channels = param.size(1) + a = binary_map.expand(num_filters, num_channels) + c = a.unsqueeze(-1) + d = c.expand(num_filters, num_channels, param.size(2) * param.size(3)).contiguous() + return d.view(num_filters, num_channels, param.size(2), param.size(3)) + + if binary_map is None: + bottomk_channels, channel_mags = rank_channels(fraction_to_prune, param) + if bottomk_channels is None: + # Empty list means that fraction_to_prune is too low to prune anything + return + threshold = bottomk_channels[-1] + binary_map = channel_mags.gt(threshold).type(param.data.type()) + + if zeros_mask_dict is not None: + zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param) + msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, + distiller.sparsity_ch(zeros_mask_dict[param_name].mask), + fraction_to_prune, binary_map.sum().item(), param.size(1)) + return binary_map @staticmethod - def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None): + def rank_and_prune_filters(fraction_to_prune, param, param_name, + zeros_mask_dict, model=None, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" - # First we rank the filters - view_filters = param.view(param.size(0), -1) - filter_mags = view_filters.data.abs().mean(dim=1) - topk_filters = int(fraction_to_prune * filter_mags.size(0)) - if topk_filters == 0: - msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) - return - bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) - threshold = bottomk[-1] + + threshold = None + if binary_map is None: + # First we rank the filters + view_filters = param.view(param.size(0), -1) + filter_mags = view_filters.data.abs().mean(dim=1) + topk_filters = int(fraction_to_prune * filter_mags.size(0)) + if topk_filters == 0: + msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) + return + bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) + threshold = bottomk[-1] + msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)", + param_name, + topk_filters, filter_mags.size(0)) # Then we threshold - zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs') - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, - distiller.sparsity(zeros_mask_dict[param_name].mask), - fraction_to_prune, topk_filters, filter_mags.size(0)) + mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map) + if zeros_mask_dict is not None: + zeros_mask_dict[param_name].mask = mask + msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", + param_name, + distiller.sparsity(mask), + fraction_to_prune) + return binary_map @staticmethod - def rank_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None): + def rank_and_prune_rows(fraction_to_prune, param, param_name, + zeros_mask_dict, model=None, binary_map=None): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is @@ -145,36 +208,15 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) -class RankedFiltersParameterPruner(RankedStructureParameterPruner): - """Base class for the special (but often-used) case of ranking filters - """ - def __init__(self, name, reg_regims): - super().__init__(name, reg_regims) - - def set_param_mask(self, param, param_name, zeros_mask_dict, meta): - if param_name not in self.reg_regims.keys(): - return - - group_type = self.reg_regims[param_name][1] - fraction_to_prune = self.reg_regims[param_name][0] - if fraction_to_prune == 0: - return - - if group_type in ['3D', 'Filters']: - return self.rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, meta['model']) - else: - raise ValueError("Currently only filter (3D) ranking is supported") - - @staticmethod - def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters): +def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map): + if binary_map is None: binary_map = torch.zeros(num_filters).cuda() binary_map[filters_ordered_by_criterion] = 1 - #msglogger.info("binary_map: {}".format(binary_map)) - expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() - return expanded.view(param.shape) + expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() + return expanded.view(param.shape), binary_map -class ActivationAPoZRankedFilterPruner(RankedFiltersParameterPruner): +class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner): """Uses mean APoZ (average percentage of zeros) activation channels to rank structures and prune a specified percentage of structures. @@ -182,11 +224,16 @@ class ActivationAPoZRankedFilterPruner(RankedFiltersParameterPruner): Hengyuan Hu, Rui Peng, Yu-Wing Tai, Chi-Keung Tang, ICLR 2016 https://arxiv.org/abs/1607.03250 """ - def __init__(self, name, reg_regims): - super().__init__(name, reg_regims) + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): + super().__init__(name, group_type, desired_sparsity, weights, group_dependency) - @staticmethod - def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model): + def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): + if fraction_to_prune == 0: + return + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + return binary_map + + def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" # Use the parameter name to locate the module that has the activation sparsity statistics @@ -208,25 +255,31 @@ class ActivationAPoZRankedFilterPruner(RankedFiltersParameterPruner): # Sort from low to high, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_apoz = np.argsort(apoz)[:-num_filters_to_prune] - zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_apoz, - param, num_filters) + mask, binary_map = mask_from_filter_order(filters_ordered_by_apoz, param, num_filters, binary_map) + zeros_mask_dict[param_name].mask = mask msglogger.info("ActivationL1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters) + return binary_map -class RandomRankedFilterPruner(RankedFiltersParameterPruner): +class RandomRankedFilterPruner(RankedStructureParameterPruner): """A Random raanking of filters. This is used for sanity testing of other algorithms. """ - def __init__(self, name, reg_regims): - super().__init__(name, reg_regims) + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): + super().__init__(name, group_type, desired_sparsity, weights, group_dependency) - @staticmethod - def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model): + def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): + if fraction_to_prune == 0: + return + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + return binary_map + + def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) @@ -236,23 +289,29 @@ class RandomRankedFilterPruner(RankedFiltersParameterPruner): return filters_ordered_randomly = np.random.permutation(num_filters)[:-num_filters_to_prune] - zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_randomly, - param, num_filters) + mask, binary_map = mask_from_filter_order(filters_ordered_randomly, param, num_filters) + zeros_mask_dict[param_name].mask = mask msglogger.info("RandomRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters) + return binary_map -class GradientRankedFilterPruner(RankedFiltersParameterPruner): +class GradientRankedFilterPruner(RankedStructureParameterPruner): """ """ - def __init__(self, name, reg_regims): - super().__init__(name, reg_regims) + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): + super().__init__(name, group_type, desired_sparsity, weights, group_dependency) - @staticmethod - def rank_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model): + def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): + if fraction_to_prune == 0: + return + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + return binary_map + + def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) @@ -268,9 +327,11 @@ class GradientRankedFilterPruner(RankedFiltersParameterPruner): # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune] - zeros_mask_dict[param_name].mask = RankedFiltersParameterPruner.mask_from_filter_order(filters_ordered_by_gradient, - param, num_filters) + mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters) + zeros_mask_dict[param_name].mask = mask + msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity_3D(zeros_mask_dict[param_name].mask), fraction_to_prune, num_filters_to_prune, num_filters) + return binary_map diff --git a/distiller/sensitivity.py b/distiller/sensitivity.py index 8db2773f63cacb5cf31f5db9a39fbc0e2178dc8a..591f50977239d7e02eeeb413879f681ef3fe5c19 100755 --- a/distiller/sensitivity.py +++ b/distiller/sensitivity.py @@ -37,6 +37,7 @@ from .scheduler import CompressionScheduler msglogger = logging.getLogger() + def perform_sensitivity_analysis(model, net_params, sparsities, test_func, group): """Perform a sensitivity test for a model's weights parameters. @@ -86,19 +87,23 @@ def perform_sensitivity_analysis(model, net_params, sparsities, test_func, group if group == 'element': # Element-wise sparasity sparsity_levels = {param_name: sparsity_level} - pruner = distiller.pruning.SparsityLevelParameterPruner(name='sensitivity', levels=sparsity_levels) + pruner = distiller.pruning.SparsityLevelParameterPruner(name="sensitivity", levels=sparsity_levels) elif group == 'filter': # Filter ranking if model.state_dict()[param_name].dim() != 4: continue - regims = {param_name: [sparsity_level, '3D']} - pruner = distiller.pruning.L1RankedStructureParameterPruner(name='sensitivity', reg_regims=regims) + pruner = distiller.pruning.L1RankedStructureParameterPruner("sensitivity", + group_type="Filters", + desired_sparsity=sparsity_level, + weights=param_name) elif group == 'channel': # Filter ranking if model.state_dict()[param_name].dim() != 4: continue - regims = {param_name: [sparsity_level, 'Channels']} - pruner = distiller.pruning.L1RankedStructureParameterPruner(name='sensitivity', reg_regims=regims) + pruner = distiller.pruning.L1RankedStructureParameterPruner("sensitivity", + group_type="Channels", + desired_sparsity=sparsity_level, + weights=param_name) policy = distiller.PruningPolicy(pruner, pruner_args=None) scheduler = CompressionScheduler(model_cpy) diff --git a/distiller/thinning.py b/distiller/thinning.py index ab6cff157a1195cc068659eda6083cd37efd4f54..ad3ff26d3fb0909901b0021b2de70dadc48b9039 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -34,7 +34,7 @@ import distiller from distiller import normalize_module_name, denormalize_module_name from apputils import SummaryGraph from models import create_model -msglogger = logging.getLogger() +msglogger = logging.getLogger(__name__) ThinningRecipe = namedtuple('ThinningRecipe', ['modules', 'parameters']) """A ThinningRecipe is composed of two sets of instructions. @@ -57,6 +57,7 @@ These tuples can have 2 values, or 4 values. """ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', + 'StructureRemover', 'ChannelRemover', 'remove_channels', 'FilterRemover', 'remove_filters', 'find_nonzero_channels', 'find_nonzero_channels_list', @@ -79,20 +80,41 @@ def param_name_2_layer_name(param_name): return param_name[:-len('weights')] +def directives_equal(d1, d2): + """Test if two directives are equal""" + if len(d1) != len(d2): + return False + if len(d1) == 2: + return d1[0] == d2[0] and torch.equal(d1[1], d2[1]) + if len(d1) == 4: + e = all(d1[i] == d2[i] for i in (0, 2, 3)) and torch.equal(d1[1], d2[1]) + msglogger.info("{}: \n{}\n{}".format(e, d1, d2)) + return e + assert ValueError("Unsupported directive length") + + def append_param_directive(thinning_recipe, param_name, directive): - param_directive = thinning_recipe.parameters.get(param_name, []) - param_directive.append(directive) - thinning_recipe.parameters[param_name] = param_directive + param_directives = thinning_recipe.parameters.get(param_name, []) + for d in param_directives: + # Duplicate parameter directives are rooted out because they can create erronous conditions. + # For example, if the first directive changes the change of the parameter, a second + # directive will cause an exception. + if directives_equal(d, directive): + return + msglogger.debug("\t[recipe] param_directive for {} = {}".format(param_name, directive)) + param_directives.append(directive) + thinning_recipe.parameters[param_name] = param_directives def append_module_directive(model, thinning_recipe, module_name, key, val): + msglogger.debug("\t[recipe] setting {}.{} = {}".format(module_name, key, val)) module_name = denormalize_module_name(model, module_name) mod_directive = thinning_recipe.modules.get(module_name, {}) mod_directive[key] = val thinning_recipe.modules[module_name] = mod_directive -def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_features): +def append_bn_thinning_directive(thinning_recipe, layers, bn_name, len_thin_features, thin_features): """Adjust the sizes of the parameters of a BatchNormalization layer This function is invoked after the Convolution layer preceeding a BN layer has changed dimensions (filters or channels were removed), and the BN layer also @@ -100,6 +122,7 @@ def bn_thinning(thinning_recipe, layers, bn_name, len_thin_features, thin_featur """ bn_module = layers[bn_name] assert isinstance(bn_module, torch.nn.modules.batchnorm.BatchNorm2d) + msglogger.debug("\t[recipe] bn_thinning {}".format(bn_name)) bn_directive = thinning_recipe.modules.get(bn_name, {}) bn_directive['num_features'] = len_thin_features @@ -248,12 +271,13 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used) predecessors = [normalize_module_name(predecessor) for predecessor in predecessors] if len(predecessors) == 0: - msglogger.info("Could not find predecessors for name={} normal={} {}".format(layer_name, normalize_module_name(layer_name), denormalize_module_name(model, layer_name))) + msglogger.info("Could not find predecessors for name={} normal={} {}".format( + layer_name, normalize_module_name(layer_name), denormalize_module_name(model, layer_name))) for predecessor in predecessors: # For each of the convolutional layers that preceed, we have to reduce the number of output channels. append_module_directive(model, thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) - # Now remove channels from the weights tensor of the successor conv + # Now remove channels from the weights tensor of the predecessor conv append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.weight', (0, indices)) if layers[denormalize_module_name(model, predecessor)].bias is not None: @@ -263,11 +287,16 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): # Now handle the BatchNormalization layer that follows the convolution bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization']) if len(bn_layers) > 0: - assert len(bn_layers) == 1 - # Thinning of the BN layer that follows the convolution - bn_layer_name = denormalize_module_name(model, bn_layers[0]) - bn_thinning(thinning_recipe, layers, bn_layer_name, - len_thin_features=num_nnz_channels, thin_features=indices) + # if len(bn_layers) != 1: + # raise RuntimeError("{} should have exactly one BN predecessors, but has {}".format(layer_name, len(bn_layers))) + for bn_layer in bn_layers: + # Thinning of the BN layer that follows the convolution + bn_layer_name = denormalize_module_name(model, bn_layer) + msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer_name)) + append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name, + len_thin_features=num_nnz_channels, thin_features=indices) + + msglogger.debug(thinning_recipe) return thinning_recipe @@ -300,7 +329,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name) # If there are non-zero filters in this tensor then continue to next tensor if num_filters <= num_nnz_filters: - msglogger.debug("SKipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape)) + msglogger.debug("Skipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape)) continue msglogger.info("In tensor %s found %d/%d zero filters", param_name, @@ -330,7 +359,6 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): if isinstance(layers[successor], torch.nn.modules.Conv2d): # For each of the convolutional layers that follow, we have to reduce the number of input channels. append_module_directive(model, thinning_recipe, successor, key='in_channels', val=num_nnz_filters) - msglogger.debug("[recipe] {}: setting in_channels = {}".format(successor, num_nnz_filters)) # Now remove channels from the weights tensor of the successor conv append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices)) @@ -340,6 +368,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): fm_size = layers[successor].in_features // layers[layer_name].out_channels in_features = fm_size * num_nnz_filters append_module_directive(model, thinning_recipe, successor, key='in_features', val=in_features) + msglogger.debug("[recipe] Linear {}: fm_size = {} layers[{}].out_channels={}".format( + successor, in_features, layer_name, layers[layer_name].out_channels)) msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features)) # Now remove channels from the weights tensor of the successor FC layer: @@ -347,7 +377,9 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): fm_height = fm_width = int(math.sqrt(fm_size)) view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width) view_2D = (layers[successor].out_features, in_features) - append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices, view_4D, view_2D)) + append_param_directive(thinning_recipe, + denormalize_module_name(model, successor)+'.weight', + (1, indices, view_4D, view_2D)) # Now handle the BatchNormalization layer that follows the convolution bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization']) @@ -355,23 +387,12 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): assert len(bn_layers) == 1 # Thinning of the BN layer that follows the convolution bn_layer_name = denormalize_module_name(model, bn_layers[0]) - bn_thinning(thinning_recipe, layers, bn_layer_name, - len_thin_features=num_nnz_filters, thin_features=indices) + append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name, + len_thin_features=num_nnz_filters, thin_features=indices) return thinning_recipe -class ChannelRemover(ScheduledTrainingPolicy): - """A policy which applies a network thinning function""" - def __init__(self, thinning_func_str, arch, dataset): - self.thinning_func = globals()[thinning_func_str] - self.arch = arch - self.dataset = dataset - - def on_epoch_end(self, model, zeros_mask_dict, meta): - self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, meta.get('optimizer', None)) - - -class FilterRemover(ScheduledTrainingPolicy): +class StructureRemover(ScheduledTrainingPolicy): """A policy which applies a network thinning function""" def __init__(self, thinning_func_str, arch, dataset): self.thinning_func = globals()[thinning_func_str] @@ -404,6 +425,11 @@ class FilterRemover(ScheduledTrainingPolicy): self.done = False +# For backward-compatibility with some of the scripts, we assign aliases to StructureRemover +FilterRemover = StructureRemover +ChannelRemover = StructureRemover + + def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list): # Invoke this function when you want to use a list of thinning recipes to convert a programmed model # to a thinned model. For example, this is invoked when loading a model from a checkpoint. @@ -414,7 +440,7 @@ def execute_thinning_recipes_list(model, zeros_mask_dict, recipe_list): def optimizer_thinning(optimizer, param, dim, indices, new_shape=None): - """Adjust the size of the SGD vecolity-tracking tensors. + """Adjust the size of the SGD velocity-tracking tensors. The SGD momentum update (velocity) is dependent on the weights, and because during thinning we dynamically change the weights shapes, we need to make the apporpriate changes in the Optimizer, @@ -458,7 +484,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr # Check if we're trying to trim a parameter that is already "thin" if running.size(dim_to_trim) != indices_to_select.nelement(): msglogger.debug("[thinning] {}: setting {} to {}". - format(layer_name, attr, indices_to_select.nelement())) + format(layer_name, attr, indices_to_select.nelement())) setattr(layers[layer_name], attr, torch.index_select(running, dim=dim_to_trim, index=indices_to_select)) else: @@ -468,6 +494,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr assert len(recipe.parameters) > 0 for param_name, param_directives in recipe.parameters.items(): + msglogger.debug("{} : {}".format(param_name, param_directives)) param = distiller.model_find_param(model, param_name) assert param is not None for directive in param_directives: @@ -475,6 +502,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr indices = directive[1] len_indices = indices.nelement() if len(directive) == 4: # TODO: this code is hard to follow + msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2]))) selection_view = param.view(*directive[2]) # Check if we're trying to trim a parameter that is already "thin" if param.data.size(dim) != len_indices: diff --git a/distiller/thresholding.py b/distiller/thresholding.py index 84ec73c55f28777e69df11e61543f36a2f0f64c7..da2c3b629446b721aecae93751f566fb2f35dc0a 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -43,7 +43,7 @@ class GroupThresholdMixin(object): return group_threshold_mask(param, group_type, threshold, threshold_criteria) -def group_threshold_mask(param, group_type, threshold, threshold_criteria): +def group_threshold_binary_map(param, group_type, threshold, threshold_criteria): """Return a threshold mask for the provided parameter and group type. Args: @@ -67,31 +67,26 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria): # elements in each channel as the threshold filter. # 3. Apply the threshold filter binary_map = threshold_policy(view_2d, thresholds, threshold_criteria) - - # 3. Finally, expand the thresholds and view as a 4D tensor - a = binary_map.expand(param.size(2) * param.size(3), - param.size(0) * param.size(1)).t() - return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) + return binary_map elif group_type == 'Rows': assert param.dim() == 2, "This regularization is only supported for 2D weights" thresholds = torch.Tensor([threshold] * param.size(0)).cuda() binary_map = threshold_policy(param, thresholds, threshold_criteria) - return binary_map.expand(param.size(1), param.size(0)).t() + return binary_map elif group_type == 'Cols': assert param.dim() == 2, "This regularization is only supported for 2D weights" thresholds = torch.Tensor([threshold] * param.size(1)).cuda() binary_map = threshold_policy(param, thresholds, threshold_criteria, dim=0) - return binary_map.expand(param.size(0), param.size(1)) + return binary_map elif group_type == '3D' or group_type == 'Filters': assert param.dim() == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) thresholds = torch.Tensor([threshold] * param.size(0)).cuda() binary_map = threshold_policy(view_filters, thresholds, threshold_criteria) - a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t() - return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) + return binary_map elif group_type == '4D': assert param.dim() == 4, "This thresholding is only supported for 4D weights" @@ -116,6 +111,65 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria): k_means_mat = kernel_means.view(num_filters, num_kernels_per_filter).t() thresholds = torch.Tensor([threshold] * num_kernels_per_filter).cuda() binary_map = k_means_mat.data.mean(dim=1).gt(thresholds).type(param.type()) + return binary_map + + +def group_threshold_mask(param, group_type, threshold, threshold_criteria, binary_map=None): + """Return a threshold mask for the provided parameter and group type. + + Args: + param: The parameter to mask + group_type: The elements grouping type (structure). + One of:2D, 3D, 4D, Channels, Row, Cols + threshold: The threshold + threshold_criteria: The thresholding criteria. + 'Mean_Abs' thresholds the entire element group using the mean of the + absolute values of the tensor elements. + 'Max' thresholds the entire group using the magnitude of the largest + element in the group. + """ + if group_type == '2D': + if binary_map is None: + binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) + + # 3. Finally, expand the thresholds and view as a 4D tensor + a = binary_map.expand(param.size(2) * param.size(3), + param.size(0) * param.size(1)).t() + return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) + + elif group_type == 'Rows': + if binary_map is None: + binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) + return binary_map.expand(param.size(1), param.size(0)).t() + + elif group_type == 'Cols': + if binary_map is None: + binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) + return binary_map.expand(param.size(0), param.size(1)) + + elif group_type == '3D' or group_type == 'Filters': + if binary_map is None: + binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) + a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t() + return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map + + elif group_type == '4D': + assert param.dim() == 4, "This thresholding is only supported for 4D weights" + if threshold_criteria == 'Mean_Abs': + if param.data.abs().mean() > threshold: + return None + return torch.zeros_like(param.data) + elif threshold_criteria == 'Max': + if param.data.abs().max() > threshold: + return None + return torch.zeros_like(param.data) + exit("Invalid threshold_criteria {}".format(threshold_criteria)) + + elif group_type == 'Channels': + if binary_map is None: + binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) + num_filters = param.size(0) + num_kernels_per_filter = param.size(1) # Now let's expand back up to a 4D mask a = binary_map.expand(num_filters, num_kernels_per_filter) diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml index 88c5fc53f36d9e77a3e3bda3773881e783d66325..708c1417038ca71c91107693e1ee1b5598d59ab3 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml @@ -71,22 +71,11 @@ pruners: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.50 - reg_regims: - module.layer2.0.conv1.weight: Filters - - module.layer2.0.conv2.weight: Filters - module.layer2.0.downsample.0.weight: Filters - module.layer2.1.conv2.weight: Filters - module.layer2.2.conv2.weight: Filters # to balance the BN - - module.layer2.1.conv1.weight: Filters - module.layer2.2.conv1.weight: Filters - - #module.layer3.0.conv2.weight: Filters - #module.layer3.0.downsample.0.weight: Filters - #module.layer3.1.conv2.weight: Filters - #module.layer3.2.conv2.weight: Filters - + group_type: Filters + weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight, + module.layer2.0.downsample.0.weight, + module.layer2.1.conv2.weight, module.layer2.2.conv2.weight, + module.layer2.1.conv1.weight, module.layer2.2.conv1.weight] fine_pruner: class: AutomatedGradualPruner @@ -99,8 +88,9 @@ pruners: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.05 final_sparsity: 0.50 - reg_regims: - module.fc.weight: Rows + group_type: Rows + weights: [module.fc.weight] + lr_schedulers: pruning_lr: @@ -135,15 +125,6 @@ policies: ending_epoch: 230 frequency: 2 - # Currently the thinner is disabled until the the structure pruner is done, because it interacts - # with the sparsity goals of the L1RankedStructureParameterPruner_AGP. - # This can be fixed rather easily. - # - extension: - # instance_name: net_thinner - # starting_epoch: 0 - # ending_epoch: 20 - # frequency: 2 - # After completeing the pruning, we perform network thinning and continue fine-tuning. - extension: instance_name: net_thinner diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml index 730dd9a830a427f9e39a10dcaaf8d91ebdb6e45c..4c9d3b5005a40a7cf74a1f7d69323e221f80daf8 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_2.yaml @@ -6,13 +6,15 @@ # Baseline results: # Top1: 91.780 Top5: 99.710 Loss: 0.376 # Total MACs: 40,813,184 +# # of parameters: 270,896 # # Results: # Top1: 91.200 Top5: 99.660 Loss: 1.551 # Total MACs: 30,638,720 # Total sparsity: 41.84 +# # of parameters: 143,488 (=53% of the baseline parameters) # -# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar # # # Parameters: @@ -69,13 +71,13 @@ pruners: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.40 - reg_regims: - module.layer1.0.conv1.weight: Filters - module.layer1.1.conv1.weight: Filters - module.layer1.2.conv1.weight: Filters - module.layer2.0.conv1.weight: Filters - module.layer2.1.conv1.weight: Filters - module.layer2.2.conv1.weight: Filters + group_type: Filters + weights: [module.layer1.0.conv1.weight, + module.layer1.1.conv1.weight, + module.layer1.2.conv1.weight, + module.layer2.0.conv1.weight, + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight] fine_pruner: class: AutomatedGradualPruner @@ -110,14 +112,6 @@ policies: starting_epoch: 200 ending_epoch: 220 frequency: 2 - # Currently the thinner is disabled until the end, because it interacts with the sparsity - # goals of the L1RankedStructureParameterPruner_AGP. - # This can be fixed rather easily. - # - extension: - # instance_name: net_thinner - # starting_epoch: 0 - # ending_epoch: 20 - # frequency: 2 # After completeing the pruning, we perform network thinning and continue fine-tuning. - extension: diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml index 6114552909aa1669e9fe242bd199a239d90d5687..9ac92a2e71f30d6dad2dcabb62e7bc91ce75fdab 100755 --- a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml @@ -70,21 +70,15 @@ pruners: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.50 - reg_regims: - module.layer2.0.conv1.weight: Filters + group_type: Filters + weights: [module.layer2.0.conv1.weight, + module.layer2.0.conv2.weight, + module.layer2.0.downsample.0.weight, + module.layer2.1.conv2.weight, + module.layer2.2.conv2.weight, # to balance the BN + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight] - module.layer2.0.conv2.weight: Filters - module.layer2.0.downsample.0.weight: Filters - module.layer2.1.conv2.weight: Filters - module.layer2.2.conv2.weight: Filters # to balance the BN - - module.layer2.1.conv1.weight: Filters - module.layer2.2.conv1.weight: Filters - - #module.layer3.0.conv2.weight: Filters - #module.layer3.0.downsample.0.weight: Filters - #module.layer3.1.conv2.weight: Filters - #module.layer3.2.conv2.weight: Filters fine_pruner1: class: AutomatedGradualPruner initial_sparsity : 0.05 @@ -102,8 +96,8 @@ pruners: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.05 final_sparsity: 0.50 - reg_regims: - module.fc.weight: Rows + group_type: Rows + weights: [module.fc.weight] lr_schedulers: pruning_lr: diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml new file mode 100755 index 0000000000000000000000000000000000000000..816ac043852d6142fdc8571682d3823338e9c7cd --- /dev/null +++ b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml @@ -0,0 +1,154 @@ +# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling: +# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers. +# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors. +# 3. Row pruning for the last linear (fully-connected) layer. +# +# Baseline results: +# Top1: 91.780 Top5: 99.710 Loss: 0.376 +# Total MACs: 40,813,184 +# # of parameters: 270,896 +# +# Results: +# Top1: 90.86 +# Total MACs: 24,723,776 (=1.65x MACs) +# Total sparsity: 39.66 +# # of parameters: 78,776 (=29.1% of the baseline parameters) +# +# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-size=0 +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.43875 | -0.01073 | 0.30863 | +# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16813 | -0.01061 | 0.11996 | +# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16491 | 0.00070 | 0.12291 | +# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14755 | -0.01732 | 0.11317 | +# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13875 | -0.00517 | 0.10758 | +# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18923 | -0.01326 | 0.14210 | +# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15583 | -0.00106 | 0.11905 | +# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18225 | -0.00390 | 0.14086 | +# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19739 | -0.00785 | 0.15657 | +# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.34375 | -0.01899 | 0.24210 | +# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13369 | -0.00798 | 0.10439 | +# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10876 | -0.00311 | 0.08349 | +# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13521 | -0.01331 | 0.09820 | +# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09794 | 0.00274 | 0.07076 | +# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14004 | -0.00417 | 0.11131 | +# | 15 | module.layer3.0.conv2.weight | (32, 64, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12559 | -0.00528 | 0.09919 | +# | 16 | module.layer3.0.downsample.0.weight | (32, 16, 1, 1) | 512 | 512 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.20064 | 0.00941 | 0.15694 | +# | 17 | module.layer3.1.conv1.weight | (64, 32, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 7.76367 | 1.56250 | 69.99783 | 0.10111 | -0.00460 | 0.04898 | +# | 18 | module.layer3.1.conv2.weight | (32, 64, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 1.56250 | 8.83789 | 0.00000 | 69.99783 | 0.09000 | -0.00652 | 0.04327 | +# | 19 | module.layer3.2.conv1.weight | (64, 32, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 9.37500 | 0.00000 | 69.99783 | 0.09371 | -0.00490 | 0.04521 | +# | 20 | module.layer3.2.conv2.weight | (32, 64, 3, 3) | 18432 | 5530 | 0.00000 | 0.00000 | 0.00000 | 24.65820 | 0.00000 | 69.99783 | 0.06228 | 0.00012 | 0.02827 | +# | 21 | module.fc.weight | (10, 32) | 320 | 160 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.78238 | -0.00000 | 0.48823 | +# | 22 | Total sparsity: | - | 130544 | 78776 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 39.65560 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 39.66 +# +# --- validate (epoch=359)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 90.580 Top5: 99.650 Loss: 0.341 +# +# ==> Best Top1: 90.860 on Epoch: 294 +# Saving checkpoint to: logs/2018.11.30-152150/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 90.580 Top5: 99.650 Loss: 0.341 + +version: 1 + +pruners: + low_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.10 + final_sparsity: 0.50 + group_type: Filters + weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight, + module.layer2.0.downsample.0.weight, + module.layer2.1.conv2.weight, module.layer2.2.conv2.weight, + module.layer2.1.conv1.weight, module.layer2.2.conv1.weight] + + low_pruner_2: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.10 + final_sparsity: 0.50 + group_type: Filters + group_dependency: Leader + weights: [module.layer3.0.conv2.weight, module.layer3.0.downsample.0.weight, + module.layer3.1.conv2.weight, module.layer3.2.conv2.weight] + + fine_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: [module.layer3.1.conv1.weight, module.layer3.1.conv2.weight, + module.layer3.2.conv1.weight, module.layer3.2.conv2.weight] + + fc_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.05 + final_sparsity: 0.50 + group_type: Rows + weights: [module.fc.weight] + + +lr_schedulers: + pruning_lr: + class: StepLR + step_size: 50 + gamma: 0.10 + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'resnet20_cifar' + dataset: 'cifar10' + + +policies: + - pruner: + instance_name : low_pruner + starting_epoch: 180 + ending_epoch: 210 + frequency: 2 + + - pruner: + instance_name : low_pruner_2 + starting_epoch: 180 + ending_epoch: 210 + frequency: 2 + + - pruner: + instance_name : fine_pruner + starting_epoch: 210 + ending_epoch: 230 + frequency: 2 + + - pruner: + instance_name : fc_pruner + starting_epoch: 210 + ending_epoch: 230 + frequency: 2 + + # Currently the thinner is disabled until the the structure pruner is done, because it interacts + # with the sparsity goals of the L1RankedStructureParameterPruner_AGP. + # This can be fixed rather easily. + # - extension: + # instance_name: net_thinner + # starting_epoch: 0 + # ending_epoch: 20 + # frequency: 2 + +# After completeing the pruning, we perform network thinning and continue fine-tuning. + - extension: + instance_name: net_thinner + #epochs: [181] + epochs: [212] + + - lr_scheduler: + instance_name: pruning_lr + starting_epoch: 180 + ending_epoch: 400 + frequency: 1 diff --git a/examples/classifier_compression/logging.conf b/examples/classifier_compression/logging.conf index 1bad269356a3f18b2915e874ee4511c548d0d61d..66b7739e815627a6480fa88e4e846336d4120a10 100755 --- a/examples/classifier_compression/logging.conf +++ b/examples/classifier_compression/logging.conf @@ -5,7 +5,7 @@ keys: simple, time_simple keys: console, file [loggers] -keys: root, app_cfg +keys: root, app_cfg, distiller.thinning [formatter_simple] format: %(message)s @@ -39,8 +39,8 @@ handlers: file # Example of adding a module-specific logger # Do not forget to add apputils.model_summaries to the list of keys in section [loggers] -# [logger_apputils.model_summaries] -# level: DEBUG -# qualname: apputils.model_summaries -# propagate: 0 -# handlers: console +[logger_distiller.thinning] +level: INFO +qualname: distiller.thinning +propagate: 0 +handlers: console, file diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml index 1e9e455680e08980013cb05830038994c310ea85..cf6c7220d39d85ce95e1f933174ff774beeebbd7 100755 --- a/examples/network_trimming/resnet56_cifar_activation_apoz.yaml +++ b/examples/network_trimming/resnet56_cifar_activation_apoz.yaml @@ -100,67 +100,55 @@ version: 1 pruners: -# filter_pruner: -# class: 'ActivationAPoZRankedStructureParameterPruner' -# reg_regims: -# 'module.layer1.0.conv1.weight': Filters - - # filter_pruner: - # class: ActivationAPoZRankedFilterPruner_AGP - # initial_sparsity : 0.10 - # final_sparsity: 0.6 - # reg_regims: - # module.layer1.0.conv1.weight: Filters - filter_pruner_60: - #class: StructuredAutomatedGradualPruner class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.6 - reg_regims: - module.layer1.0.conv1.weight: Filters - module.layer1.1.conv1.weight: Filters - module.layer1.2.conv1.weight: Filters - module.layer1.3.conv1.weight: Filters - module.layer1.4.conv1.weight: Filters - module.layer1.5.conv1.weight: Filters - module.layer1.6.conv1.weight: Filters - module.layer1.7.conv1.weight: Filters - module.layer1.8.conv1.weight: Filters + group_type: Filters + weights: [ + module.layer1.0.conv1.weight, + module.layer1.1.conv1.weight, + module.layer1.2.conv1.weight, + module.layer1.3.conv1.weight, + module.layer1.4.conv1.weight, + module.layer1.5.conv1.weight, + module.layer1.6.conv1.weight, + module.layer1.7.conv1.weight, + module.layer1.8.conv1.weight] filter_pruner_50: #class: StructuredAutomatedGradualPruner class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.5 - reg_regims: - module.layer2.1.conv1.weight: Filters - module.layer2.2.conv1.weight: Filters - module.layer2.3.conv1.weight: Filters - module.layer2.4.conv1.weight: Filters - module.layer2.6.conv1.weight: Filters - module.layer2.7.conv1.weight: Filters + group_type: Filters + weights: [ + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight, + module.layer2.3.conv1.weight, + module.layer2.4.conv1.weight, + module.layer2.6.conv1.weight, + module.layer2.7.conv1.weight] filter_pruner_10: - #class: StructuredAutomatedGradualPruner class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0 final_sparsity: 0.1 - reg_regims: - module.layer3.1.conv1.weight: Filters + group_type: Filters + weights: [module.layer3.1.conv1.weight] filter_pruner_30: - #class: StructuredAutomatedGradualPruner class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.3 - reg_regims: - module.layer3.2.conv1.weight: Filters - module.layer3.3.conv1.weight: Filters - module.layer3.5.conv1.weight: Filters - module.layer3.6.conv1.weight: Filters - module.layer3.7.conv1.weight: Filters - module.layer3.8.conv1.weight: Filters + group_type: Filters + weights: [ + module.layer3.2.conv1.weight, + module.layer3.3.conv1.weight, + module.layer3.5.conv1.weight, + module.layer3.6.conv1.weight, + module.layer3.7.conv1.weight, + module.layer3.8.conv1.weight] extensions: @@ -204,6 +192,7 @@ policies: - extension: instance_name: net_thinner epochs: [200] + #epochs: [181] - lr_scheduler: instance_name: exp_finetuning_lr diff --git a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml index 1062520d51d39a72426a69edc6391244d8e88213..4f65e441bb6afa19cf1b8de8ece040fed2272c0e 100755 --- a/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml +++ b/examples/network_trimming/resnet56_cifar_activation_apoz_v2.yaml @@ -106,47 +106,50 @@ pruners: class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.7 - reg_regims: - module.layer1.0.conv1.weight: Filters - module.layer1.1.conv1.weight: Filters - module.layer1.2.conv1.weight: Filters - module.layer1.3.conv1.weight: Filters - module.layer1.4.conv1.weight: Filters - module.layer1.5.conv1.weight: Filters - module.layer1.6.conv1.weight: Filters - module.layer1.7.conv1.weight: Filters - module.layer1.8.conv1.weight: Filters + group_type: Filters + weights: [ + module.layer1.0.conv1.weight, + module.layer1.1.conv1.weight, + module.layer1.2.conv1.weight, + module.layer1.3.conv1.weight, + module.layer1.4.conv1.weight, + module.layer1.5.conv1.weight, + module.layer1.6.conv1.weight, + module.layer1.7.conv1.weight, + module.layer1.8.conv1.weight] filter_pruner_50: class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.6 - reg_regims: - module.layer2.1.conv1.weight: Filters - module.layer2.2.conv1.weight: Filters - module.layer2.3.conv1.weight: Filters - module.layer2.4.conv1.weight: Filters - module.layer2.6.conv1.weight: Filters - module.layer2.7.conv1.weight: Filters + group_type: Filters + weights: [ + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight, + module.layer2.3.conv1.weight, + module.layer2.4.conv1.weight, + module.layer2.6.conv1.weight, + module.layer2.7.conv1.weight] filter_pruner_10: class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0 final_sparsity: 0.2 - reg_regims: - module.layer3.1.conv1.weight: Filters + group_type: Filters + weights: [module.layer3.1.conv1.weight] filter_pruner_30: class: ActivationAPoZRankedFilterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.4 - reg_regims: - module.layer3.2.conv1.weight: Filters - module.layer3.3.conv1.weight: Filters - module.layer3.5.conv1.weight: Filters - module.layer3.6.conv1.weight: Filters - module.layer3.7.conv1.weight: Filters - module.layer3.8.conv1.weight: Filters + group_type: Filters + weights: [ + module.layer3.2.conv1.weight, + module.layer3.3.conv1.weight, + module.layer3.5.conv1.weight, + module.layer3.6.conv1.weight, + module.layer3.7.conv1.weight, + module.layer3.8.conv1.weight] extensions: diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4d934c5bd8867809d2aebaf240d6cb9853e10cbc --- /dev/null +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml @@ -0,0 +1,189 @@ +# This is an example of one-shot channel pruning. +# It is very similar to the pruning schedule described in +# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf. +# ICLR 2017, arXiv:1608.087 +# However, instead of one-shot filter ranking and pruning, we perform one-shot channel ranking and +# pruning, using L1-magnitude of the structures. +# +# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml --resume=checkpoint.resnet56_cifar_baseline.pth.tar -j=1 --deterministic +# +# Baseline results: +# Top1: 92.850 Top5: 99.780 Loss: 0.464 +# Parameters: 851,504 +# Total MACs: 125,747,840 +# +# Results: +# Top1: 92.580 Top5: 99.670 Loss: 0.378 +# Parameters: 566,887 (=33.4% sparse) +# Total MACs: 66,592,384 (=1.89x less MACs) +# +# Parameters: +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (5, 3, 3, 3) | 135 | 135 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.51071 | 0.02961 | 0.34620 | +# | 1 | module.layer1.0.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10577 | -0.00027 | 0.06610 | +# | 2 | module.layer1.0.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09014 | -0.00534 | 0.05336 | +# | 3 | module.layer1.1.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09359 | -0.00072 | 0.05725 | +# | 4 | module.layer1.1.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06693 | -0.00967 | 0.04221 | +# | 5 | module.layer1.2.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11141 | -0.00108 | 0.07825 | +# | 6 | module.layer1.2.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08755 | 0.00565 | 0.06267 | +# | 7 | module.layer1.3.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12567 | -0.00626 | 0.09563 | +# | 8 | module.layer1.3.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10765 | -0.00790 | 0.08238 | +# | 9 | module.layer1.4.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11128 | -0.00219 | 0.07703 | +# | 10 | module.layer1.4.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09995 | -0.00546 | 0.06941 | +# | 11 | module.layer1.5.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13976 | -0.00127 | 0.09434 | +# | 12 | module.layer1.5.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11994 | 0.01183 | 0.08470 | +# | 13 | module.layer1.6.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14581 | -0.00892 | 0.10762 | +# | 14 | module.layer1.6.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11972 | 0.00398 | 0.08729 | +# | 15 | module.layer1.7.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13305 | 0.00098 | 0.09495 | +# | 16 | module.layer1.7.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10282 | -0.00396 | 0.07579 | +# | 17 | module.layer1.8.conv1.weight | (16, 5, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14427 | 0.00264 | 0.10199 | +# | 18 | module.layer1.8.conv2.weight | (5, 16, 3, 3) | 720 | 720 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10895 | 0.00739 | 0.07599 | +# | 19 | module.layer2.0.conv1.weight | (32, 5, 3, 3) | 1440 | 1440 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.19751 | -0.00411 | 0.14773 | +# | 20 | module.layer2.0.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10607 | -0.00486 | 0.07701 | +# | 21 | module.layer2.0.downsample.0.weight | (32, 5, 1, 1) | 160 | 160 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.25461 | 0.01664 | 0.19264 | +# | 22 | module.layer2.1.conv1.weight | (13, 32, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09205 | -0.00378 | 0.06871 | +# | 23 | module.layer2.1.conv2.weight | (32, 13, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08184 | -0.00368 | 0.06422 | +# | 24 | module.layer2.2.conv1.weight | (13, 32, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08441 | -0.00480 | 0.06477 | +# | 25 | module.layer2.2.conv2.weight | (32, 13, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07339 | -0.00748 | 0.05726 | +# | 26 | module.layer2.3.conv1.weight | (13, 32, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08118 | -0.00391 | 0.06368 | +# | 27 | module.layer2.3.conv2.weight | (32, 13, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06810 | -0.00177 | 0.05296 | +# | 28 | module.layer2.4.conv1.weight | (13, 32, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07782 | -0.00768 | 0.06072 | +# | 29 | module.layer2.4.conv2.weight | (32, 13, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06117 | -0.00520 | 0.04731 | +# | 30 | module.layer2.5.conv1.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05832 | -0.00430 | 0.04224 | +# | 31 | module.layer2.5.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04601 | -0.00230 | 0.03286 | +# | 32 | module.layer2.6.conv1.weight | (13, 32, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06935 | -0.00572 | 0.05344 | +# | 33 | module.layer2.6.conv2.weight | (32, 13, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05382 | -0.00365 | 0.04143 | +# | 34 | module.layer2.7.conv1.weight | (13, 32, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07991 | -0.00900 | 0.06264 | +# | 35 | module.layer2.7.conv2.weight | (32, 13, 3, 3) | 3744 | 3744 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06059 | -0.00253 | 0.04624 | +# | 36 | module.layer2.8.conv1.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04541 | -0.00436 | 0.02956 | +# | 37 | module.layer2.8.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03298 | -0.00051 | 0.02021 | +# | 38 | module.layer3.0.conv1.weight | (64, 32, 3, 3) | 18432 | 18432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07679 | -0.00169 | 0.05996 | +# | 39 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06658 | -0.00063 | 0.04878 | +# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11350 | 0.00252 | 0.07997 | +# | 41 | module.layer3.1.conv1.weight | (52, 64, 3, 3) | 29952 | 29952 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05254 | -0.00156 | 0.03844 | +# | 42 | module.layer3.1.conv2.weight | (64, 52, 3, 3) | 29952 | 29952 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05045 | -0.00537 | 0.03825 | +# | 43 | module.layer3.2.conv1.weight | (39, 64, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05230 | -0.00151 | 0.03934 | +# | 44 | module.layer3.2.conv2.weight | (64, 39, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04915 | -0.00644 | 0.03828 | +# | 45 | module.layer3.3.conv1.weight | (39, 64, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05474 | -0.00361 | 0.04263 | +# | 46 | module.layer3.3.conv2.weight | (64, 39, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04832 | -0.00569 | 0.03775 | +# | 47 | module.layer3.4.conv1.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05756 | -0.00462 | 0.04486 | +# | 48 | module.layer3.4.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04677 | -0.00323 | 0.03594 | +# | 49 | module.layer3.5.conv1.weight | (39, 64, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06053 | -0.00528 | 0.04773 | +# | 50 | module.layer3.5.conv2.weight | (64, 39, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04602 | -0.00363 | 0.03534 | +# | 51 | module.layer3.6.conv1.weight | (39, 64, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04386 | -0.00286 | 0.03391 | +# | 52 | module.layer3.6.conv2.weight | (64, 39, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03310 | -0.00052 | 0.02422 | +# | 53 | module.layer3.7.conv1.weight | (39, 64, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03997 | -0.00282 | 0.03058 | +# | 54 | module.layer3.7.conv2.weight | (64, 39, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02942 | -0.00015 | 0.02128 | +# | 55 | module.layer3.8.conv1.weight | (39, 64, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05219 | -0.00294 | 0.04075 | +# | 56 | module.layer3.8.conv2.weight | (64, 39, 3, 3) | 22464 | 22464 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03520 | 0.00144 | 0.02545 | +# | 57 | module.fc.weight | (10, 64) | 640 | 640 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.49462 | -0.00002 | 0.39447 | +# | 58 | Total sparsity: | - | 566887 | 566887 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# Total sparsity: 0.00 +# +# --- validate (epoch=249)----------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 92.460 Top5: 99.670 Loss: 0.381 +# +# ==> Best Top1: 92.580 on Epoch: 248 +# Saving checkpoint to: logs/2018.11.29-145258/checkpoint.pth.tar +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# ==> Top1: 92.460 Top5: 99.670 Loss: 0.381 + + + + +version: 1 +pruners: + filter_pruner_70: + class: 'L1RankedStructureParameterPruner' + group_type: Channels + desired_sparsity: 0.7 + group_dependency: Leader + weights: [ + module.layer1.1.conv1.weight, + module.layer1.0.conv1.weight, + module.layer1.2.conv1.weight, + module.layer1.3.conv1.weight, + module.layer1.4.conv1.weight, + module.layer1.5.conv1.weight, + module.layer1.6.conv1.weight, + module.layer1.7.conv1.weight, + module.layer1.8.conv1.weight, + module.layer2.0.conv1.weight, + module.layer2.0.downsample.0.weight + ] + + filter_pruner_60: + class: 'L1RankedStructureParameterPruner' + group_type: Channels + desired_sparsity: 0.6 + weights: [ + module.layer2.1.conv2.weight, + module.layer2.2.conv2.weight, + module.layer2.3.conv2.weight, + module.layer2.4.conv2.weight, + module.layer2.6.conv2.weight, + module.layer2.7.conv2.weight] + + filter_pruner_20: + class: 'L1RankedStructureParameterPruner' + group_type: Channels + desired_sparsity: 0.2 + weights: [module.layer3.1.conv2.weight] + + filter_pruner_40: + class: 'L1RankedStructureParameterPruner' + group_type: Channels + desired_sparsity: 0.4 + weights: [ + module.layer3.2.conv2.weight, + module.layer3.3.conv2.weight, + module.layer3.5.conv2.weight, + module.layer3.6.conv2.weight, + module.layer3.7.conv2.weight, + module.layer3.8.conv2.weight] + + +extensions: + net_thinner: + class: StructureRemover + thinning_func_str: remove_channels + arch: resnet56_cifar + dataset: cifar10 + +lr_schedulers: + exp_finetuning_lr: + class: ExponentialLR + gamma: 0.95 + + +policies: + - pruner: + instance_name: filter_pruner_70 + epochs: [180] + + - pruner: + instance_name: filter_pruner_60 + epochs: [180] + + - pruner: + instance_name: filter_pruner_40 + epochs: [180] + + - pruner: + instance_name: filter_pruner_20 + epochs: [180] + + - extension: + instance_name: net_thinner + epochs: [180] + + - lr_scheduler: + instance_name: exp_finetuning_lr + starting_epoch: 190 + ending_epoch: 300 + frequency: 1 diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml index 3b33d1b9fdce89829f0553e00f4e1c7d595fda52..d4dce9622f98f6d4d664f0a241c336f52c4a1a80 100755 --- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml @@ -113,69 +113,51 @@ version: 1 pruners: - filter_pruner: + filter_pruner_60: class: 'L1RankedStructureParameterPruner' - reg_regims: - #'module.conv1.weight': [0.2, '3D'] - 'module.layer1.0.conv1.weight': [0.6, '3D'] - #'module.layer1.0.conv2.weight': [0.4, '3D'] - 'module.layer1.1.conv1.weight': [0.6, '3D'] - #'module.layer1.1.conv2.weight': [0.6, '3D'] - 'module.layer1.2.conv1.weight': [0.6, '3D'] - #'module.layer1.2.conv2.weight': [0.6, '3D'] - 'module.layer1.3.conv1.weight': [0.6, '3D'] - #'module.layer1.3.conv2.weight': [0.6, '3D'] - 'module.layer1.4.conv1.weight': [0.6, '3D'] - #'module.layer1.4.conv2.weight': [0.6, '3D'] - 'module.layer1.5.conv1.weight': [0.6, '3D'] - #'module.layer1.5.conv2.weight': [0.6, '3D'] - 'module.layer1.6.conv1.weight': [0.6, '3D'] - #'module.layer1.6.conv2.weight': [0.6, '3D'] - 'module.layer1.7.conv1.weight': [0.6, '3D'] - #'module.layer1.7.conv2.weight': [0.6, '3D'] - 'module.layer1.8.conv1.weight': [0.6, '3D'] - #'module.layer1.8.conv2.weight': [0.2, '3D'] + group_type: Filters + desired_sparsity: 0.6 + weights: [ + module.layer1.0.conv1.weight, + module.layer1.1.conv1.weight, + module.layer1.2.conv1.weight, + module.layer1.3.conv1.weight, + module.layer1.4.conv1.weight, + module.layer1.5.conv1.weight, + module.layer1.6.conv1.weight, + module.layer1.7.conv1.weight, + module.layer1.8.conv1.weight] - ##'module.layer2.0.conv1.weight': [0.5, '3D'] - #'module.layer2.0.conv2.weight': [0.3, '3D'] - #'module.layer2.0.downsample.0.weight': [0.3, '3D'] - 'module.layer2.1.conv1.weight': [0.5, '3D'] - #'module.layer2.1.conv2.weight': [0.3, '3D'] - 'module.layer2.2.conv1.weight': [0.5, '3D'] - #'module.layer2.2.conv2.weight': [0.3, '3D'] - 'module.layer2.3.conv1.weight': [0.5, '3D'] - #'module.layer2.3.conv2.weight': [0.3, '3D'] + filter_pruner_50: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.5 + weights: [ + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight, + module.layer2.3.conv1.weight, + module.layer2.4.conv1.weight, + module.layer2.6.conv1.weight, + module.layer2.7.conv1.weight] + + filter_pruner_10: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.1 + weights: [module.layer3.1.conv1.weight] - 'module.layer2.4.conv1.weight': [0.5, '3D'] - # 'module.layer2.4.conv2.weight': [0.3, '3D'] - #'module.layer2.5.conv1.weight': [0.5, '3D'] - # 'module.layer2.5.conv2.weight': [0.3, '3D'] - 'module.layer2.6.conv1.weight': [0.5, '3D'] - # 'module.layer2.6.conv2.weight': [0.2, '3D'] - 'module.layer2.7.conv1.weight': [0.5, '3D'] - # 'module.layer2.7.conv2.weight': [0.2, '3D'] - ##'module.layer2.8.conv1.weight': [0.3, '3D'] - # 'module.layer2.8.conv2.weight': [0.2, '3D'] + filter_pruner_30: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.3 + weights: [ + module.layer3.2.conv1.weight, + module.layer3.3.conv1.weight, + module.layer3.5.conv1.weight, + module.layer3.6.conv1.weight, + module.layer3.7.conv1.weight, + module.layer3.8.conv1.weight] - #'module.layer3.0.conv1.weight': [0.1, '3D'] - # 'module.layer3.0.conv2.weight': [0.1, '3D'] - # 'module.layer3.0.downsample.0.weight': [0.1, '3D'] - 'module.layer3.1.conv1.weight': [0.1, '3D'] - # 'module.layer3.1.conv2.weight': [0.1, '3D'] - 'module.layer3.2.conv1.weight': [0.3, '3D'] - # 'module.layer3.2.conv2.weight': [0.1, '3D'] - 'module.layer3.3.conv1.weight': [0.3, '3D'] - # 'module.layer3.3.conv2.weight': [0.1, '3D'] - #'module.layer3.4.conv1.weight': [0.1, '3D'] - # 'module.layer3.4.conv2.weight': [0.1, '3D'] - 'module.layer3.5.conv1.weight': [0.3, '3D'] - #'module.layer3.5.conv2.weight': [0.1, '3D'] - 'module.layer3.6.conv1.weight': [0.3, '3D'] - # 'module.layer3.6.conv2.weight': [0.1, '3D'] - 'module.layer3.7.conv1.weight': [0.3, '3D'] - # 'module.layer3.7.conv2.weight': [0.1, '3D'] - 'module.layer3.8.conv1.weight': [0.3, '3D'] - # 'module.layer3.8.conv2.weight': [0.2, '3D'] extensions: net_thinner: @@ -192,7 +174,19 @@ lr_schedulers: policies: - pruner: - instance_name: filter_pruner + instance_name: filter_pruner_60 + epochs: [180] + + - pruner: + instance_name: filter_pruner_50 + epochs: [180] + + - pruner: + instance_name: filter_pruner_30 + epochs: [180] + + - pruner: + instance_name: filter_pruner_10 epochs: [180] - extension: diff --git a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml index bc0d5df0238c4d98d51ae241bb375aa977904ce0..6e00522e43dd568918e08625ede4baba212929b8 100755 --- a/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml +++ b/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml @@ -115,69 +115,51 @@ version: 1 pruners: - filter_pruner: + filter_pruner_70: class: 'L1RankedStructureParameterPruner' - reg_regims: - #'module.conv1.weight': [0.2, '3D'] - 'module.layer1.0.conv1.weight': [0.7, '3D'] - #'module.layer1.0.conv2.weight': [0.4, '3D'] - 'module.layer1.1.conv1.weight': [0.7, '3D'] - #'module.layer1.1.conv2.weight': [0.7, '3D'] - 'module.layer1.2.conv1.weight': [0.7, '3D'] - #'module.layer1.2.conv2.weight': [0.7, '3D'] - 'module.layer1.3.conv1.weight': [0.7, '3D'] - #'module.layer1.3.conv2.weight': [0.7, '3D'] - 'module.layer1.4.conv1.weight': [0.7, '3D'] - #'module.layer1.4.conv2.weight': [0.7, '3D'] - 'module.layer1.5.conv1.weight': [0.7, '3D'] - #'module.layer1.5.conv2.weight': [0.7, '3D'] - 'module.layer1.6.conv1.weight': [0.7, '3D'] - #'module.layer1.6.conv2.weight': [0.7, '3D'] - 'module.layer1.7.conv1.weight': [0.7, '3D'] - #'module.layer1.7.conv2.weight': [0.7, '3D'] - 'module.layer1.8.conv1.weight': [0.7, '3D'] - #'module.layer1.8.conv2.weight': [0.2, '3D'] + group_type: Filters + desired_sparsity: 0.7 + weights: [ + module.layer1.0.conv1.weight, + module.layer1.1.conv1.weight, + module.layer1.2.conv1.weight, + module.layer1.3.conv1.weight, + module.layer1.4.conv1.weight, + module.layer1.5.conv1.weight, + module.layer1.6.conv1.weight, + module.layer1.7.conv1.weight, + module.layer1.8.conv1.weight] - ##'module.layer2.0.conv1.weight': [0.6, '3D'] - #'module.layer2.0.conv2.weight': [0.4, '3D'] - #'module.layer2.0.downsample.0.weight': [0.4, '3D'] - 'module.layer2.1.conv1.weight': [0.6, '3D'] - #'module.layer2.1.conv2.weight': [0.4, '3D'] - 'module.layer2.2.conv1.weight': [0.6, '3D'] - #'module.layer2.2.conv2.weight': [0.4, '3D'] - 'module.layer2.3.conv1.weight': [0.6, '3D'] - #'module.layer2.3.conv2.weight': [0.4, '3D'] + filter_pruner_60: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.6 + weights: [ + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight, + module.layer2.3.conv1.weight, + module.layer2.4.conv1.weight, + module.layer2.6.conv1.weight, + module.layer2.7.conv1.weight] + + filter_pruner_20: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.2 + weights: [module.layer3.1.conv1.weight] - 'module.layer2.4.conv1.weight': [0.6, '3D'] - # 'module.layer2.4.conv2.weight': [0.4, '3D'] - #'module.layer2.5.conv1.weight': [0.6, '3D'] - # 'module.layer2.5.conv2.weight': [0.4, '3D'] - 'module.layer2.6.conv1.weight': [0.6, '3D'] - # 'module.layer2.6.conv2.weight': [0.2, '3D'] - 'module.layer2.7.conv1.weight': [0.6, '3D'] - # 'module.layer2.7.conv2.weight': [0.2, '3D'] - ##'module.layer2.8.conv1.weight': [0.4, '3D'] - # 'module.layer2.8.conv2.weight': [0.2, '3D'] + filter_pruner_40: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.4 + weights: [ + module.layer3.2.conv1.weight, + module.layer3.3.conv1.weight, + module.layer3.5.conv1.weight, + module.layer3.6.conv1.weight, + module.layer3.7.conv1.weight, + module.layer3.8.conv1.weight] - #'module.layer3.0.conv1.weight': [0.1, '3D'] - # 'module.layer3.0.conv2.weight': [0.1, '3D'] - # 'module.layer3.0.downsample.0.weight': [0.1, '3D'] - 'module.layer3.1.conv1.weight': [0.2, '3D'] - # 'module.layer3.1.conv2.weight': [0.1, '3D'] - 'module.layer3.2.conv1.weight': [0.4, '3D'] - # 'module.layer3.2.conv2.weight': [0.1, '3D'] - 'module.layer3.3.conv1.weight': [0.4, '3D'] - # 'module.layer3.3.conv2.weight': [0.1, '3D'] - #'module.layer3.4.conv1.weight': [0.1, '3D'] - # 'module.layer3.4.conv2.weight': [0.1, '3D'] - 'module.layer3.5.conv1.weight': [0.4, '3D'] - #'module.layer3.5.conv2.weight': [0.1, '3D'] - 'module.layer3.6.conv1.weight': [0.4, '3D'] - # 'module.layer3.6.conv2.weight': [0.1, '3D'] - 'module.layer3.7.conv1.weight': [0.4, '3D'] - # 'module.layer3.7.conv2.weight': [0.1, '3D'] - 'module.layer3.8.conv1.weight': [0.4, '3D'] - # 'module.layer3.8.conv2.weight': [0.2, '3D'] extensions: net_thinner: @@ -194,7 +176,19 @@ lr_schedulers: policies: - pruner: - instance_name: filter_pruner + instance_name: filter_pruner_70 + epochs: [180] + + - pruner: + instance_name: filter_pruner_60 + epochs: [180] + + - pruner: + instance_name: filter_pruner_40 + epochs: [180] + + - pruner: + instance_name: filter_pruner_20 epochs: [180] - extension: diff --git a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml index cb20dd1ea5aea4e8d58803ce14b35cffaf10990b..0304fd8eab4ab1a8b7fbd664c25dcdf60c4bd90a 100755 --- a/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml +++ b/examples/pruning_filters_for_efficient_convnets/vgg19.schedule_filter_rank.yaml @@ -10,27 +10,30 @@ version: 1 pruners: vgg_manual: class: 'L1RankedStructureParameterPruner' - reg_regims: -# 'features.module.0.weight': [0.1, '3D'] - 'features.module.2.weight': [0.1, '3D'] - 'features.module.5.weight': [0.1, '3D'] - 'features.module.7.weight': [0.1, '3D'] - 'features.module.10.weight': [0.1, '3D'] - 'features.module.12.weight': [0.1, '3D'] - 'features.module.14.weight': [0.1, '3D'] - 'features.module.16.weight': [0.1, '3D'] - 'features.module.19.weight': [0.1, '3D'] + group_type: Filters + desired_sparsity: 0.1 + weights: [ + features.module.2.weight, + features.module.5.weight, + features.module.7.weight, + features.module.10.weight, + features.module.12.weight, + features.module.14.weight, + features.module.16.weight, + features.module.19.weight] vgg_manual2: class: 'L1RankedStructureParameterPruner' - reg_regims: - 'features.module.21.weight': [0.1, '3D'] - 'features.module.23.weight': [0.1, '3D'] - 'features.module.25.weight': [0.1, '3D'] - 'features.module.28.weight': [0.1, '3D'] - 'features.module.30.weight': [0.1, '3D'] - 'features.module.32.weight': [0.1, '3D'] - 'features.module.34.weight': [0.1, '3D'] + group_type: Filters + desired_sparsity: 0.1 + weights: [ + features.module.21.weight, + features.module.23.weight, + features.module.25.weight, + features.module.28.weight, + features.module.30.weight, + features.module.32.weight, + features.module.34.weight] extensions: net_thinner: diff --git a/imgs/pruning_structs_ex1.png b/imgs/pruning_structs_ex1.png new file mode 100755 index 0000000000000000000000000000000000000000..7eea414984bc50838f5a3eb3c07ed024f39846a0 Binary files /dev/null and b/imgs/pruning_structs_ex1.png differ diff --git a/imgs/pruning_structs_ex2.png b/imgs/pruning_structs_ex2.png new file mode 100755 index 0000000000000000000000000000000000000000..cbf08789c2b0df7d5d232bf7f0c920f718568856 Binary files /dev/null and b/imgs/pruning_structs_ex2.png differ diff --git a/imgs/pruning_structs_ex3.png b/imgs/pruning_structs_ex3.png new file mode 100755 index 0000000000000000000000000000000000000000..22f98c2d18b10594a4b9c5fa81efe3fd14743412 Binary files /dev/null and b/imgs/pruning_structs_ex3.png differ diff --git a/imgs/pruning_structs_ex4.png b/imgs/pruning_structs_ex4.png new file mode 100755 index 0000000000000000000000000000000000000000..02d48a2af960886085cb8789dc3ea66e6b66e37c Binary files /dev/null and b/imgs/pruning_structs_ex4.png differ diff --git a/imgs/pruning_structs_ex5.png b/imgs/pruning_structs_ex5.png new file mode 100755 index 0000000000000000000000000000000000000000..e70a0c807f9502e134fdcae640498a7140c6b9a8 Binary files /dev/null and b/imgs/pruning_structs_ex5.png differ diff --git a/tests/test_pruning.py b/tests/test_pruning.py index fc3b0f19adb3cec0fd0f7feadab084f212d03a69..90b506544b9604b1aa7ceaa7279058b4b79d4899 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -138,8 +138,10 @@ def ranked_filter_pruning(config, ratio_to_prune, is_parallel): assert distiller.sparsity_3D(conv1_p) == 0.0 # Create a filter-ranking pruner - reg_regims = {pair[0] + ".weight": [ratio_to_prune, "3D"]} - pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) + pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", + group_type="Filters", + desired_sparsity=ratio_to_prune, + weights=pair[0] + ".weight") pruner.set_param_mask(conv1_p, pair[0] + ".weight", zeros_mask_dict, meta=None) conv1 = common.find_module_by_name(model, pair[0]) @@ -347,8 +349,10 @@ def test_conv_fc_interface(is_parallel=parallel, model=None, zeros_mask_dict=Non assert conv_p.dim() == 4 # Create a filter-ranking pruner - reg_regims = {conv_name + ".weight": [ratio_to_prune, "3D"]} - pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", reg_regims) + pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner", + group_type="Filters", + desired_sparsity=ratio_to_prune, + weights=conv_name + ".weight") pruner.set_param_mask(conv_p, conv_name + ".weight", zeros_mask_dict, meta=None) # Use the mask to prune diff --git a/tests/test_ranking.py b/tests/test_ranking.py index a0fa14a222d2e7f53a0e418c569577b3801d22f6..d48e927c1055d36735a8a87c06911dddd6922289 100755 --- a/tests/test_ranking.py +++ b/tests/test_ranking.py @@ -55,9 +55,8 @@ def test_ch_ranking(): [37, 38]]]]) fraction_to_prune = 0.5 - bottomk_channels, channel_mags = distiller.pruning.L1RankedStructureParameterPruner.rank_channels(fraction_to_prune, param) - logger.info("bottom {}% channels: {}".format(fraction_to_prune*100, bottomk_channels)) - assert bottomk_channels == torch.tensor([90.]) + binary_map = distiller.pruning.L1RankedStructureParameterPruner.rank_and_prune_channels(fraction_to_prune, param) + assert all(binary_map == torch.tensor([0., 1.])) def test_ranked_channel_pruning(): @@ -71,8 +70,10 @@ def test_ranked_channel_pruning(): assert distiller.sparsity_ch(conv1_p) == 0.0 # # Create a channel-ranking pruner - reg_regims = {"layer1.0.conv1.weight": [0.1, "Channels"]} - pruner = distiller.pruning.L1RankedStructureParameterPruner("channel_pruner", reg_regims) + pruner = distiller.pruning.L1RankedStructureParameterPruner("channel_pruner", + group_type="Channels", + desired_sparsity=0.1, + weights="layer1.0.conv1.weight") pruner.set_param_mask(conv1_p, "layer1.0.conv1.weight", zeros_mask_dict, meta=None) conv1 = common.find_module_by_name(model, "layer1.0.conv1")