From b9d53ff885115adf49bbf7965fd7d07e92ebe391 Mon Sep 17 00:00:00 2001
From: Bar <elhararb@gmail.com>
Date: Tue, 8 Jan 2019 22:07:12 +0200
Subject: [PATCH] Non-channel/filter block pruning (#119)

Block pruning: support specifying the block shape from the YAML file

Block pruning refers to pruning 4-D structures of a specific shape.  This
is a why it is sometimes called structure-pruning or group-pruning
(confusing, I know).
A specific example of block pruning is filter or channel pruning, which
have a highly-regular block shape.
This commit adds support for pruning blocks/groups/structures
that have irregular shapes that accelerate inference on a specific
hardware platform.  You can read more about the regularity of shapes in
(Exploring the Regularity of Sparse Structure in
Convolutional Neural Networks)[https://arxiv.org/pdf/1705.08922.pdf].

When we want to introduce sparsity in order to reduce the compute load
of a certain layer, we need to understand how the HW and SW perform
the layer's operation, and how this operation is vectorized.  Then we can
induce sparsity to match the vector shape.

For example, Intel AVX-512 are SIMD instructions that apply the same
instruction (Single Instruction) on a vector of inputs (Multiple
Data).  The following single instruction performs an element-wise
multiplication of two 16 32-bit element vectors:

     __m256i result = __mm256_mul_epi32(vec_a, vec_b);

If either vec_a or vec_b are partially sparse, we still need to perform
the multiplication operation and the sparsity does not help reduce the
cost (power, latency) of computation.  However, if either vec_a or vec_b
contain only zeros then we can eliminate entirely the instruction.  In this
case, we say that we would like to have group sparsity of 16-elements.
I.e. the HW/SW benefits from sparsity induced in blocks of 16 elements.

Things are a bit more involved because we also need to understand how the
software maps layer operations to hardware.  For example, a 3x3
convolution can be computed as a direct-convolution, as a matrix multiply
operation, or as a Winograd matrix operation (to name a few ways of
computation).  These low-level operations are then mapped to SIMD
instructions.

Finally, the low-level SW needs to support a block-sparse storage-format
for weight tensors (see for example:
http://www.netlib.org/linalg/html_templates/node90.html)
---
 distiller/pruning/automated_gradual_pruner.py |   6 +-
 distiller/pruning/ranked_structures_pruner.py | 102 +++++++++++-
 distiller/utils.py                            |  46 +++++-
 .../resnet50.schedule_agp.1x1x8-blocks.yaml   | 147 ++++++++++++++++++
 4 files changed, 293 insertions(+), 8 deletions(-)
 create mode 100755 examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml

diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index bede488..e267a68 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -99,10 +99,10 @@ class StructuredAGP(AutomatedGradualPrunerBase):
 # TODO: this class parameterization is cumbersome: the ranking functions (per structure)
 # should come from the YAML schedule
 class L1RankedStructureParameterPruner_AGP(StructuredAGP):
-    def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None):
+    def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None):
         super().__init__(name, initial_sparsity, final_sparsity)
-        self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0,
-                                                       weights=weights, group_dependency=group_dependency)
+        self.pruner = L1RankedStructureParameterPruner(name, group_type, desired_sparsity=0, weights=weights,
+                                                       group_dependency=group_dependency, kwargs=kwargs)
 
 
 class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP):
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 58f6681..1ca9237 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 
+from functools import partial
 import numpy as np
 import logging
 import torch
