Skip to content
Snippets Groups Projects
Commit ff6985ad authored by Neta Zmora's avatar Neta Zmora
Browse files

AGP filter-pruning: a new ResNet20 schedule

Another schedule or ResNet20 Filter-wise pruning for ResNet20, with
64.6% sparsity, 25.4% compute reduction and Top1 91.47% (vs. 91.78
basline).
parent 961fcfdd
No related branches found
No related tags found
No related merge requests found
# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
# 3. Row pruning for the last linear (fully-connected) layer.
#
# Baseline results:
# Top1: 91.780 Top5: 99.710 Loss: 0.376
# Total MACs: 40,813,184
# # of parameters: 270,896
#
# Results:
# Top1: 91.470 on Epoch: 288
# Total MACs: 30,433,920 (74.6% of the original compute)
# Total sparsity: 56.41%
# # of parameters: 95922 (=35.4% of the baseline parameters ==> 64.6% sparsity)
#
# time python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-size=0
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# | 0 | module.conv1.weight | (16, 3, 3, 3) | 432 | 432 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.41372 | -0.00535 | 0.29289 |
# | 1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15610 | -0.01373 | 0.11096 |
# | 2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.15429 | 0.00180 | 0.11294 |
# | 3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13297 | -0.01580 | 0.10052 |
# | 4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12638 | -0.00556 | 0.09699 |
# | 5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17940 | -0.01313 | 0.13183 |
# | 6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.14671 | -0.00056 | 0.11065 |
# | 7 | module.layer2.0.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16872 | -0.00380 | 0.12838 |
# | 8 | module.layer2.0.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18371 | 0.00119 | 0.14401 |
# | 9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) | 256 | 256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.33976 | 0.00148 | 0.24721 |
# | 10 | module.layer2.1.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12741 | -0.00734 | 0.09754 |
# | 11 | module.layer2.1.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10207 | 0.00286 | 0.07914 |
# | 12 | module.layer2.2.conv1.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.13480 | -0.00943 | 0.10174 |
# | 13 | module.layer2.2.conv2.weight | (16, 16, 3, 3) | 2304 | 2304 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.09721 | 0.00049 | 0.07094 |
# | 14 | module.layer3.0.conv1.weight | (64, 16, 3, 3) | 9216 | 4608 | 0.00000 | 0.00000 | 0.00000 | 2.63672 | 1.56250 | 50.00000 | 0.11758 | -0.00484 | 0.07093 |
# | 15 | module.layer3.0.conv2.weight | (64, 64, 3, 3) | 36864 | 18432 | 0.00000 | 0.00000 | 1.56250 | 2.00195 | 0.00000 | 50.00000 | 0.08720 | -0.00522 | 0.05143 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) | 1024 | 1024 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.16003 | -0.01049 | 0.12534 |
# | 17 | module.layer3.1.conv1.weight | (63, 64, 3, 3) | 36288 | 10887 | 0.00000 | 0.00000 | 0.00000 | 9.20139 | 1.58730 | 69.99835 | 0.07613 | -0.00415 | 0.03605 |
# | 18 | module.layer3.1.conv2.weight | (64, 63, 3, 3) | 36288 | 10887 | 0.00000 | 0.00000 | 1.58730 | 9.10218 | 0.00000 | 69.99835 | 0.07025 | -0.00544 | 0.03305 |
# | 19 | module.layer3.2.conv1.weight | (62, 64, 3, 3) | 35712 | 10714 | 0.00000 | 0.00000 | 0.00000 | 13.33165 | 3.22581 | 69.99888 | 0.07118 | -0.00550 | 0.03367 |
# | 20 | module.layer3.2.conv2.weight | (64, 62, 3, 3) | 35712 | 10714 | 0.00000 | 0.00000 | 3.22581 | 28.80544 | 0.00000 | 69.99888 | 0.04353 | 0.00071 | 0.01894 |
# | 21 | module.fc.weight | (10, 64) | 640 | 320 | 0.00000 | 50.00000 | 0.00000 | 0.00000 | 0.00000 | 50.00000 | 0.57334 | -0.00001 | 0.35840 |
# | 22 | Total sparsity: | - | 220080 | 95922 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 56.41494 | 0.00000 | 0.00000 | 0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 56.41
#
# --- validate (epoch=359)-----------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.140 Top5: 99.750 Loss: 0.331
#
# ==> Best Top1: 91.470 on Epoch: 288
# Saving checkpoint to: logs/2018.11.08-232134/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.140 Top5: 99.750 Loss: 0.331
#
#
# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.11.08-232134/2018.11.08-232134.log
#
# real 37m51.274s
# user 85m48.506s
# sys 10m35.410s
version: 1
pruners:
low_pruner:
class: L1RankedStructureParameterPruner_AGP
initial_sparsity : 0.10
final_sparsity: 0.50
reg_regims:
module.layer2.0.conv1.weight: Filters
module.layer2.0.conv2.weight: Filters
module.layer2.0.downsample.0.weight: Filters
module.layer2.1.conv2.weight: Filters
module.layer2.2.conv2.weight: Filters # to balance the BN
module.layer2.1.conv1.weight: Filters
module.layer2.2.conv1.weight: Filters
#module.layer3.0.conv2.weight: Filters
#module.layer3.0.downsample.0.weight: Filters
#module.layer3.1.conv2.weight: Filters
#module.layer3.2.conv2.weight: Filters
fine_pruner1:
class: AutomatedGradualPruner
initial_sparsity : 0.05
final_sparsity: 0.50
weights: [module.layer3.0.conv1.weight, module.layer3.0.conv2.weight]
fine_pruner2:
class: AutomatedGradualPruner
initial_sparsity : 0.05
final_sparsity: 0.70
weights: [module.layer3.1.conv1.weight, module.layer3.1.conv2.weight,
module.layer3.2.conv1.weight, module.layer3.2.conv2.weight]
fc_pruner:
class: L1RankedStructureParameterPruner_AGP
initial_sparsity : 0.05
final_sparsity: 0.50
reg_regims:
module.fc.weight: Rows
lr_schedulers:
pruning_lr:
class: StepLR
step_size: 50
gamma: 0.10
extensions:
net_thinner:
class: 'FilterRemover'
thinning_func_str: remove_filters
arch: 'resnet20_cifar'
dataset: 'cifar10'
policies:
- pruner:
instance_name : low_pruner
starting_epoch: 180
ending_epoch: 210
frequency: 2
- pruner:
instance_name : fine_pruner1
starting_epoch: 210
ending_epoch: 230
frequency: 2
- pruner:
instance_name : fine_pruner2
starting_epoch: 210
ending_epoch: 230
frequency: 2
- pruner:
instance_name : fc_pruner
starting_epoch: 210
ending_epoch: 230
frequency: 2
# Currently the thinner is disabled until the the structure pruner is done, because it interacts
# with the sparsity goals of the L1RankedStructureParameterPruner_AGP.
# This can be fixed rather easily.
# - extension:
# instance_name: net_thinner
# starting_epoch: 0
# ending_epoch: 20
# frequency: 2
# After completeing the pruning, we perform network thinning and continue fine-tuning.
- extension:
instance_name: net_thinner
epochs: [212]
- lr_scheduler:
instance_name: pruning_lr
starting_epoch: 180
ending_epoch: 400
frequency: 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment