From d89a3a07ebd1111aad080d2d9a4ce9713c0ccaf5 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 19 Dec 2018 15:52:28 +0200 Subject: [PATCH] AGP for filters: added FC element-wise pruning to filter-pruning In short: improves top1 results. Might be just due to random conditions. --- ...resnet50.schedule_agp.filters_with_FC.yaml | 207 ++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100755 examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml new file mode 100755 index 0000000..15cc334 --- /dev/null +++ b/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml @@ -0,0 +1,207 @@ +# +# This schedule performs filter-pruning using L1-norm ranking and AGP for the setting the pruning-rate decay. +# The final Linear layer (FC) is also pruned to 70%. +# Curiously, we achieve slightly better Top1, when compared to the same schedule without the FC pruning. +# +# Best Top1: 74.564 (epoch 84) vs. 76.15 baseline (-1.6%) +# No. of Parameters: 10,901,696 (of 25,502,912) = 42.74% dense (57.26% sparse) +# Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x +# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ~/datasets/imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-size=0 --num-best-scores=10 +# +# Parameters: +# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0 | module.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11119 | -0.00042 | 0.06789 | +# | 1 | module.layer1.0.conv1.weight | (32, 64, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07661 | -0.00610 | 0.04643 | +# | 2 | module.layer1.0.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04028 | 0.00160 | 0.02608 | +# | 3 | module.layer1.0.conv3.weight | (256, 32, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03803 | -0.00044 | 0.02407 | +# | 4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05144 | -0.00304 | 0.02863 | +# | 5 | module.layer1.1.conv1.weight | (32, 256, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03162 | 0.00100 | 0.02178 | +# | 6 | module.layer1.1.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03687 | 0.00027 | 0.02591 | +# | 7 | module.layer1.1.conv3.weight | (256, 32, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03185 | -0.00048 | 0.02030 | +# | 8 | module.layer1.2.conv1.weight | (32, 256, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03011 | 0.00019 | 0.02205 | +# | 9 | module.layer1.2.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03635 | -0.00009 | 0.02744 | +# | 10 | module.layer1.2.conv3.weight | (256, 32, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02803 | -0.00242 | 0.01684 | +# | 11 | module.layer2.0.conv1.weight | (64, 256, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03596 | -0.00125 | 0.02536 | +# | 12 | module.layer2.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02418 | 0.00002 | 0.01789 | +# | 13 | module.layer2.0.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02696 | 0.00000 | 0.01652 | +# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02060 | -0.00044 | 0.01214 | +# | 15 | module.layer2.1.conv1.weight | (64, 512, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01739 | -0.00021 | 0.01075 | +# | 16 | module.layer2.1.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02545 | 0.00070 | 0.01662 | +# | 17 | module.layer2.1.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02249 | -0.00146 | 0.01323 | +# | 18 | module.layer2.2.conv1.weight | (64, 512, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02338 | -0.00056 | 0.01624 | +# | 19 | module.layer2.2.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02410 | 0.00015 | 0.01685 | +# | 20 | module.layer2.2.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02576 | 0.00017 | 0.01794 | +# | 21 | module.layer2.3.conv1.weight | (64, 512, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02341 | -0.00082 | 0.01743 | +# | 22 | module.layer2.3.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02380 | -0.00048 | 0.01804 | +# | 23 | module.layer2.3.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02311 | -0.00123 | 0.01596 | +# | 24 | module.layer3.0.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02988 | -0.00090 | 0.02166 | +# | 25 | module.layer3.0.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01793 | -0.00014 | 0.01335 | +# | 26 | module.layer3.0.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02273 | -0.00044 | 0.01634 | +# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01441 | -0.00002 | 0.00988 | +# | 28 | module.layer3.1.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01521 | -0.00035 | 0.01075 | +# | 29 | module.layer3.1.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01686 | -0.00003 | 0.01215 | +# | 30 | module.layer3.1.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01961 | -0.00062 | 0.01396 | +# | 31 | module.layer3.2.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01551 | -0.00032 | 0.01105 | +# | 32 | module.layer3.2.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01649 | -0.00058 | 0.01217 | +# | 33 | module.layer3.2.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01839 | -0.00051 | 0.01337 | +# | 34 | module.layer3.3.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01679 | -0.00058 | 0.01252 | +# | 35 | module.layer3.3.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01554 | -0.00052 | 0.01180 | +# | 36 | module.layer3.3.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01745 | -0.00095 | 0.01283 | +# | 37 | module.layer3.4.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01739 | -0.00078 | 0.01312 | +# | 38 | module.layer3.4.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01537 | -0.00063 | 0.01168 | +# | 39 | module.layer3.4.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01709 | -0.00124 | 0.01253 | +# | 40 | module.layer3.5.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01870 | -0.00072 | 0.01434 | +# | 41 | module.layer3.5.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01530 | -0.00072 | 0.01172 | +# | 42 | module.layer3.5.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01855 | -0.00212 | 0.01395 | +# | 43 | module.layer4.0.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02219 | -0.00086 | 0.01714 | +# | 44 | module.layer4.0.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01231 | -0.00011 | 0.00960 | +# | 45 | module.layer4.0.conv3.weight | (2048, 256, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01452 | -0.00058 | 0.01130 | +# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) | 2097152 | 2097152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00903 | -0.00019 | 0.00688 | +# | 47 | module.layer4.1.conv1.weight | (256, 2048, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01434 | -0.00029 | 0.01122 | +# | 48 | module.layer4.1.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01229 | -0.00059 | 0.00963 | +# | 49 | module.layer4.1.conv3.weight | (2048, 256, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01432 | 0.00003 | 0.01107 | +# | 50 | module.layer4.2.conv1.weight | (256, 2048, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01774 | -0.00010 | 0.01394 | +# | 51 | module.layer4.2.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01076 | -0.00036 | 0.00845 | +# | 52 | module.layer4.2.conv3.weight | (2048, 256, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01319 | 0.00018 | 0.00993 | +# | 53 | module.fc.weight | (1000, 2048) | 2048000 | 614400 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 70.00000 | 0.03188 | 0.00306 | 0.01495 | +# | 54 | Total sparsity: | - | 12335296 | 10901696 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 11.62193 | 0.00000 | 0.00000 | 0.00000 | +# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# 2018-12-13 17:32:34,087 - Total sparsity: 11.62 +# +# 2018-12-13 17:32:34,087 - --- validate (epoch=99)----------- +# 2018-12-13 17:32:34,088 - 50000 samples (256 per mini-batch) +# 2018-12-13 17:32:50,281 - Epoch: [99][ 50/ 195] Loss 0.734106 Top1 80.734375 Top5 95.226562 +# 2018-12-13 17:32:57,293 - Epoch: [99][ 100/ 195] Loss 0.858832 Top1 77.910156 Top5 93.960938 +# 2018-12-13 17:33:04,432 - Epoch: [99][ 150/ 195] Loss 0.981425 Top1 75.471354 Top5 92.401042 +# 2018-12-13 17:33:10,060 - ==> Top1: 74.426 Top5: 91.962 Loss: 1.025 +# +# 2018-12-13 17:33:10,194 - ==> Best Top1: 75.912 on Epoch: 0 +# 2018-12-13 17:33:10,194 - ==> Best Top1: 75.492 on Epoch: 1 +# 2018-12-13 17:33:10,194 - ==> Best Top1: 74.942 on Epoch: 2 +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.564 on Epoch: 84 <====== +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.514 on Epoch: 94 +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.494 on Epoch: 80 +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.488 on Epoch: 91 +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.482 on Epoch: 82 +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.482 on Epoch: 93 +# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.472 on Epoch: 95 +# 2018-12-13 17:33:10,195 - Saving checkpoint to: logs/resnet50_filters_v4_with_FC___2018.12.11-172607/resnet50_filters_v4_with_FC_checkpoint.pth.tar +# 2018-12-13 17:33:10,457 - --- test --------------------- +# 2018-12-13 17:33:10,458 - 50000 samples (256 per mini-batch) +# 2018-12-13 17:33:26,953 - Test: [ 50/ 195] Loss 0.734106 Top1 80.734375 Top5 95.226562 +# 2018-12-13 17:33:33,762 - Test: [ 100/ 195] Loss 0.858832 Top1 77.910156 Top5 93.960938 +# 2018-12-13 17:33:40,901 - Test: [ 150/ 195] Loss 0.981425 Top1 75.471354 Top5 92.401042 +# 2018-12-13 17:33:47,076 - ==> Top1: 74.426 Top5: 91.962 Loss: 1.025 + + +version: 1 + +pruners: + fc_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: module.fc.weight + + + filter_pruner: + class: L1RankedStructureParameterPruner_AGP + initial_sparsity : 0.05 + final_sparsity: 0.50 + group_type: Filters + weights: [module.layer1.0.conv1.weight, + module.layer1.1.conv1.weight, + module.layer1.2.conv1.weight, + module.layer2.0.conv1.weight, + module.layer2.1.conv1.weight, + module.layer2.2.conv1.weight, + module.layer2.3.conv1.weight, + module.layer3.0.conv1.weight, + module.layer3.1.conv1.weight, + module.layer3.2.conv1.weight, + module.layer3.3.conv1.weight, + module.layer3.4.conv1.weight, + module.layer3.5.conv1.weight, + module.layer4.0.conv1.weight, + module.layer4.1.conv1.weight, + module.layer4.2.conv1.weight, + + + module.layer1.0.conv2.weight, + module.layer1.1.conv2.weight, + module.layer1.2.conv2.weight, + module.layer2.0.conv2.weight, + module.layer2.1.conv2.weight, + module.layer2.2.conv2.weight, + module.layer2.3.conv2.weight, + module.layer3.0.conv2.weight, + module.layer3.1.conv2.weight, + module.layer3.2.conv2.weight, + module.layer3.3.conv2.weight, + module.layer3.4.conv2.weight, + module.layer3.5.conv2.weight, + module.layer4.0.conv2.weight, + module.layer4.1.conv2.weight, + module.layer4.2.conv2.weight] + + fine_pruner: + class: AutomatedGradualPruner + initial_sparsity : 0.05 + final_sparsity: 0.70 + weights: [ + module.layer4.0.conv2.weight, + module.layer4.0.conv3.weight, + module.layer4.0.downsample.0.weight, + module.layer4.1.conv1.weight, + module.layer4.1.conv2.weight, + module.layer4.1.conv3.weight, + module.layer4.2.conv1.weight, + module.layer4.2.conv2.weight, + module.layer4.2.conv3.weight] + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'resnet50' + dataset: 'imagenet' + +lr_schedulers: + pruning_lr: + class: ExponentialLR + gamma: 0.95 + +policies: +# - pruner: +# instance_name : fine_pruner +# starting_epoch: 0 +# ending_epoch: 45 +# frequency: 3 + + - pruner: + instance_name : filter_pruner +# args: +# mini_batch_pruning_frequency: 1 + starting_epoch: 0 + ending_epoch: 30 + frequency: 1 + + - pruner: + instance_name : fc_pruner + starting_epoch: 0 + ending_epoch: 30 + frequency: 3 + +# After completeing the pruning, we perform network thinning and continue fine-tuning. + - extension: + instance_name: net_thinner + epochs: [31] + + + - lr_scheduler: + instance_name: pruning_lr + starting_epoch: 40 + ending_epoch: 80 + frequency: 1 -- GitLab