@@ -82,12 +83,17 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
 
     This class prunes to a prescribed percentage of structured-sparsity (level pruning).
     """
-    def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
+    def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
-        if group_type not in ['3D', 'Filters', 'Channels', 'Rows']:
-            raise ValueError("Structure {} was requested but"
-                             "currently only filter (3D) and channel ranking is supported".
+        if group_type not in ['3D', 'Filters', 'Channels', 'Rows', 'Blocks']:
+            raise ValueError("Structure {} was requested but "
+                             "currently ranking of this shape is not supported".
                              format(group_type))
+        if group_type == 'Blocks':
+            try:
+                self.block_shape = kwargs['block_shape']
+            except KeyError:
+                raise ValueError("When defining a block pruner you must also specify the block shape")
 
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
@@ -98,6 +104,8 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
             group_pruning_fn = self.rank_and_prune_channels
         elif self.group_type == 'Rows':
             group_pruning_fn = self.rank_and_prune_rows
+        elif self.group_type == 'Blocks':
+            group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape)
 
         binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
         return binary_map
@@ -207,6 +215,92 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
 
+    @staticmethod
+    def rank_and_prune_blocks(fraction_to_prune, param, param_name=None,
+                              zeros_mask_dict=None, model=None, binary_map=None, block_shape=None):
+        """Block-wise pruning for 4D tensors.
+
+        The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width].
+        The dimension 'block_repetitions' specifies in how many consecutive filters the "basic block"
+        (shaped as [block_depth, block_height, block_width]) repeats to produce a (4D) "super block".
+
+        For example:
+
+          block_pruner:
+            class: L1RankedStructureParameterPruner_AGP
+            initial_sparsity : 0.05
+            final_sparsity: 0.70
+            group_type: Blocks
+            kwargs:
+              block_shape: [1,8,1,1]  # [block_repetitions, block_depth, block_height, block_width]
+
+        Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1
+        """
+        if len(block_shape) != 4:
+            raise ValueError("The block shape must be specified as a 4-element tuple")
+        block_repetitions, block_depth, block_height, block_width = block_shape
+        if not block_width == block_height == 1:
+            raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1")
+        super_block_volume = distiller.volume(block_shape)
+        num_super_blocks = distiller.volume(param) / super_block_volume
+        if distiller.volume(param) % super_block_volume != 0:
+            raise ValueError("The super-block size must divide the weight tensor exactly.")
+
+        num_filters = param.size(0)
+        num_channels = param.size(1)
+        kernel_size = param.size(2) * param.size(3)
+
+        if block_depth > 1:
+            view_dims = (
+                num_filters*num_channels//(block_repetitions*block_depth),
+                block_repetitions*block_depth,
+                kernel_size,
+                )
+        else:
+            view_dims = (
+                num_filters // block_repetitions,
+                block_repetitions,
+                -1,
+                )
+
+        def rank_blocks(fraction_to_prune, param):
+            # Create a view where each block is a column
+            view1 = param.view(*view_dims)
+            # Next, compute the sums of each column (block)
+            block_sums = view1.abs().sum(dim=1)
+
+            # Now group by channels
+            block_mags = block_sums.view(-1)  # flatten
+            k = int(fraction_to_prune * block_mags.size(0))
+            if k == 0:
+                msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks",
+                               block_mags.size(0), 100*fraction_to_prune)
+                return None, None
+
+            bottomk, _ = torch.topk(block_mags, k, largest=False, sorted=True)
+            return bottomk, block_mags
+
+        def binary_map_to_mask(binary_map, param):
+            a = binary_map.view(view_dims[0], view_dims[2])
+            c = a.unsqueeze(1)
+            d = c.expand(*view_dims).contiguous()
+            return d.view(num_filters, num_channels, param.size(2), param.size(3))
+
+        if binary_map is None:
+            bottomk_blocks, block_mags = rank_blocks(fraction_to_prune, param)
+            if bottomk_blocks is None:
+                # Empty list means that fraction_to_prune is too low to prune anything
+                return
+            threshold = bottomk_blocks[-1]
+            binary_map = block_mags.gt(threshold).type(param.data.type())
+
+        if zeros_mask_dict is not None:
+            zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
+            msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+                           distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape),
+                           fraction_to_prune, binary_map.sum().item(), num_super_blocks)
+        return binary_map
+
 
 def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map):
     if binary_map is None:
diff --git a/distiller/utils.py b/distiller/utils.py
index 17d1097..b1e70c8 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -131,7 +131,7 @@ def volume(tensor):
     """return the volume of a pytorch tensor"""
     if isinstance(tensor, torch.FloatTensor) or isinstance(tensor, torch.cuda.FloatTensor):
         return np.prod(tensor.shape)
-    if isinstance(tensor, tuple):
+    if isinstance(tensor, tuple) or isinstance(tensor, list):
         return np.prod(tensor)
     raise ValueError
 
@@ -246,6 +246,50 @@ def density_ch(tensor):
     return 1 - sparsity_ch(tensor)
 
 
+def sparsity_blocks(tensor, block_shape):
+    """Block-wise sparsity for 4D tensors
+
+    Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1
+    """
+    if tensor.dim() != 4:
+        raise ValueError("sparsity_blocks is only supported for 4-D tensors")
+
+    if len(block_shape) != 4:
+        raise ValueError("Block shape must be specified as a 4-element tuple")
+    block_repetitions, block_depth, block_height, block_width = block_shape
+    if not block_width == block_height == 1:
+        raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1")
+
+    super_block_volume = volume(block_shape)
+    num_super_blocks = volume(tensor) / super_block_volume
+
+    num_filters, num_channels = tensor.size(0), tensor.size(1)
+    kernel_size = tensor.size(2) * tensor.size(3)
+
+    # Create a view where each block is a column
+    if block_depth > 1:
+        view_dims = (
+            num_filters*num_channels//(block_repetitions*block_depth),
+            block_repetitions*block_depth,
+            kernel_size,
+            )
+    else:
+        view_dims = (
+            num_filters // block_repetitions,
+            block_repetitions,
+            -1,
+            )
+    view1 = tensor.view(*view_dims)
+
+    # Next, compute the sums of each column (block)
+    block_sums = view1.abs().sum(dim=1)
+
+    # Next, compute the sums of each column (block)
+    block_sums = view1.abs().sum(dim=1)
+    nonzero_blocks = len(torch.nonzero(block_sums))
+    return 1 - nonzero_blocks/num_super_blocks
+
+
 def sparsity_matrix(tensor, dim):
     """Generic sparsity computation for 2D matrices"""
     if tensor.dim() != 2:
diff --git a/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
new file mode 100755
index 0000000..342492c
--- /dev/null
+++ b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
@@ -0,0 +1,147 @@
+#
+# This schedule performs 1x1x8 block pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
+# The final Linear layer (FC) is also pruned to 70%.
+#
+# Best Top1: 76.358 (epoch 72) vs. 76.15 baseline (+0.2%)
+#
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml --validation-size=0 --num-best-scores=10
+#
+# Parameters:
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11032 | -0.00044 |    0.06760 |
+# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |           4096 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06353 | -0.00371 |    0.03587 |
+# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 28.12500 |  7.81250 |   69.98698 | 0.02277 |  0.00061 |    0.00835 |
+# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03142 |  0.00034 |    0.01880 |
+# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05117 | -0.00298 |    0.02858 |
+# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02722 |  0.00105 |    0.01803 |
+# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 18.75000 |  1.56250 |   69.98698 | 0.02097 |  0.00016 |    0.00841 |
+# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02896 | -0.00002 |    0.01815 |
+# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02671 |  0.00015 |    0.01929 |
+# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 13.47656 |  0.00000 |   69.98698 | 0.02149 | -0.00033 |    0.00930 |
+# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02755 | -0.00215 |    0.01658 |
+# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03152 | -0.00126 |    0.02213 |
+# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 19.04297 |  0.00000 |   69.99783 | 0.01489 | -0.00011 |    0.00633 |
+# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02486 |  0.00003 |    0.01535 |
+# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02046 | -0.00033 |    0.01198 |
+# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01482 | -0.00005 |    0.00895 |
+# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 22.36328 |  0.78125 |   69.99783 | 0.01512 |  0.00037 |    0.00598 |
+# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01964 | -0.00101 |    0.01122 |
+# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02073 | -0.00067 |    0.01437 |
+# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 14.64844 |  0.00000 |   69.99783 | 0.01522 |  0.00006 |    0.00622 |
+# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02328 | -0.00032 |    0.01636 |
+# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02161 | -0.00079 |    0.01598 |
+# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 12.79297 |  0.00000 |   69.99783 | 0.01498 | -0.00022 |    0.00650 |
+# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02159 | -0.00091 |    0.01488 |
+# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02732 | -0.00100 |    0.01950 |
+# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 24.70703 |  0.00000 |   69.99919 | 0.01165 | -0.00010 |    0.00486 |
+# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02073 | -0.00034 |    0.01470 |
+# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01429 |  0.00003 |    0.00978 |
+# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01360 | -0.00050 |    0.00955 |
+# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 16.29639 |  0.00000 |   69.99919 | 0.01055 |  0.00002 |    0.00442 |
+# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01820 | -0.00090 |    0.01307 |
+# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01407 | -0.00041 |    0.01007 |
+# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.88965 |  0.00000 |   69.99919 | 0.01011 | -0.00021 |    0.00433 |
+# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01698 | -0.00063 |    0.01240 |
+# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01550 | -0.00058 |    0.01146 |
+# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.92627 |  0.00000 |   69.99919 | 0.00985 | -0.00019 |    0.00429 |
+# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01625 | -0.00094 |    0.01197 |
+# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01617 | -0.00080 |    0.01216 |
+# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.99951 |  0.00000 |   69.99919 | 0.00980 | -0.00028 |    0.00428 |
+# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01623 | -0.00131 |    0.01196 |
+# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01757 | -0.00075 |    0.01337 |
+# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.16943 |  0.00000 |   69.99919 | 0.01001 | -0.00032 |    0.00438 |
+# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01725 | -0.00191 |    0.01294 |
+# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02114 | -0.00099 |    0.01632 |
+# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 19.15894 |  0.00000 |   69.99986 | 0.00801 | -0.00012 |    0.00358 |
+# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01369 | -0.00055 |    0.01057 |
+# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00892 | -0.00015 |    0.00679 |
+# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01330 | -0.00056 |    0.01038 |
+# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 13.93127 |  0.00000 |   69.99986 | 0.00781 | -0.00028 |    0.00351 |
+# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01347 | -0.00007 |    0.01039 |
+# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01628 | -0.00034 |    0.01277 |
+# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 23.70911 |  0.00000 |   69.99986 | 0.00686 | -0.00021 |    0.00310 |
+# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01262 |  0.00002 |    0.00943 |
+# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         614400 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   70.00000 | 0.03148 |  0.00299 |    0.01480 |
+# | 54 | Total sparsity:                     | -                  |      25502912 |       16147304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   36.68447 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-12-24 13:52:24,645 - Total sparsity: 36.68
+#
+# 2018-12-24 13:52:24,645 - --- validate (epoch=72)-----------
+# 2018-12-24 13:52:24,645 - 50000 samples (256 per mini-batch)
+# 2018-12-24 13:52:44,774 - Epoch: [72][   50/  195]    Loss 0.676330    Top1 82.195312    Top5 96.039062
+# 2018-12-24 13:52:52,702 - Epoch: [72][  100/  195]    Loss 0.799058    Top1 79.386719    Top5 94.863281
+# 2018-12-24 13:53:00,916 - Epoch: [72][  150/  195]    Loss 0.911178    Top1 77.216146    Top5 93.466146
+# 2018-12-24 13:53:08,224 - ==> Top1: 76.358    Top5: 92.972    Loss: 0.952
+#
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.454 on Epoch: 1
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.446 on Epoch: 0
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.416 on Epoch: 3
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.358 on Epoch: 72
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.344 on Epoch: 2
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.326 on Epoch: 69
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.320 on Epoch: 68
+# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.318 on Epoch: 70
+# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.300 on Epoch: 58
+# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.284 on Epoch: 71
+
+version: 1
+
+pruners:
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: module.fc.weight
+
+
+  block_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    group_type: Blocks
+    kwargs:
+      block_shape: [1,8,1,1]  # [block_repetition, block_depth, block_height, block_width]
+    weights: [module.layer1.0.conv2.weight,
+              module.layer1.1.conv2.weight,
+              module.layer1.2.conv2.weight,
+              module.layer2.0.conv2.weight,
+              module.layer2.1.conv2.weight,
+              module.layer2.2.conv2.weight,
+              module.layer2.3.conv2.weight,
+              module.layer3.0.conv2.weight,
+              module.layer3.1.conv2.weight,
+              module.layer3.2.conv2.weight,
+              module.layer3.3.conv2.weight,
+              module.layer3.4.conv2.weight,
+              module.layer3.5.conv2.weight,
+              module.layer4.0.conv2.weight,
+              module.layer4.1.conv2.weight,
+              module.layer4.2.conv2.weight]
+
+
+lr_schedulers:
+  pruning_lr:
+    class: ExponentialLR
+    gamma: 0.95
+
+
+policies:
+  - pruner:
+     instance_name : block_pruner
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 1
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 3
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 40
+    ending_epoch: 80
+    frequency: 1
-- 
GitLab