From 78cdad0730a8c3850bdb3f0e045b3b2f390b2576 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 30 Oct 2018 16:38:18 +0200
Subject: [PATCH] ResNet20-Cifar: imporved the results of L1 filter pruning

Small improvement in the results
---
 .../resnet20_filters.schedule_agp.yaml        | 105 ++++++++++--------
 1 file changed, 59 insertions(+), 46 deletions(-)

diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
index 406180e..88becf2 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -6,62 +6,63 @@
 # Baseline results:
 #     Top1: 91.780    Top5: 99.710    Loss: 0.376
 #     Total MACs: 40,813,184
+#     # of parameters: 270,896
 #
 # Results:
-#     Top1: 91.760    Top5: 99.700    Loss: 1.546
-#     Total MACs: 35,947,136
+#     Top1: 91.73
+#     Total MACs: 30,655,104
 #     Total sparsity: 41.10
+#     # of parameters: 120,000  (=55.7% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-size=0
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 # |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 # |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
-# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.26754 | -0.00478 |    0.18996 |
-# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10113 | -0.00595 |    0.07182 |
-# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09882 | -0.00013 |    0.07256 |
-# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08715 | -0.01028 |    0.06691 |
-# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08150 | -0.00316 |    0.06242 |
-# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11227 | -0.00627 |    0.08206 |
-# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09145 |  0.00145 |    0.06919 |
-# |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09975 | -0.00178 |    0.07747 |
-# |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08692 | -0.00438 |    0.06784 |
-# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17339 | -0.00644 |    0.12457 |
-# | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07515 | -0.00582 |    0.05967 |
-# | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06694 | -0.00409 |    0.05272 |
-# | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07822 | -0.00873 |    0.06161 |
-# | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06251 |  0.00119 |    0.04923 |
-# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06655 | -0.00436 |    0.05293 |
-# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06298 | -0.00286 |    0.05019 |
-# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08574 | -0.00490 |    0.06750 |
-# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.00684 |  0.00000 |   69.99783 | 0.05113 | -0.00318 |    0.02568 |
-# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.64160 |  0.00000 |   69.99783 | 0.04585 | -0.00355 |    0.02293 |
-# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.88867 |  0.00000 |   69.99783 | 0.04487 | -0.00409 |    0.02258 |
-# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 31.51855 |  1.56250 |   69.99783 | 0.02512 |  0.00008 |    0.01251 |
-# | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.48359 | -0.00001 |    0.30379 |
-# | 22 | Total sparsity:                     | -              |        251888 |         148352 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.10398 | 0.00000 |  0.00000 |    0.00000 |
+# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42267 | -0.01028 |    0.29903 |
+# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15895 | -0.01265 |    0.11210 |
+# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15610 |  0.00257 |    0.11472 |
+# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13888 | -0.01590 |    0.10543 |
+# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13052 | -0.00519 |    0.10135 |
+# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18351 | -0.01298 |    0.13564 |
+# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14909 | -0.00098 |    0.11435 |
+# |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17438 | -0.00580 |    0.13427 |
+# |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18654 | -0.00126 |    0.14499 |
+# |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34412 | -0.01243 |    0.24940 |
+# | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11833 | -0.00937 |    0.08865 |
+# | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09171 | -0.00197 |    0.06956 |
+# | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13403 | -0.01057 |    0.09999 |
+# | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09652 |  0.00544 |    0.07033 |
+# | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13635 | -0.00543 |    0.10654 |
+# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09992 | -0.00600 |    0.07893 |
+# | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17133 | -0.00926 |    0.13503 |
+# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  8.47168 |  1.56250 |   69.99783 | 0.07819 | -0.00423 |    0.03752 |
+# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  1.56250 |  8.37402 |  0.00000 |   69.99783 | 0.07238 | -0.00539 |    0.03450 |
+# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 11.93848 |  3.12500 |   69.99783 | 0.07195 | -0.00571 |    0.03462 |
+# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  3.12500 | 28.75977 |  1.56250 |   69.99783 | 0.04405 |  0.00060 |    0.02004 |
+# | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.57112 | -0.00001 |    0.36129 |
+# | 22 | Total sparsity:                     | -              |        223536 |         120000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   46.31737 | 0.00000 |  0.00000 |    0.00000 |
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
-# Total sparsity: 41.10
+# Total sparsity: 46.32
 #
 # --- validate (epoch=359)-----------
-# 5000 samples (256 per mini-batch)
-# ==> Top1: 93.720    Top5: 99.880    Loss: 1.529
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 91.490    Top5: 99.710    Loss: 0.346
 #
-# ==> Best Top1: 96.900   On Epoch: 181
+# ==> Best Top1: 91.730   On Epoch: 344
 #
-# Saving checkpoint to: logs/2018.10.15-111439/checkpoint.pth.tar
+# Saving checkpoint to: logs/2018.10.30-150931/checkpoint.pth.tar
 # --- test ---------------------
 # 10000 samples (256 per mini-batch)
-# ==> Top1: 91.760    Top5: 99.700    Loss: 1.546
+# ==> Top1: 91.490    Top5: 99.710    Loss: 0.346
 #
 #
-# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.15-111439/2018.10.15-111439.log
+# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.10.30-150931/2018.10.30-150931.log
 #
-# real    31m55.802s
-# user    73m1.353s
-# sys     8m46.687s
-
+# real    36m36.329s
+# user    82m32.685s
+# sys     10m8.746s
 
 version: 1
 
@@ -69,12 +70,24 @@ pruners:
   low_pruner:
     class: L1RankedStructureParameterPruner_AGP
     initial_sparsity : 0.10
-    final_sparsity: 0.40
+    final_sparsity: 0.50
     reg_regims:
       module.layer2.0.conv1.weight: Filters
+
+      module.layer2.0.conv2.weight: Filters
+      module.layer2.0.downsample.0.weight: Filters
+      module.layer2.1.conv2.weight: Filters
+      module.layer2.2.conv2.weight: Filters  # to balance the BN
+
       module.layer2.1.conv1.weight: Filters
       module.layer2.2.conv1.weight: Filters
 
+      #module.layer3.0.conv2.weight: Filters
+      #module.layer3.0.downsample.0.weight: Filters
+      #module.layer3.1.conv2.weight: Filters
+      #module.layer3.2.conv2.weight: Filters
+
+
   fine_pruner:
     class:  AutomatedGradualPruner
     initial_sparsity : 0.05
@@ -95,7 +108,6 @@ lr_schedulers:
     step_size: 50
     gamma: 0.10
 
-
 extensions:
   net_thinner:
       class: 'FilterRemover'
@@ -103,23 +115,24 @@ extensions:
       arch: 'resnet20_cifar'
       dataset: 'cifar10'
 
+
 policies:
   - pruner:
       instance_name : low_pruner
     starting_epoch: 180
-    ending_epoch: 200
+    ending_epoch: 210
     frequency: 2
 
   - pruner:
       instance_name : fine_pruner
-    starting_epoch: 200
-    ending_epoch: 220
+    starting_epoch: 210
+    ending_epoch: 230
     frequency: 2
 
   - pruner:
       instance_name : fc_pruner
-    starting_epoch: 200
-    ending_epoch: 220
+    starting_epoch: 210
+    ending_epoch: 230
     frequency: 2
 
   # Currently the thinner is disabled until the the structure pruner is done, because it interacts
@@ -134,10 +147,10 @@ policies:
 # After completeing the pruning, we perform network thinning and continue fine-tuning.
   - extension:
       instance_name: net_thinner
-    epochs: [202]
+    epochs: [212]
 
   - lr_scheduler:
       instance_name: pruning_lr
-    starting_epoch: 0
+    starting_epoch: 180
     ending_epoch: 400
     frequency: 1
-- 
GitLab