diff --git a/examples/ssl/vgg16_cifar_ssl_channels_training.yaml b/examples/ssl/vgg16_cifar_ssl_channels_training.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9b7155bfab9b1355b207b66230c7744ca03e7c2b --- /dev/null +++ b/examples/ssl/vgg16_cifar_ssl_channels_training.yaml @@ -0,0 +1,92 @@ +# SSL: Channel regularization +# We compressed the compute from 313M MACs to 85M MACs (x3.7), and the parameters from 14.7M to 725K (x20) using SSL +# with channel-wise regularization, with a drop of 0.19% Top1 accuracy and w/o muuch effort. +# +# time python3 compress_classifier.py --arch vgg16_cifar ../../../data.cifar10 -p=50 --lr=0.05 --epochs=180 --compress=../ssl/vgg16_cifar_ssl_channels_training.yaml -j=1 --deterministic +# +# The results below are from the SSL training session, and you can follow-up with some fine-tuning: +# time python3 compress_classifier.py --arch vgg16_cifar ../../../data.cifar10 --resume=checkpoint.vgg16_cifar.pth.tar --lr=0.01 --epochs=20 +# ==> Top1: 91.010 Top5: 99.480 Loss: 0.513 +# +# Parameters: +# +----------+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean | +# |----------+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------| +# | 0.00000 | features.module.0.weight | (31, 3, 3, 3) | 837 | 837 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.32376 | -0.00329 | 0.23517 | +# | 1.00000 | features.module.2.weight | (47, 31, 3, 3) | 13113 | 13113 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.07184 | -0.00374 | 0.04210 | +# | 2.00000 | features.module.5.weight | (98, 47, 3, 3) | 41454 | 41454 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04835 | -0.00384 | 0.03511 | +# | 3.00000 | features.module.7.weight | (117, 98, 3, 3) | 103194 | 103194 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.03096 | -0.00500 | 0.02305 | +# | 4.00000 | features.module.10.weight | (193, 117, 3, 3) | 203229 | 203229 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02948 | -0.00345 | 0.02259 | +# | 5.00000 | features.module.12.weight | (164, 193, 3, 3) | 284868 | 284868 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01766 | -0.00233 | 0.01313 | +# | 6.00000 | features.module.14.weight | (24, 164, 3, 3) | 35424 | 35424 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.01755 | -0.00082 | 0.01235 | +# | 7.00000 | features.module.17.weight | (15, 24, 3, 3) | 3240 | 3240 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.04479 | 0.00043 | 0.03239 | +# | 8.00000 | features.module.19.weight | (9, 15, 3, 3) | 1215 | 1215 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.08220 | 0.00163 | 0.06065 | +# | 9.00000 | features.module.21.weight | (7, 9, 3, 3) | 567 | 567 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.12944 | 0.00961 | 0.09122 | +# | 10.00000 | features.module.24.weight | (5, 7, 3, 3) | 315 | 315 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.17881 | 0.02638 | 0.12776 | +# | 11.00000 | features.module.26.weight | (7, 5, 3, 3) | 315 | 315 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.18656 | 0.03477 | 0.12432 | +# | 12.00000 | features.module.28.weight | (512, 7, 3, 3) | 32256 | 32256 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.02310 | 0.00009 | 0.01329 | +# | 13.00000 | classifier.weight | (10, 512) | 5120 | 5120 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10157 | -0.00002 | 0.07181 | +# | 14.00000 | Total sparsity: | - | 725147 | 725147 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | +# +----------+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+ +# +# Total sparsity: 0.00 +# --- test --------------------- +# 10000 samples (256 per mini-batch) +# Test: [ 10/ 39] Loss 0.454324 Top1 90.664062 Top5 99.609375 +# Test: [ 20/ 39] Loss 0.450643 Top1 90.722656 Top5 99.511719 +# Test: [ 30/ 39] Loss 0.441285 Top1 90.807292 Top5 99.557292 +# Test: [ 40/ 39] Loss 0.458055 Top1 90.740000 Top5 99.580000 +# ==> Top1: 90.740 Top5: 99.580 Loss: 0.458 + +lr_schedulers: + training_lr: + class: StepLR + step_size: 45 + gamma: 0.10 + +regularizers: + Channels_groups_regularizer: + class: GroupLassoRegularizer + reg_regims: + features.module.0.weight: [0.0008, Channels] + features.module.2.weight: [0.0008, Channels] + features.module.5.weight: [0.0008, Channels] + features.module.7.weight: [0.0008, Channels] + features.module.10.weight: [0.0008, Channels] + features.module.12.weight: [0.0008, Channels] + features.module.14.weight: [0.0008, Channels] + features.module.17.weight: [0.0008, Channels] + features.module.19.weight: [0.0008, Channels] + features.module.21.weight: [0.0008, Channels] + features.module.24.weight: [0.0008, Channels] + features.module.26.weight: [0.0008, Channels] + features.module.28.weight: [0.0008, Channels] + + threshold_criteria: Mean_Abs + +extensions: + net_thinner: + class: 'ChannelRemover' + thinning_func_str: remove_channels + arch: 'vgg16_cifar' + dataset: 'cifar10' + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 45 + ending_epoch: 300 + frequency: 1 + +# After completeing the regularization, we perform network thinning and exit. + - extension: + instance_name: net_thinner + epochs: [179] + + - regularizer: + instance_name: Channels_groups_regularizer + args: + keep_mask: True + starting_epoch: 0 + ending_epoch: 180 + frequency: 1