Skip to content
Snippets Groups Projects
Commit d89a3a07 authored by Neta Zmora's avatar Neta Zmora
Browse files

AGP for filters: added FC element-wise pruning to filter-pruning

In short: improves top1 results.  Might be just due to random conditions.
parent 73c3a7a7
No related branches found
No related tags found
No related merge requests found
#
# This schedule performs filter-pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
# The final Linear layer (FC) is also pruned to 70%.
# Curiously, we achieve slightly better Top1, when compared to the same schedule without the FC pruning.
#
# Best Top1: 74.564 (epoch 84) vs. 76.15 baseline (-1.6%)
# No. of Parameters: 10,901,696 (of 25,502,912) = 42.74% dense (57.26% sparse)
# Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ~/datasets/imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-size=0 --num-best-scores=10
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# | 0 | module.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.11119 | -0.00042 | 0.06789 |
# | 1 | module.layer1.0.conv1.weight | (32, 64, 1, 1) | 2048 | 2048 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07661 | -0.00610 | 0.04643 |
# | 2 | module.layer1.0.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04028 | 0.00160 | 0.02608 |
# | 3 | module.layer1.0.conv3.weight | (256, 32, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03803 | -0.00044 | 0.02407 |
# | 4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.05144 | -0.00304 | 0.02863 |
# | 5 | module.layer1.1.conv1.weight | (32, 256, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03162 | 0.00100 | 0.02178 |
# | 6 | module.layer1.1.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03687 | 0.00027 | 0.02591 |
# | 7 | module.layer1.1.conv3.weight | (256, 32, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03185 | -0.00048 | 0.02030 |
# | 8 | module.layer1.2.conv1.weight | (32, 256, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03011 | 0.00019 | 0.02205 |
# | 9 | module.layer1.2.conv2.weight | (32, 32, 3, 3) | 9216 | 9216 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03635 | -0.00009 | 0.02744 |
# | 10 | module.layer1.2.conv3.weight | (256, 32, 1, 1) | 8192 | 8192 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02803 | -0.00242 | 0.01684 |
# | 11 | module.layer2.0.conv1.weight | (64, 256, 1, 1) | 16384 | 16384 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03596 | -0.00125 | 0.02536 |
# | 12 | module.layer2.0.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02418 | 0.00002 | 0.01789 |
# | 13 | module.layer2.0.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02696 | 0.00000 | 0.01652 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02060 | -0.00044 | 0.01214 |
# | 15 | module.layer2.1.conv1.weight | (64, 512, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01739 | -0.00021 | 0.01075 |
# | 16 | module.layer2.1.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02545 | 0.00070 | 0.01662 |
# | 17 | module.layer2.1.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02249 | -0.00146 | 0.01323 |
# | 18 | module.layer2.2.conv1.weight | (64, 512, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02338 | -0.00056 | 0.01624 |
# | 19 | module.layer2.2.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02410 | 0.00015 | 0.01685 |
# | 20 | module.layer2.2.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02576 | 0.00017 | 0.01794 |
# | 21 | module.layer2.3.conv1.weight | (64, 512, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02341 | -0.00082 | 0.01743 |
# | 22 | module.layer2.3.conv2.weight | (64, 64, 3, 3) | 36864 | 36864 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02380 | -0.00048 | 0.01804 |
# | 23 | module.layer2.3.conv3.weight | (512, 64, 1, 1) | 32768 | 32768 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02311 | -0.00123 | 0.01596 |
# | 24 | module.layer3.0.conv1.weight | (128, 512, 1, 1) | 65536 | 65536 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02988 | -0.00090 | 0.02166 |
# | 25 | module.layer3.0.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01793 | -0.00014 | 0.01335 |
# | 26 | module.layer3.0.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02273 | -0.00044 | 0.01634 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01441 | -0.00002 | 0.00988 |
# | 28 | module.layer3.1.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01521 | -0.00035 | 0.01075 |
# | 29 | module.layer3.1.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01686 | -0.00003 | 0.01215 |
# | 30 | module.layer3.1.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01961 | -0.00062 | 0.01396 |
# | 31 | module.layer3.2.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01551 | -0.00032 | 0.01105 |
# | 32 | module.layer3.2.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01649 | -0.00058 | 0.01217 |
# | 33 | module.layer3.2.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01839 | -0.00051 | 0.01337 |
# | 34 | module.layer3.3.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01679 | -0.00058 | 0.01252 |
# | 35 | module.layer3.3.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01554 | -0.00052 | 0.01180 |
# | 36 | module.layer3.3.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01745 | -0.00095 | 0.01283 |
# | 37 | module.layer3.4.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01739 | -0.00078 | 0.01312 |
# | 38 | module.layer3.4.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01537 | -0.00063 | 0.01168 |
# | 39 | module.layer3.4.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01709 | -0.00124 | 0.01253 |
# | 40 | module.layer3.5.conv1.weight | (128, 1024, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01870 | -0.00072 | 0.01434 |
# | 41 | module.layer3.5.conv2.weight | (128, 128, 3, 3) | 147456 | 147456 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01530 | -0.00072 | 0.01172 |
# | 42 | module.layer3.5.conv3.weight | (1024, 128, 1, 1) | 131072 | 131072 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01855 | -0.00212 | 0.01395 |
# | 43 | module.layer4.0.conv1.weight | (256, 1024, 1, 1) | 262144 | 262144 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02219 | -0.00086 | 0.01714 |
# | 44 | module.layer4.0.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01231 | -0.00011 | 0.00960 |
# | 45 | module.layer4.0.conv3.weight | (2048, 256, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01452 | -0.00058 | 0.01130 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) | 2097152 | 2097152 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00903 | -0.00019 | 0.00688 |
# | 47 | module.layer4.1.conv1.weight | (256, 2048, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01434 | -0.00029 | 0.01122 |
# | 48 | module.layer4.1.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01229 | -0.00059 | 0.00963 |
# | 49 | module.layer4.1.conv3.weight | (2048, 256, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01432 | 0.00003 | 0.01107 |
# | 50 | module.layer4.2.conv1.weight | (256, 2048, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01774 | -0.00010 | 0.01394 |
# | 51 | module.layer4.2.conv2.weight | (256, 256, 3, 3) | 589824 | 589824 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01076 | -0.00036 | 0.00845 |
# | 52 | module.layer4.2.conv3.weight | (2048, 256, 1, 1) | 524288 | 524288 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01319 | 0.00018 | 0.00993 |
# | 53 | module.fc.weight | (1000, 2048) | 2048000 | 614400 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 70.00000 | 0.03188 | 0.00306 | 0.01495 |
# | 54 | Total sparsity: | - | 12335296 | 10901696 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 11.62193 | 0.00000 | 0.00000 | 0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-12-13 17:32:34,087 - Total sparsity: 11.62
#
# 2018-12-13 17:32:34,087 - --- validate (epoch=99)-----------
# 2018-12-13 17:32:34,088 - 50000 samples (256 per mini-batch)
# 2018-12-13 17:32:50,281 - Epoch: [99][ 50/ 195] Loss 0.734106 Top1 80.734375 Top5 95.226562
# 2018-12-13 17:32:57,293 - Epoch: [99][ 100/ 195] Loss 0.858832 Top1 77.910156 Top5 93.960938
# 2018-12-13 17:33:04,432 - Epoch: [99][ 150/ 195] Loss 0.981425 Top1 75.471354 Top5 92.401042
# 2018-12-13 17:33:10,060 - ==> Top1: 74.426 Top5: 91.962 Loss: 1.025
#
# 2018-12-13 17:33:10,194 - ==> Best Top1: 75.912 on Epoch: 0
# 2018-12-13 17:33:10,194 - ==> Best Top1: 75.492 on Epoch: 1
# 2018-12-13 17:33:10,194 - ==> Best Top1: 74.942 on Epoch: 2
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.564 on Epoch: 84 <======
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.514 on Epoch: 94
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.494 on Epoch: 80
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.488 on Epoch: 91
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.482 on Epoch: 82
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.482 on Epoch: 93
# 2018-12-13 17:33:10,195 - ==> Best Top1: 74.472 on Epoch: 95
# 2018-12-13 17:33:10,195 - Saving checkpoint to: logs/resnet50_filters_v4_with_FC___2018.12.11-172607/resnet50_filters_v4_with_FC_checkpoint.pth.tar
# 2018-12-13 17:33:10,457 - --- test ---------------------
# 2018-12-13 17:33:10,458 - 50000 samples (256 per mini-batch)
# 2018-12-13 17:33:26,953 - Test: [ 50/ 195] Loss 0.734106 Top1 80.734375 Top5 95.226562
# 2018-12-13 17:33:33,762 - Test: [ 100/ 195] Loss 0.858832 Top1 77.910156 Top5 93.960938
# 2018-12-13 17:33:40,901 - Test: [ 150/ 195] Loss 0.981425 Top1 75.471354 Top5 92.401042
# 2018-12-13 17:33:47,076 - ==> Top1: 74.426 Top5: 91.962 Loss: 1.025
version: 1
pruners:
fc_pruner:
class: AutomatedGradualPruner
initial_sparsity : 0.05
final_sparsity: 0.70
weights: module.fc.weight
filter_pruner:
class: L1RankedStructureParameterPruner_AGP
initial_sparsity : 0.05
final_sparsity: 0.50
group_type: Filters
weights: [module.layer1.0.conv1.weight,
module.layer1.1.conv1.weight,
module.layer1.2.conv1.weight,
module.layer2.0.conv1.weight,
module.layer2.1.conv1.weight,
module.layer2.2.conv1.weight,
module.layer2.3.conv1.weight,
module.layer3.0.conv1.weight,
module.layer3.1.conv1.weight,
module.layer3.2.conv1.weight,
module.layer3.3.conv1.weight,
module.layer3.4.conv1.weight,
module.layer3.5.conv1.weight,
module.layer4.0.conv1.weight,
module.layer4.1.conv1.weight,
module.layer4.2.conv1.weight,
module.layer1.0.conv2.weight,
module.layer1.1.conv2.weight,
module.layer1.2.conv2.weight,
module.layer2.0.conv2.weight,
module.layer2.1.conv2.weight,
module.layer2.2.conv2.weight,
module.layer2.3.conv2.weight,
module.layer3.0.conv2.weight,
module.layer3.1.conv2.weight,
module.layer3.2.conv2.weight,
module.layer3.3.conv2.weight,
module.layer3.4.conv2.weight,
module.layer3.5.conv2.weight,
module.layer4.0.conv2.weight,
module.layer4.1.conv2.weight,
module.layer4.2.conv2.weight]
fine_pruner:
class: AutomatedGradualPruner
initial_sparsity : 0.05
final_sparsity: 0.70
weights: [
module.layer4.0.conv2.weight,
module.layer4.0.conv3.weight,
module.layer4.0.downsample.0.weight,
module.layer4.1.conv1.weight,
module.layer4.1.conv2.weight,
module.layer4.1.conv3.weight,
module.layer4.2.conv1.weight,
module.layer4.2.conv2.weight,
module.layer4.2.conv3.weight]
extensions:
net_thinner:
class: 'FilterRemover'
thinning_func_str: remove_filters
arch: 'resnet50'
dataset: 'imagenet'
lr_schedulers:
pruning_lr:
class: ExponentialLR
gamma: 0.95
policies:
# - pruner:
# instance_name : fine_pruner
# starting_epoch: 0
# ending_epoch: 45
# frequency: 3
- pruner:
instance_name : filter_pruner
# args:
# mini_batch_pruning_frequency: 1
starting_epoch: 0
ending_epoch: 30
frequency: 1
- pruner:
instance_name : fc_pruner
starting_epoch: 0
ending_epoch: 30
frequency: 3
# After completeing the pruning, we perform network thinning and continue fine-tuning.
- extension:
instance_name: net_thinner
epochs: [31]
- lr_scheduler:
instance_name: pruning_lr
starting_epoch: 40
ending_epoch: 80
frequency: 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment