From 37d5774b01b1af506228d733069df1d0da2b8e92 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 3 Dec 2018 22:54:21 +0200
Subject: [PATCH] ResNet50 Dynamic Surgery: new schedule with improved results

Top1: 75.52% (-0.63% from TorchVision dense ResNet50)
Total sparsity: 82.6%
---
 .../resnet50.network_surgery2.yaml            | 203 ++++++++++++++++++
 1 file changed, 203 insertions(+)
 create mode 100755 examples/network_surgery/resnet50.network_surgery2.yaml

diff --git a/examples/network_surgery/resnet50.network_surgery2.yaml b/examples/network_surgery/resnet50.network_surgery2.yaml
new file mode 100755
index 0000000..5cbe324
--- /dev/null
+++ b/examples/network_surgery/resnet50.network_surgery2.yaml
@@ -0,0 +1,203 @@
+# This schedule follows the methodology proposed by Intel Labs China in the paper:
+#   Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
+#   NIPS 2016, https://arxiv.org/abs/1600.604493.
+#
+# Top1 is 75.518 (on Epoch: 99) vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
+# Total sparsity: 82.6%
+#
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.001 --compress=resnet50.network_surgery2.yaml --validation-size=0  --masks-sparsity --num-best-scores=10
+#
+#
+# Parameters:
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10739 | -0.00040 |    0.06567 |
+# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |            812 |    0.00000 |    0.00000 |  3.12500 | 80.17578 |  7.81250 |   80.17578 | 0.05457 | -0.00405 |    0.02019 |
+# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |           5052 |    0.00000 |    0.00000 |  7.81250 | 51.00098 |  6.25000 |   86.29557 | 0.02182 |  0.00054 |    0.00697 |
+# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |           2477 |    0.00000 |    0.00000 |  6.25000 | 84.88159 | 13.28125 |   84.88159 | 0.02654 |  0.00022 |    0.00923 |
+# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |           2975 |    0.00000 |    0.00000 |  1.56250 | 81.84204 | 14.06250 |   81.84204 | 0.04410 | -0.00247 |    0.01580 |
+# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |           2026 |    0.00000 |    0.00000 | 14.45312 | 87.63428 |  6.25000 |   87.63428 | 0.02121 |  0.00072 |    0.00704 |
+# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |           4064 |    0.00000 |    0.00000 |  6.25000 | 52.31934 |  0.00000 |   88.97569 | 0.01952 |  0.00019 |    0.00595 |
+# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |           1997 |    0.00000 |    0.00000 |  0.00000 | 87.81128 |  5.85938 |   87.81128 | 0.02324 |  0.00021 |    0.00751 |
+# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |           2994 |    0.00000 |    0.00000 |  9.37500 | 81.72607 |  0.00000 |   81.72607 | 0.02169 | -0.00005 |    0.00874 |
+# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |           4551 |    0.00000 |    0.00000 |  0.00000 | 45.41016 |  0.00000 |   87.65462 | 0.02076 | -0.00029 |    0.00698 |
+# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |           1938 |    0.00000 |    0.00000 |  0.00000 | 88.17139 | 10.15625 |   88.17139 | 0.02266 | -0.00103 |    0.00724 |
+# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |           5757 |    0.00000 |    0.00000 |  6.25000 | 82.43103 |  0.00000 |   82.43103 | 0.02551 | -0.00083 |    0.00988 |
+# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          23222 |    0.00000 |    0.00000 |  0.00000 | 43.48755 |  0.00000 |   84.25157 | 0.01525 | -0.00010 |    0.00572 |
+# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |           6978 |    0.00000 |    0.00000 |  0.00000 | 89.35242 | 28.90625 |   89.35242 | 0.01970 |  0.00022 |    0.00584 |
+# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |          13839 |    0.00000 |    0.00000 |  0.00000 | 89.44168 | 14.06250 |   89.44168 | 0.01643 | -0.00022 |    0.00459 |
+# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |           6780 |    0.00000 |    0.00000 | 17.18750 | 89.65454 |  0.00000 |   89.65454 | 0.01183 |  0.00018 |    0.00345 |
+# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          15531 |    0.00000 |    0.00000 |  0.00000 | 60.41260 |  2.34375 |   89.46737 | 0.01378 |  0.00027 |    0.00402 |
+# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |           6229 |    0.00000 |    0.00000 |  0.00000 | 90.49530 | 19.72656 |   90.49530 | 0.01613 | -0.00081 |    0.00447 |
+# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |           9000 |    0.00000 |    0.00000 |  1.95312 | 86.26709 |  0.00000 |   86.26709 | 0.01634 | -0.00031 |    0.00554 |
+# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          16032 |    0.00000 |    0.00000 |  0.00000 | 52.87476 |  0.00000 |   89.12760 | 0.01431 | -0.00007 |    0.00440 |
+# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |           6783 |    0.00000 |    0.00000 |  0.00000 | 89.64996 |  5.85938 |   89.64996 | 0.01736 | -0.00006 |    0.00516 |
+# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |           8544 |    0.00000 |    0.00000 |  1.75781 | 86.96289 |  0.00000 |   86.96289 | 0.01625 | -0.00028 |    0.00555 |
+# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          23301 |    0.00000 |    0.00000 |  0.00000 | 33.40454 |  0.00000 |   84.19800 | 0.01532 | -0.00025 |    0.00578 |
+# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |           7932 |    0.00000 |    0.00000 |  0.00000 | 87.89673 | 17.57812 |   87.89673 | 0.01693 | -0.00015 |    0.00545 |
+# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |          19673 |    0.00000 |    0.00000 |  0.00000 | 84.99069 |  0.00000 |   84.99069 | 0.02137 | -0.00038 |    0.00765 |
+# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         101140 |    0.00000 |    0.00000 |  0.00000 | 47.00165 |  0.00000 |   82.85251 | 0.01212 | -0.00015 |    0.00467 |
+# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          29483 |    0.00000 |    0.00000 |  0.00000 | 88.75313 |  6.05469 |   88.75313 | 0.01549 |  0.00009 |    0.00486 |
+# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |          52743 |    0.00000 |    0.00000 |  0.00000 | 89.94007 |  4.98047 |   89.94007 | 0.01071 |  0.00006 |    0.00310 |
+# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          28594 |    0.00000 |    0.00000 |  8.10547 | 89.09225 |  0.00000 |   89.09225 | 0.01041 | -0.00005 |    0.00319 |
+# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |          65069 |    0.00000 |    0.00000 |  0.00000 | 54.10919 |  0.00000 |   88.96807 | 0.00993 | -0.00002 |    0.00310 |
+# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          27368 |    0.00000 |    0.00000 |  0.00000 | 89.55994 |  1.95312 |   89.55994 | 0.01346 | -0.00056 |    0.00400 |
+# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          26238 |    0.00000 |    0.00000 |  1.75781 | 89.99100 |  0.00000 |   89.99100 | 0.01042 | -0.00007 |    0.00305 |
+# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |          67618 |    0.00000 |    0.00000 |  0.00000 | 45.94727 |  0.00000 |   88.53590 | 0.00971 | -0.00023 |    0.00312 |
+# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          28073 |    0.00000 |    0.00000 |  0.00000 | 89.29100 |  0.97656 |   89.29100 | 0.01248 | -0.00014 |    0.00381 |
+# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          27645 |    0.00000 |    0.00000 |  0.48828 | 89.45427 |  0.00000 |   89.45427 | 0.01131 | -0.00005 |    0.00343 |
+# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |          69321 |    0.00000 |    0.00000 |  0.00000 | 44.19861 |  0.00000 |   88.24717 | 0.00961 | -0.00017 |    0.00315 |
+# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          29057 |    0.00000 |    0.00000 |  0.00000 | 88.91563 |  3.61328 |   88.91563 | 0.01201 | -0.00033 |    0.00376 |
+# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          28934 |    0.00000 |    0.00000 |  0.09766 | 88.96255 |  0.00000 |   88.96255 | 0.01172 | -0.00016 |    0.00366 |
+# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |          70785 |    0.00000 |    0.00000 |  0.00000 | 44.31305 |  0.00000 |   87.99896 | 0.00958 | -0.00025 |    0.00319 |
+# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          29261 |    0.00000 |    0.00000 |  0.00000 | 88.83781 |  1.46484 |   88.83781 | 0.01205 | -0.00054 |    0.00379 |
+# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          30074 |    0.00000 |    0.00000 |  0.00000 | 88.52768 |  0.00000 |   88.52768 | 0.01263 | -0.00009 |    0.00405 |
+# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |          72988 |    0.00000 |    0.00000 |  0.00000 | 44.97986 |  0.00000 |   87.62546 | 0.00984 | -0.00028 |    0.00332 |
+# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          31194 |    0.00000 |    0.00000 |  0.00000 | 88.10043 |  1.66016 |   88.10043 | 0.01284 | -0.00089 |    0.00420 |
+# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         114432 |    0.00000 |    0.00000 |  0.00000 | 78.17383 |  0.00000 |   78.17383 | 0.01710 | -0.00038 |    0.00754 |
+# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         461529 |    0.00000 |    0.00000 |  0.00000 | 41.99524 |  0.00000 |   80.43785 | 0.00872 | -0.00015 |    0.00370 |
+# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         190377 |    0.00000 |    0.00000 |  0.00000 | 81.84423 |  0.00000 |   81.84423 | 0.01097 | -0.00013 |    0.00443 |
+# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |         296214 |    0.00000 |    0.00000 |  0.00000 | 85.87542 |  0.00000 |   85.87542 | 0.00690 | -0.00001 |    0.00243 |
+# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         235460 |    0.00000 |    0.00000 |  0.00000 | 77.54478 |  0.00000 |   77.54478 | 0.01123 | -0.00028 |    0.00503 |
+# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         569044 |    0.00000 |    0.00000 |  0.00000 | 27.84805 |  0.00000 |   75.88077 | 0.00897 | -0.00042 |    0.00423 |
+# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         193763 |    0.00000 |    0.00000 |  0.00000 | 81.52132 |  0.00000 |   81.52132 | 0.01092 |  0.00017 |    0.00445 |
+# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         254128 |    0.00000 |    0.00000 |  0.00000 | 75.76447 |  0.00000 |   75.76447 | 0.01357 | -0.00013 |    0.00634 |
+# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         537393 |    0.00000 |    0.00000 |  0.00000 | 47.85385 |  0.00000 |   77.22232 | 0.00767 | -0.00029 |    0.00354 |
+# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         162045 |    0.00000 |    0.00000 |  0.00000 | 84.54618 |  0.14648 |   84.54618 | 0.00990 |  0.00027 |    0.00362 |
+# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         396407 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   80.64419 | 0.03125 |  0.00427 |    0.01213 |
+# | 54 | Total sparsity:                     | -                  |      25502912 |        4434272 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   82.61268 | 0.00000 |  0.00000 |    0.00000 |
+# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+# 2018-11-14 09:44:15,161 - Total sparsity: 82.61
+#
+# 2018-11-14 09:44:15,304 - --- validate (epoch=99)-----------
+# 2018-11-14 09:44:15,305 - 50000 samples (256 per mini-batch)
+# 2018-11-14 09:44:34,609 - Epoch: [99][   50/  195]    Loss 0.697465    Top1 81.437500    Top5 95.703125
+# 2018-11-14 09:44:42,914 - Epoch: [99][  100/  195]    Loss 0.816492    Top1 78.804688    Top5 94.542969
+# 2018-11-14 09:44:51,516 - Epoch: [99][  150/  195]    Loss 0.930595    Top1 76.380208    Top5 93.135417
+# 2018-11-14 09:44:58,448 - ==> Top1: 75.518    Top5: 92.620    Loss: 0.975
+#
+# 2018-11-14 09:44:58,508 - ==> Best Top1: 75.518 on Epoch: 99
+# 2018-11-14 09:44:58,508 - ==> Best Top1: 75.494 on Epoch: 83
+# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.480 on Epoch: 89
+# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.462 on Epoch: 91
+# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.454 on Epoch: 97
+# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.452 on Epoch: 93
+# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.450 on Epoch: 96
+# 2018-11-14 09:44:58,510 - ==> Best Top1: 75.448 on Epoch: 90
+# 2018-11-14 09:44:58,510 - ==> Best Top1: 75.438 on Epoch: 94
+# 2018-11-14 09:44:58,510 - ==> Best Top1: 75.436 on Epoch: 73
+# 2018-11-14 09:44:58,510 - Saving checkpoint to: logs/resnet50_lr_0.001_mult_0.005___2018.11.12-041119/resnet50_lr_0.001_mult_0.005_checkpoint.pth.tar
+# 2018-11-14 09:44:59,539 - --- test ---------------------
+# 2018-11-14 09:44:59,540 - 50000 samples (256 per mini-batch)
+# 2018-11-14 09:45:18,661 - Test: [   50/  195]    Loss 0.697465    Top1 81.437500    Top5 95.703125
+# 2018-11-14 09:45:27,176 - Test: [  100/  195]    Loss 0.816492    Top1 78.804688    Top5 94.542969
+# 2018-11-14 09:45:36,202 - Test: [  150/  195]    Loss 0.930595    Top1 76.380208    Top5 93.135417
+# 2018-11-14 09:45:43,168 - ==> Top1: 75.518    Top5: 92.620    Loss: 0.975#
+# --- validate (epoch=359)-----------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 91.480    Top5: 99.600    Loss: 0.363
+#
+# ==> Best Top1: 91.790 (0.0 sparsity) on Epoch: 181
+#
+# Saving checkpoint to: logs/2018.10.31-232827/checkpoint.pth.tar
+# --- test ---------------------
+# 10000 samples (256 per mini-batch)
+# ==> Top1: 91.480    Top5: 99.600    Loss: 0.363
+#
+#
+# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.31-232827/2018.10.31-232827.log
+#
+# real    64m27.317s
+# user    118m46.020s
+# sys     14m3.627s
+
+version: 1
+pruners:
+  pruner1:
+    class: SplicingPruner
+    low_thresh_mult: 0.9 # 0.6
+    hi_thresh_mult: 1.1 # 0.7
+    sensitivity_multiplier: 0.005   # 0.015
+    sensitivities:
+      #'module.conv1.weight': 0.60
+      module.layer1.0.conv1.weight: 0.10
+      module.layer1.0.conv2.weight: 0.40
+      module.layer1.0.conv3.weight: 0.40
+      module.layer1.0.downsample.0.weight: 0.20
+      module.layer1.1.conv1.weight: 0.60
+      module.layer1.1.conv2.weight: 0.60
+      module.layer1.1.conv3.weight: 0.60
+      module.layer1.2.conv1.weight: 0.30
+      module.layer1.2.conv2.weight: 0.60
+      module.layer1.2.conv3.weight: 0.60
+
+      module.layer2.0.conv1.weight: 0.30
+      module.layer2.0.conv2.weight: 0.40
+      module.layer2.0.conv3.weight: 0.60
+      module.layer2.0.downsample.0.weight: 0.50
+      module.layer2.1.conv1.weight: 0.60
+      module.layer2.1.conv2.weight: 0.60
+      module.layer2.1.conv3.weight: 0.60
+      module.layer2.2.conv1.weight: 0.40
+      module.layer2.2.conv2.weight: 0.60
+      module.layer2.2.conv3.weight: 0.60
+      module.layer2.3.conv1.weight: 0.50
+      module.layer2.3.conv2.weight: 0.40
+      module.layer2.3.conv3.weight: 0.50
+
+      module.layer3.0.conv1.weight: 0.40
+      module.layer3.0.conv2.weight: 0.30
+      module.layer3.0.conv3.weight: 0.60
+      module.layer3.0.downsample.0.weight: 0.60
+      module.layer3.1.conv1.weight: 0.60
+      module.layer3.1.conv2.weight: 0.60
+      module.layer3.1.conv3.weight: 0.60
+      module.layer3.2.conv1.weight: 0.60
+      module.layer3.2.conv2.weight: 0.60
+      module.layer3.2.conv3.weight: 0.60
+      module.layer3.3.conv1.weight: 0.60
+      module.layer3.3.conv2.weight: 0.60
+      module.layer3.3.conv3.weight: 0.60
+      module.layer3.4.conv1.weight: 0.60
+      module.layer3.4.conv2.weight: 0.60
+      module.layer3.4.conv3.weight: 0.60
+      module.layer3.5.conv1.weight: 0.60
+      module.layer3.5.conv2.weight: 0.60
+      module.layer3.5.conv3.weight: 0.60
+
+      module.layer4.0.conv1.weight: 0.20
+      module.layer4.0.conv2.weight: 0.30
+      module.layer4.0.conv3.weight: 0.30
+      module.layer4.0.downsample.0.weight: 0.40
+      module.layer4.1.conv1.weight: 0.15
+      module.layer4.1.conv2.weight: 0.15
+      module.layer4.1.conv3.weight: 0.30
+      module.layer4.2.conv1.weight: 0.15
+      module.layer4.2.conv2.weight: 0.30
+      module.layer4.2.conv3.weight: 0.45
+      module.fc.weight: 0.50
+
+lr_schedulers:
+  training_lr:
+    class: StepLR
+    step_size: 45
+    gamma: 0.10
+
+policies:
+  - pruner:
+      instance_name: pruner1
+      args:
+        keep_mask: True
+        #mini_batch_pruning_frequency: 1
+        mask_on_forward_only: True
+    starting_epoch: 0
+    ending_epoch: 47
+    frequency: 1
+
+
+  - lr_scheduler:
+      instance_name: training_lr
+    starting_epoch: 0
+    ending_epoch: 400
+    frequency: 1
-- 
GitLab