From d89a3a07ebd1111aad080d2d9a4ce9713c0ccaf5 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 19 Dec 2018 15:52:28 +0200
Subject: [PATCH] AGP for filters: added FC element-wise pruning to
 filter-pruning

In short: improves top1 results.  Might be just due to random conditions.
---
 ...resnet50.schedule_agp.filters_with_FC.yaml | 207 ++++++++++++++++++
 1 file changed, 207 insertions(+)
 create mode 100755 examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml

diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml
new file mode 100755
index 0000000..15cc334
--- /dev/null
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml
@@ -0,0 +1,207 @@
+#
+# This schedule performs filter-pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
+# The final Linear layer (FC) is also pruned to 70%.
+# Curiously, we achieve slightly better Top1, when compared to the same schedule without the FC pruning.
+#
+# Best Top1: 74.564 (epoch 84) vs. 76.15 baseline (-1.6%)
+# No. of Parameters: 10,901,696 (of 25,502,912) = 42.74% dense (57.26% sparse)
+# Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ~/datasets/imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-size=0   --num-best-scores=10
+#
+# Parameters:
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11119 | -0.00042 |    0.06789 |
+# |  1 | module.layer1.0.conv1.weight        | (32, 64, 1, 1)     |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07661 | -0.00610 |    0.04643 |
+# |  2 | module.layer1.0.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04028 |  0.00160 |    0.02608 |
+# |  3 | module.layer1.0.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03803 | -0.00044 |    0.02407 |
+# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05144 | -0.00304 |    0.02863 |
+# |  5 | module.layer1.1.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03162 |  0.00100 |    0.02178 |
+# |  6 | module.layer1.1.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03687 |  0.00027 |    0.02591 |
+# |  7 | module.layer1.1.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03185 | -0.00048 |    0.02030 |
+# |  8 | module.layer1.2.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03011 |  0.00019 |    0.02205 |
+# |  9 | module.layer1.2.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03635 | -0.00009 |    0.02744 |
+# | 10 | module.layer1.2.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02803 | -0.00242 |    0.01684 |
+# | 11 | module.layer2.0.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03596 | -0.00125 |    0.02536 |
+# | 12 | module.layer2.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02418 |  0.00002 |    0.01789 |
+# | 13 | module.layer2.0.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02696 |  0.00000 |    0.01652 |
+# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02060 | -0.00044 |    0.01214 |
+# | 15 | module.layer2.1.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01739 | -0.00021 |    0.01075 |
+# | 16 | module.layer2.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02545 |  0.00070 |    0.01662 |
+# | 17 | module.layer2.1.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02249 | -0.00146 |    0.01323 |
+# | 18 | module.layer2.2.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02338 | -0.00056 |    0.01624 |
+# | 19 | module.layer2.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02410 |  0.00015 |    0.01685 |
+# | 20 | module.layer2.2.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02576 |  0.00017 |    0.01794 |
+# | 21 | module.layer2.3.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02341 | -0.00082 |    0.01743 |
+# | 22 | module.layer2.3.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02380 | -0.00048 |    0.01804 |
+# | 23 | module.layer2.3.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02311 | -0.00123 |    0.01596 |
+# | 24 | module.layer3.0.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02988 | -0.00090 |    0.02166 |
+# | 25 | module.layer3.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01793 | -0.00014 |    0.01335 |
+# | 26 | module.layer3.0.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02273 | -0.00044 |    0.01634 |
+# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01441 | -0.00002 |    0.00988 |
+# | 28 | module.layer3.1.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01521 | -0.00035 |    0.01075 |
+# | 29 | module.layer3.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01686 | -0.00003 |    0.01215 |
+# | 30 | module.layer3.1.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01961 | -0.00062 |    0.01396 |
+# | 31 | module.layer3.2.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01551 | -0.00032 |    0.01105 |
+# | 32 | module.layer3.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01649 | -0.00058 |    0.01217 |
+# | 33 | module.layer3.2.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01839 | -0.00051 |    0.01337 |
+# | 34 | module.layer3.3.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01679 | -0.00058 |    0.01252 |
+# | 35 | module.layer3.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01554 | -0.00052 |    0.01180 |
+# | 36 | module.layer3.3.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01745 | -0.00095 |    0.01283 |
+# | 37 | module.layer3.4.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01739 | -0.00078 |    0.01312 |
+# | 38 | module.layer3.4.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01537 | -0.00063 |    0.01168 |
+# | 39 | module.layer3.4.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01709 | -0.00124 |    0.01253 |
+# | 40 | module.layer3.5.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01870 | -0.00072 |    0.01434 |
+# | 41 | module.layer3.5.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01530 | -0.00072 |    0.01172 |
+# | 42 | module.layer3.5.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01855 | -0.00212 |    0.01395 |
+# | 43 | module.layer4.0.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02219 | -0.00086 |    0.01714 |
+# | 44 | module.layer4.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01231 | -0.00011 |    0.00960 |
+# | 45 | module.layer4.0.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01452 | -0.00058 |    0.01130 |
+# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00903 | -0.00019 |    0.00688 |
+# | 47 | module.layer4.1.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01434 | -0.00029 |    0.01122 |
+# | 48 | module.layer4.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01229 | -0.00059 |    0.00963 |
+# | 49 | module.layer4.1.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01432 |  0.00003 |    0.01107 |
+# | 50 | module.layer4.2.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01774 | -0.00010 |    0.01394 |
+# | 51 | module.layer4.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01076 | -0.00036 |    0.00845 |
+# | 52 | module.layer4.2.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01319 |  0.00018 |    0.00993 |
+# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         614400 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   70.00000 | 0.03188 |  0.00306 |    0.01495 |
+# | 54 | Total sparsity:                     | -                  |      12335296 |       10901696 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   11.62193 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-12-13 17:32:34,087 - Total sparsity: 11.62
+#
+# 2018-12-13 17:32:34,087 - --- validate (epoch=99)-----------
+# 2018-12-13 17:32:34,088 - 50000 samples (256 per mini-batch)
+# 2018-12-13 17:32:50,281 - Epoch: [99][   50/  195]    Loss 0.734106    Top1 80.734375    Top5 95.226562
+# 2018-12-13 17:32:57,293 - Epoch: [99][  100/  195]    Loss 0.858832    Top1 77.910156    Top5 93.960938
+# 2018-12-13 17:33:04,432 - Epoch: [99][  150/  195]    Loss 0.981425    Top1 75.471354    Top5 92.401042
+# 2018-12-13 17:33:10,060 - ==> Top1: 74.426    Top5: 91.962    Loss: 1.025
+#
+# 2018-12-13 17:33:10,194 - ==> Best Top1: 75.912 on Epoch: 0
+# 2018-12-13 17:33:10,194 - ==> Best Top1: 75.492 on Epoch: 1
+# 2018-12-13 17:33:10,194 - ==> Best Top1: 74.942 on Epoch: 2
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.564 on Epoch: 84   <======
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.514 on Epoch: 94
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.494 on Epoch: 80
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.488 on Epoch: 91
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.482 on Epoch: 82
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.482 on Epoch: 93
+# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.472 on Epoch: 95
+# 2018-12-13 17:33:10,195 - Saving checkpoint to: logs/resnet50_filters_v4_with_FC___2018.12.11-172607/resnet50_filters_v4_with_FC_checkpoint.pth.tar
+# 2018-12-13 17:33:10,457 - --- test ---------------------
+# 2018-12-13 17:33:10,458 - 50000 samples (256 per mini-batch)
+# 2018-12-13 17:33:26,953 - Test: [   50/  195]    Loss 0.734106    Top1 80.734375    Top5 95.226562
+# 2018-12-13 17:33:33,762 - Test: [  100/  195]    Loss 0.858832    Top1 77.910156    Top5 93.960938
+# 2018-12-13 17:33:40,901 - Test: [  150/  195]    Loss 0.981425    Top1 75.471354    Top5 92.401042
+# 2018-12-13 17:33:47,076 - ==> Top1: 74.426    Top5: 91.962    Loss: 1.025
+
+
+version: 1
+
+pruners:
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: module.fc.weight
+
+
+  filter_pruner:
+    class: L1RankedStructureParameterPruner_AGP
+    initial_sparsity : 0.05
+    final_sparsity: 0.50
+    group_type: Filters
+    weights: [module.layer1.0.conv1.weight,
+              module.layer1.1.conv1.weight,
+              module.layer1.2.conv1.weight,
+              module.layer2.0.conv1.weight,
+              module.layer2.1.conv1.weight,
+              module.layer2.2.conv1.weight,
+              module.layer2.3.conv1.weight,
+              module.layer3.0.conv1.weight,
+              module.layer3.1.conv1.weight,
+              module.layer3.2.conv1.weight,
+              module.layer3.3.conv1.weight,
+              module.layer3.4.conv1.weight,
+              module.layer3.5.conv1.weight,
+              module.layer4.0.conv1.weight,
+              module.layer4.1.conv1.weight,
+              module.layer4.2.conv1.weight,
+
+
+              module.layer1.0.conv2.weight,
+              module.layer1.1.conv2.weight,
+              module.layer1.2.conv2.weight,
+              module.layer2.0.conv2.weight,
+              module.layer2.1.conv2.weight,
+              module.layer2.2.conv2.weight,
+              module.layer2.3.conv2.weight,
+              module.layer3.0.conv2.weight,
+              module.layer3.1.conv2.weight,
+              module.layer3.2.conv2.weight,
+              module.layer3.3.conv2.weight,
+              module.layer3.4.conv2.weight,
+              module.layer3.5.conv2.weight,
+              module.layer4.0.conv2.weight,
+              module.layer4.1.conv2.weight,
+              module.layer4.2.conv2.weight]
+
+  fine_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.05
+    final_sparsity: 0.70
+    weights: [
+      module.layer4.0.conv2.weight,
+      module.layer4.0.conv3.weight,
+      module.layer4.0.downsample.0.weight,
+      module.layer4.1.conv1.weight,
+      module.layer4.1.conv2.weight,
+      module.layer4.1.conv3.weight,
+      module.layer4.2.conv1.weight,
+      module.layer4.2.conv2.weight,
+      module.layer4.2.conv3.weight]
+
+extensions:
+  net_thinner:
+    class: 'FilterRemover'
+    thinning_func_str: remove_filters
+    arch: 'resnet50'
+    dataset: 'imagenet'
+
+lr_schedulers:
+  pruning_lr:
+    class: ExponentialLR
+    gamma: 0.95
+
+policies:
+#  - pruner:
+#     instance_name : fine_pruner
+#    starting_epoch: 0
+#    ending_epoch: 45
+#    frequency: 3
+
+  - pruner:
+     instance_name : filter_pruner
+#     args:
+#       mini_batch_pruning_frequency: 1
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 1
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 0
+    ending_epoch: 30
+    frequency: 3
+
+# After completeing the pruning, we perform network thinning and continue fine-tuning.
+  - extension:
+      instance_name: net_thinner
+    epochs: [31]
+
+
+  - lr_scheduler:
+      instance_name: pruning_lr
+    starting_epoch: 40
+    ending_epoch: 80
+    frequency: 1
-- 
GitLab