diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index bede4887de3ee225eb8d798df37a955b056b372e..e267a686df218f67065f04b2e93a5e6407f5554d 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -99,10 +99,10 @@ class StructuredAGP(AutomatedGradualPrunerBase): # TODO: this class parameterization is cumbersome: the ranking functions (per structure) # should come from the YAML schedule class L1RankedStructureParameterPruner_AGP(StructuredAGP): - def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None): super().__init__(name, initial_sparsity, final_sparsity) - self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0, - weights=weights, group_dependency=group_dependency) + self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0, weights=weights, + group_dependency=group_dependency, kwargs=kwargs) class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP): diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 58f66818b21d10b8b859f3b3777cd0c8088a7a84..1ca9237a42a5d83616c9566ee1487599d8317dd7 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -14,6 +14,7 @@ # limitations under the License. # +from functools import partial import numpy as np import logging import torch @@ -82,12 +83,17 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): This class prunes to a prescribed percentage of structured-sparsity (level pruning). """ - def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None): super().__init__(name, group_type, desired_sparsity, weights, group_dependency) - if group_type not in ['3D', 'Filters', 'Channels', 'Rows']: - raise ValueError("Structure {} was requested but" - "currently only filter (3D) and channel ranking is supported". + if group_type not in ['3D', 'Filters', 'Channels', 'Rows', 'Blocks']: + raise ValueError("Structure {} was requested but " + "currently ranking of this shape is not supported". format(group_type)) + if group_type == 'Blocks': + try: + self.block_shape = kwargs['block_shape'] + except KeyError: + raise ValueError("When defining a block pruner you must also specify the block shape") def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): if fraction_to_prune == 0: @@ -98,6 +104,8 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): group_pruning_fn = self.rank_and_prune_channels elif self.group_type == 'Rows': group_pruning_fn = self.rank_and_prune_rows + elif self.group_type == 'Blocks': + group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape) binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) return binary_map @@ -207,6 +215,92 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) + @staticmethod + def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, + zeros_mask_dict=None, model=None, binary_map=None, block_shape=None): + """Block-wise pruning for 4D tensors. + + The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width]. + The dimension 'block_repetitions' specifies in how many consecutive filters the "basic block" + (shaped as [block_depth, block_height, block_width]) repeats to produce a (4D) "super block". + + For example: + + block_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.05 + final_sparsity: 0.70 + group_type: Blocks + kwargs: + block_shape: [1,8,1,1] # [block_repetitions, block_depth, block_height, block_width] + + Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1 + """ + if len(block_shape) != 4: + raise ValueError("The block shape must be specified as a 4-element tuple") + block_repetitions, block_depth, block_height, block_width = block_shape + if not block_width == block_height == 1: + raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1") + super_block_volume = distiller.volume(block_shape) + num_super_blocks = distiller.volume(param) / super_block_volume + if distiller.volume(param) % super_block_volume != 0: + raise ValueError("The super-block size must divide the weight tensor exactly.") + + num_filters = param.size(0) + num_channels = param.size(1) + kernel_size = param.size(2) * param.size(3) + + if block_depth > 1: + view_dims = ( + num_filters*num_channels//(block_repetitions*block_depth), + block_repetitions*block_depth, + kernel_size, + ) + else: + view_dims = ( + num_filters // block_repetitions, + block_repetitions, + -1, + ) + + def rank_blocks(fraction_to_prune, param): + # Create a view where each block is a column + view1 = param.view(*view_dims) + # Next, compute the sums of each column (block) + block_sums = view1.abs().sum(dim=1) + + # Now group by channels + block_mags = block_sums.view(-1) # flatten + k = int(fraction_to_prune * block_mags.size(0)) + if k == 0: + msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks", + block_mags.size(0), 100*fraction_to_prune) + return None, None + + bottomk, _ = torch.topk(block_mags, k, largest=False, sorted=True) + return bottomk, block_mags + + def binary_map_to_mask(binary_map, param): + a = binary_map.view(view_dims[0], view_dims[2]) + c = a.unsqueeze(1) + d = c.expand(*view_dims).contiguous() + return d.view(num_filters, num_channels, param.size(2), param.size(3)) + + if binary_map is None: + bottomk_blocks, block_mags = rank_blocks(fraction_to_prune, param) + if bottomk_blocks is None: + # Empty list means that fraction_to_prune is too low to prune anything + return + threshold = bottomk_blocks[-1] + binary_map = block_mags.gt(threshold).type(param.data.type()) + + if zeros_mask_dict is not None: + zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param) + msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, + distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape), + fraction_to_prune, binary_map.sum().item(), num_super_blocks) + return binary_map + def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map): if binary_map is None: diff --git a/distiller/utils.py b/distiller/utils.py index 17d1097c52b02fdb87ced151ba86fc78d08d3e65..b1e70c8c9b1bbc6426fcbef72a9f1aa0563d4cd8 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -131,7 +131,7 @@ def volume(tensor): """return the volume of a pytorch tensor""" if isinstance(tensor, torch.FloatTensor) or isinstance(tensor, torch.cuda.FloatTensor): return np.prod(tensor.shape) - if isinstance(tensor, tuple): + if isinstance(tensor, tuple) or isinstance(tensor, list): return np.prod(tensor) raise ValueError @@ -246,6 +246,50 @@ def density_ch(tensor): return 1 - sparsity_ch(tensor) +def sparsity_blocks(tensor, block_shape): + """Block-wise sparsity for 4D tensors + + Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1 + """ + if tensor.dim() != 4: + raise ValueError("sparsity_blocks is only supported for 4-D tensors") + + if len(block_shape) != 4: + raise ValueError("Block shape must be specified as a 4-element tuple") + block_repetitions, block_depth, block_height, block_width = block_shape + if not block_width == block_height == 1: + raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1") + + super_block_volume = volume(block_shape) + num_super_blocks = volume(tensor) / super_block_volume + + num_filters, num_channels = tensor.size(0), tensor.size(1) + kernel_size = tensor.size(2) * tensor.size(3) + + # Create a view where each block is a column + if block_depth > 1: + view_dims = ( + num_filters*num_channels//(block_repetitions*block_depth), + block_repetitions*block_depth, + kernel_size, + ) + else: + view_dims = ( + num_filters // block_repetitions, + block_repetitions, + -1, + ) + view1 = tensor.view(*view_dims) + + # Next, compute the sums of each column (block) + block_sums = view1.abs().sum(dim=1) + + # Next, compute the sums of each column (block) + block_sums = view1.abs().sum(dim=1) + nonzero_blocks = len(torch.nonzero(block_sums)) + return 1 - nonzero_blocks/num_super_blocks + + def sparsity_matrix(tensor, dim): """Generic sparsity computation for 2D matrices""" if tensor.dim() != 2: diff --git a/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml new file mode 100755 index 0000000000000000000000000000000000000000..342492cd57ae00f93555028a51a71776cd9cf4a1 --- /dev/null +++ b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml @@ -0,0 +1,147 @@ +# +# This schedule performs 1x1x8 block pruning using L1-norm ranking and AGP for the setting the pruning-rate decay. +# The final Linear layer (FC) is also pruned to 70%. +# +# Best Top1: 76.358 (epoch 72) vs. 76.15 baseline (+0.2%) +# +# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml --validation-size=0 --num-best-scores=10 +# +# Parameters: +# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11032 | -0.00044 | 0.06760 | +# | 1 | module.layer1.0.conv1.weight | (64, 64, 1, 1) | 4096 | 4096 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.06353 | -0.00371 | 0.03587 | +# | 2 | module.layer1.0.conv2.weight | (64, 64, 3, 3) | 36864 | 11064 | 0.00000 | 0.00000 | 0.00000 | 28.12500 | 7.81250 | 69.98698 | 0.02277 | 0.00061 | 0.00835 | +# | 3 | module.layer1.0.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03142 | 0.00034 | 0.01880 | +# | 4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05117 | -0.00298 | 0.02858 | +# | 5 | module.layer1.1.conv1.weight | (64, 256, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02722 | 0.00105 | 0.01803 | +# | 6 | module.layer1.1.conv2.weight | (64, 64, 3, 3) | 36864 | 11064 | 0.00000 | 0.00000 | 0.00000 | 18.75000 | 1.56250 | 69.98698 | 0.02097 | 0.00016 | 0.00841 | +# | 7 | module.layer1.1.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02896 | -0.00002 | 0.01815 | +# | 8 | module.layer1.2.conv1.weight | (64, 256, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02671 | 0.00015 | 0.01929 | +# | 9 | module.layer1.2.conv2.weight | (64, 64, 3, 3) | 36864 | 11064 | 0.00000 | 0.00000 | 0.00000 | 13.47656 | 0.00000 | 69.98698 | 0.02149 | -0.00033 | 0.00930 | +# | 10 | module.layer1.2.conv3.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02755 | -0.00215 | 0.01658 | +# | 11 | module.layer2.0.conv1.weight | (128, 256, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03152 | -0.00126 | 0.02213 | +# | 12 | module.layer2.0.conv2.weight | (128, 128, 3, 3) | 147456 | 44240 | 0.00000 | 0.00000 | 0.00000 | 19.04297 | 0.00000 | 69.99783 | 0.01489 | -0.00011 | 0.00633 | +# | 13 | module.layer2.0.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02486 | 0.00003 | 0.01535 | +# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02046 | -0.00033 | 0.01198 | +# | 15 | module.layer2.1.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01482 | -0.00005 | 0.00895 | +# | 16 | module.layer2.1.conv2.weight | (128, 128, 3, 3) | 147456 | 44240 | 0.00000 | 0.00000 | 0.00000 | 22.36328 | 0.78125 | 69.99783 | 0.01512 | 0.00037 | 0.00598 | +# | 17 | module.layer2.1.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01964 | -0.00101 | 0.01122 | +# | 18 | module.layer2.2.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02073 | -0.00067 | 0.01437 | +# | 19 | module.layer2.2.conv2.weight | (128, 128, 3, 3) | 147456 | 44240 | 0.00000 | 0.00000 | 0.00000 | 14.64844 | 0.00000 | 69.99783 | 0.01522 | 0.00006 | 0.00622 | +# | 20 | module.layer2.2.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02328 | -0.00032 | 0.01636 | +# | 21 | module.layer2.3.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02161 | -0.00079 | 0.01598 | +# | 22 | module.layer2.3.conv2.weight | (128, 128, 3, 3) | 147456 | 44240 | 0.00000 | 0.00000 | 0.00000 | 12.79297 | 0.00000 | 69.99783 | 0.01498 | -0.00022 | 0.00650 | +# | 23 | module.layer2.3.conv3.weight | (512, 128, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02159 | -0.00091 | 0.01488 | +# | 24 | module.layer3.0.conv1.weight | (256, 512, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02732 | -0.00100 | 0.01950 | +# | 25 | module.layer3.0.conv2.weight | (256, 256, 3, 3) | 589824 | 176952 | 0.00000 | 0.00000 | 0.00000 | 24.70703 | 0.00000 | 69.99919 | 0.01165 | -0.00010 | 0.00486 | +# | 26 | module.layer3.0.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02073 | -0.00034 | 0.01470 | +# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01429 | 0.00003 | 0.00978 | +# | 28 | module.layer3.1.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01360 | -0.00050 | 0.00955 | +# | 29 | module.layer3.1.conv2.weight | (256, 256, 3, 3) | 589824 | 176952 | 0.00000 | 0.00000 | 0.00000 | 16.29639 | 0.00000 | 69.99919 | 0.01055 | 0.00002 | 0.00442 | +# | 30 | module.layer3.1.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01820 | -0.00090 | 0.01307 | +# | 31 | module.layer3.2.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01407 | -0.00041 | 0.01007 | +# | 32 | module.layer3.2.conv2.weight | (256, 256, 3, 3) | 589824 | 176952 | 0.00000 | 0.00000 | 0.00000 | 11.88965 | 0.00000 | 69.99919 | 0.01011 | -0.00021 | 0.00433 | +# | 33 | module.layer3.2.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01698 | -0.00063 | 0.01240 | +# | 34 | module.layer3.3.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01550 | -0.00058 | 0.01146 | +# | 35 | module.layer3.3.conv2.weight | (256, 256, 3, 3) | 589824 | 176952 | 0.00000 | 0.00000 | 0.00000 | 11.92627 | 0.00000 | 69.99919 | 0.00985 | -0.00019 | 0.00429 | +# | 36 | module.layer3.3.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01625 | -0.00094 | 0.01197 | +# | 37 | module.layer3.4.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01617 | -0.00080 | 0.01216 | +# | 38 | module.layer3.4.conv2.weight | (256, 256, 3, 3) | 589824 | 176952 | 0.00000 | 0.00000 | 0.00000 | 11.99951 | 0.00000 | 69.99919 | 0.00980 | -0.00028 | 0.00428 | +# | 39 | module.layer3.4.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01623 | -0.00131 | 0.01196 | +# | 40 | module.layer3.5.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01757 | -0.00075 | 0.01337 | +# | 41 | module.layer3.5.conv2.weight | (256, 256, 3, 3) | 589824 | 176952 | 0.00000 | 0.00000 | 0.00000 | 11.16943 | 0.00000 | 69.99919 | 0.01001 | -0.00032 | 0.00438 | +# | 42 | module.layer3.5.conv3.weight | (1024, 256, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01725 | -0.00191 | 0.01294 | +# | 43 | module.layer4.0.conv1.weight | (512, 1024, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02114 | -0.00099 | 0.01632 | +# | 44 | module.layer4.0.conv2.weight | (512, 512, 3, 3) | 2359296 | 707792 | 0.00000 | 0.00000 | 0.00000 | 19.15894 | 0.00000 | 69.99986 | 0.00801 | -0.00012 | 0.00358 | +# | 45 | module.layer4.0.conv3.weight | (2048, 512, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01369 | -0.00055 | 0.01057 | +# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) | 2097152 | 2097152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00892 | -0.00015 | 0.00679 | +# | 47 | module.layer4.1.conv1.weight | (512, 2048, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01330 | -0.00056 | 0.01038 | +# | 48 | module.layer4.1.conv2.weight | (512, 512, 3, 3) | 2359296 | 707792 | 0.00000 | 0.00000 | 0.00000 | 13.93127 | 0.00000 | 69.99986 | 0.00781 | -0.00028 | 0.00351 | +# | 49 | module.layer4.1.conv3.weight | (2048, 512, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01347 | -0.00007 | 0.01039 | +# | 50 | module.layer4.2.conv1.weight | (512, 2048, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01628 | -0.00034 | 0.01277 | +# | 51 | module.layer4.2.conv2.weight | (512, 512, 3, 3) | 2359296 | 707792 | 0.00000 | 0.00000 | 0.00000 | 23.70911 | 0.00000 | 69.99986 | 0.00686 | -0.00021 | 0.00310 | +# | 52 | module.layer4.2.conv3.weight | (2048, 512, 1, 1) | 1048576 | 1048576 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01262 | 0.00002 | 0.00943 | +# | 53 | module.fc.weight | (1000, 2048) | 2048000 | 614400 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 70.00000 | 0.03148 | 0.00299 | 0.01480 | +# | 54 | Total sparsity: | - | 25502912 | 16147304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 36.68447 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# 2018-12-24 13:52:24,645 - Total sparsity: 36.68 +# +# 2018-12-24 13:52:24,645 - --- validate (epoch=72)----------- +# 2018-12-24 13:52:24,645 - 50000 samples (256 per mini-batch) +# 2018-12-24 13:52:44,774 - Epoch: [72][ 50/ 195] Loss 0.676330 Top1 82.195312 Top5 96.039062 +# 2018-12-24 13:52:52,702 - Epoch: [72][ 100/ 195] Loss 0.799058 Top1 79.386719 Top5 94.863281 +# 2018-12-24 13:53:00,916 - Epoch: [72][ 150/ 195] Loss 0.911178 Top1 77.216146 Top5 93.466146 +# 2018-12-24 13:53:08,224 - ==> Top1: 76.358 Top5: 92.972 Loss: 0.952 +# +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.454 on Epoch: 1 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.446 on Epoch: 0 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.416 on Epoch: 3 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.358 on Epoch: 72 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.344 on Epoch: 2 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.326 on Epoch: 69 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.320 on Epoch: 68 +# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.318 on Epoch: 70 +# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.300 on Epoch: 58 +# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.284 on Epoch: 71 + +version: 1 + +pruners: + fc_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: module.fc.weight + + + block_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.05 + final_sparsity: 0.70 + group_type: Blocks + kwargs: + block_shape: [1,8,1,1] # [block_repetition, block_depth, block_height, block_width] + weights: [module.layer1.0.conv2.weight, + module.layer1.1.conv2.weight, + module.layer1.2.conv2.weight, + module.layer2.0.conv2.weight, + module.layer2.1.conv2.weight, + module.layer2.2.conv2.weight, + module.layer2.3.conv2.weight, + module.layer3.0.conv2.weight, + module.layer3.1.conv2.weight, + module.layer3.2.conv2.weight, + module.layer3.3.conv2.weight, + module.layer3.4.conv2.weight, + module.layer3.5.conv2.weight, + module.layer4.0.conv2.weight, + module.layer4.1.conv2.weight, + module.layer4.2.conv2.weight] + + +lr_schedulers: + pruning_lr: + class: ExponentialLR + gamma: 0.95 + + +policies: + - pruner: + instance_name : block_pruner + starting_epoch: 0 + ending_epoch: 30 + frequency: 1 + + - pruner: + instance_name : fc_pruner + starting_epoch: 0 + ending_epoch: 30 + frequency: 3 + + - lr_scheduler: + instance_name: pruning_lr + starting_epoch: 40 + ending_epoch: 80 + frequency: 1