From 32a7e4bfcf9fcdea76c3d778efb62b664fe6b088 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Thu, 30 Apr 2020 10:23:09 +0300
Subject: [PATCH] Knowledge distillation fixes (#503)

Fixed two long-standing bugs in knowledge distillation:
 * Distillation loss needs to be scaled by T^2 (#122)
 * Use tensor.clone instead of new_tensor when caching student logits (#234)
Updated example results and uploaded the script to generate them
---
 README.md                                     |   6 +-
 distiller/knowledge_distillation.py           |  20 ++-
 examples/README.md                            |  14 +-
 .../preact_resnet_cifar_base_fp32.yaml        | 143 +++++++++++------
 ...preact_resnet_cifar_quant_distill_tests.sh |  72 +++++++++
 .../preact_resnet_cifar_dorefa.yaml           | 151 ++++++++++++------
 6 files changed, 290 insertions(+), 116 deletions(-)
 create mode 100644 examples/quantization/preact_resnet_cifar_quant_distill_tests.sh

diff --git a/README.md b/README.md
index 625010a..ac520da 100755
--- a/README.md
+++ b/README.md
@@ -174,9 +174,9 @@ If you do not use CUDA 10.1 in your environment, please refer to [PyTorch websit
 
 Distiller comes with sample applications and tutorials covering a range of model types:
 
-| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) |
-|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|
-| [Image classification](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression) | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | Knowledge Distillation |
+|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|:--------:|
+| [Image classification](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression) | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | [Word-level language model](https://github.com/NervanaSystems/distiller/tree/master/examples/word_language_model)| :white_check_mark: | :white_check_mark: | | |
 | [Translation (GNMT)](https://github.com/NervanaSystems/distiller/tree/master/examples/GNMT) | | :white_check_mark: | | |
 | [Recommendation System (NCF)](https://github.com/NervanaSystems/distiller/tree/master/examples/ncf) | |  :white_check_mark: | | |
diff --git a/distiller/knowledge_distillation.py b/distiller/knowledge_distillation.py
index acba5ce..48f313a 100644
--- a/distiller/knowledge_distillation.py
+++ b/distiller/knowledge_distillation.py
@@ -104,7 +104,7 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy):
 
     def forward(self, *inputs):
         """
-        Performs forward propagation through both student and teached models and caches the logits.
+        Performs forward propagation through both student and teacher models and caches the logits.
         This function MUST be used instead of calling the student model directly.
 
         Returns:
@@ -120,7 +120,7 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy):
             self.last_teacher_logits = self.teacher(*inputs)
 
         out = self.student(*inputs)
-        self.last_students_logits = out.new_tensor(out, requires_grad=True)
+        self.last_students_logits = out.clone()
 
         return out
 
@@ -150,14 +150,20 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy):
         # soft_targets = F.softmax(self.cached_teacher_logits[minibatch_id] / self.temperature)
         soft_targets = F.softmax(self.last_teacher_logits / self.temperature, dim=1)
 
-        # The averaging used in PyTorch KL Div implementation is wrong, so we work around as suggested in
-        # https://pytorch.org/docs/stable/nn.html#kldivloss
-        # (Also see https://github.com/pytorch/pytorch/issues/6622, https://github.com/pytorch/pytorch/issues/2259)
-        distillation_loss = F.kl_div(soft_log_probs, soft_targets.detach(), size_average=False) / soft_targets.shape[0]
+        distillation_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction='batchmean')
+
+        # According to [1]:
+        # "Since the magnitudes of the gradients produced by the soft targets scale as 1/(T^2) it is important
+        #  to multiply them by T^2 when using both hard and soft targets. This ensures that the relative contributions
+        #  of the hard and soft targets remain roughly unchanged if the temperature used for distillation is changed
+        #  while experimenting with meta-parameters."
+        distillation_loss_scaled = distillation_loss * self.temperature ** 2
 
         # The loss passed to the callback is the student's loss vs. the true labels, so we can use it directly, no
         # need to calculate again
+        overall_loss = self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss_scaled
 
-        overall_loss = self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss
+        # For logging purposes, we return the un-scaled distillation loss so it's comparable between runs with
+        # different temperatures
         return PolicyLoss(overall_loss,
                           [LossComponent('Distill Loss', distillation_loss)])
diff --git a/examples/README.md b/examples/README.md
index 8d8bbe1..9cf5a2c 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -2,12 +2,12 @@
 
 Distiller comes with sample applications and tutorials covering a range of model types:
 
-| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | In Directories |
-|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|----------------|
-| **Image classification** | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | [classifier_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression), [auto_compression/amc](https://github.com/NervanaSystems/distiller/tree/master/examples/auto_compression/amc) |
-| **Word-level language model** | :white_check_mark: | :white_check_mark: | | |[word_language_model](https://github.com/NervanaSystems/distiller/tree/master/examples/word_language_model) |
-| **Translation (GNMT)** | | :white_check_mark: | | | [GNMT](https://github.com/NervanaSystems/distiller/tree/master/examples/GNMT) |
-| **Recommendation System (NCF)** | |  :white_check_mark: | | | [ncf](https://github.com/NervanaSystems/distiller/tree/master/examples/ncf) |
-| **Object Detection** |  :white_check_mark: | | | | [object_detection_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/object_detection_compression) |
+| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | Knowledge Distillation | In Directories |
+|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|:----------------------:|----------------|
+| **Image classification** | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | [classifier_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression), [auto_compression/amc](https://github.com/NervanaSystems/distiller/tree/master/examples/auto_compression/amc) |
+| **Word-level language model** | :white_check_mark: | :white_check_mark: | | | |[word_language_model](https://github.com/NervanaSystems/distiller/tree/master/examples/word_language_model) |
+| **Translation (GNMT)** | | :white_check_mark: | | | | [GNMT](https://github.com/NervanaSystems/distiller/tree/master/examples/GNMT) |
+| **Recommendation System (NCF)** | |  :white_check_mark: | | | | [ncf](https://github.com/NervanaSystems/distiller/tree/master/examples/ncf) |
+| **Object Detection** |  :white_check_mark: | | | | | [object_detection_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/object_detection_compression) |
 
 The directories specified in the table contain the code implementing each of the modalities. The rest of the sub-directories in this directory are each dedicated to a specific compression method, and contain YAML schedules and other files that can be used with the sample applications. Most of these files contain details on the results obtained and how to re-produce them.
diff --git a/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml b/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml
index f0d1343..34a1517 100644
--- a/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml
+++ b/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml
@@ -16,14 +16,12 @@
 #
 # Notes:
 #  * Replace 'preact_resnet44_cifar' with the required teacher model
-#  * Use the command line above to train teacher models, which can be used as checkpoints passed to '--kd-resume'
-#  * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss
-#    (that is - the loss of the student predictions vs. the teacher's soft targets).
-#  * Note we don't change any of the other training hyper-parameters
+#  * Use the command line at the top to train teacher models, which can be used as checkpoints passed to '--kd-resume'
+#    (the shell script we point to below can be used to train baseline teacher models)
 #  * More details on knowledge distillation at: 
 #    https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation
 #
-# See some experimental results with the hyper-parameters shown above after the YAML schedule
+# See some experimental results below after the YAML schedule
 
 lr_schedulers:
   training_lr:
@@ -38,49 +36,94 @@ policies:
       ending_epoch: 200
       frequency: 1
 
-# The results listed here are based on 4 runs in each configuration. All results are Top-1:
+# The results listed here are based on 4 runs in each configuration. All results are Top-1.
+# 
+# Notes:
+#  * In this example we're testing three distillation temperatures: 1.0, 2.0, 5.0
+#    This is controlled by the '--kd-temp' command line argument.
+#  * We give a weight of 0.7 to the distillation loss (that is - the loss of the student predictions vs.
+#    the teacher's soft targets) and 0.3 weight to the "standard" studen vs. labels loss.
+#    This is done by passing the command line arguments: '--kd-dw 0.7 --kd-sw 0.3'
+#  * We don't change any of the other training hyper-parameters
+# 
+# The shell script used to generate these results is provided at:
+# <distiller_root>/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh
+#
+# (additional results and some limited analysis is included in the preact_resnet_cifar_dorefa.yaml file,
+#  located in the same directory as this file)
 #
-# +-------+--------------+-------------------------+
-# |       |              |           FP32          |
-# +-------+--------------+-------------------------+
-# | Depth | Distillation | Best  | Worst | Average |
-# |       | Teacher      |       |       |         |
-# +-------+--------------+-------+-------+---------+
-# | 20    | None         | 92.4  | 91.91 | 92.2225 |
-# +-------+--------------+-------+-------+---------+
-# | 20    | 32           | 92.85 | 92.68 | 92.7375 |
-# +-------+--------------+-------+-------+---------+
-# | 20    | 44           | 93.09 | 92.64 | 92.795  |
-# +-------+--------------+-------+-------+---------+
-# | 20    | 56           | 92.77 | 92.52 | 92.6475 |
-# +-------+--------------+-------+-------+---------+
-# | 20    | 110          | 92.87 | 92.66 | 92.7725 |
-# +-------+--------------+-------+-------+---------+
-# |       |              |       |       |         |
-# +-------+--------------+-------+-------+---------+
-# | 32    | None         | 93.31 | 92.93 | 93.13   |
-# +-------+--------------+-------+-------+---------+
-# | 32    | 44           | 93.54 | 93.35 | 93.48   |
-# +-------+--------------+-------+-------+---------+
-# | 32    | 56           | 93.58 | 93.47 | 93.5125 |
-# +-------+--------------+-------+-------+---------+
-# | 32    | 110          | 93.6  | 93.29 | 93.4575 |
-# +-------+--------------+-------+-------+---------+
-# |       |              |       |       |         |
-# +-------+--------------+-------+-------+---------+
-# | 44    | None         | 94.07 | 93.5  | 93.7425 |
-# +-------+--------------+-------+-------+---------+
-# | 44    | 56           | 94.08 | 93.58 | 93.875  |
-# +-------+--------------+-------+-------+---------+
-# | 44    | 110          | 94.13 | 93.75 | 93.95   |
-# +-------+--------------+-------+-------+---------+
-# |       |              |       |       |         |
-# +-------+--------------+-------+-------+---------+
-# | 56    | None         | 94.2  | 93.52 | 93.8    |
-# +-------+--------------+-------+-------+---------+
-# | 56    | 110          | 94.47 | 94.0  | 94.16   |
-# +-------+--------------+-------+-------+---------+
-# |       |              |       |       |         |
-# +-------+--------------+-------+-------+---------+
-# | 110   | None         | 94.66 | 94.42 | 94.54   |
-# +-------+--------------+-------+-------+---------+
+# +-------+-----------------------+-------------------------+
+# |       | Distillation Settings |      Results: FP32      |
+# +-------+---------+-------------+-------+-------+---------+
+# | Depth | Teacher | Temperature | Best  | Worst | Average |
+# +-------+---------+-------------+-------+-------+---------+
+# | 20    | None    | N/A         | 92.21 | 91.96 | 92.13   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 32      | 1           | 92.58 | 92.18 | 92.375  |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 92.92 | 92.61 | 92.775  |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.2  | 93.11 | 93.1725 |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 44      | 1           | 92.77 | 92.23 | 92.4625 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 93.02 | 92.74 | 92.8625 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.3  | 93.11 | 93.21   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 56      | 1           | 92.58 | 92    | 92.2875 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 92.93 | 92.51 | 92.73   |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.11 | 92.82 | 92.96   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 110     | 1           | 92.46 | 92.16 | 92.25   |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 92.9  | 92.54 | 92.745  |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.33 | 93    | 93.1475 |
+# +-------+---------+-------------+-------+-------+---------+
+# | 32    | None    | N/A         | 93.19 | 92.72 | 92.9975 |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 44      | 1           | 93.43 | 93.22 | 93.3475 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 94.01 | 93.37 | 93.605  |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.9  | 93.42 | 93.68   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 56      | 1           | 93.34 | 92.93 | 93.16   |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 93.99 | 93.29 | 93.5925 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.9  | 93.62 | 93.78   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 110     | 1           | 93.43 | 93.18 | 93.335  |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 93.89 | 93.5  | 93.6675 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 93.89 | 93.63 | 93.795  |
+# +-------+---------+-------------+-------+-------+---------+
+# | 44    | None    | N/A         | 93.81 | 93.23 | 93.58   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 56      | 1           | 93.97 | 93.57 | 93.7775 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 94.11 | 93.89 | 94.01   |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 94.28 | 93.97 | 94.1225 |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 110     | 1           | 93.6  | 93.48 | 93.5375 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 94.36 | 93.9  | 94.1075 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 94.39 | 94.04 | 94.225  |
+# +-------+---------+-------------+-------+-------+---------+
+# | 56    | None    | N/A         | 93.99 | 93.7  | 93.88   |
+# |       +---------+-------------+-------+-------+---------+
+# |       | 110     | 1           | 94.33 | 93.85 | 94.0675 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 2           | 94.33 | 94.1  | 94.2425 |
+# |       |         +-------------+-------+-------+---------+
+# |       |         | 5           | 94.41 | 94.18 | 94.3125 |
+# +-------+---------+-------------+-------+-------+---------+
+# | 110   | None    | N/A         | 94.69 | 94.33 | 94.475  |
+# +-------+---------+-------------+-------+-------+---------+
diff --git a/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh b/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh
new file mode 100644
index 0000000..54e8361
--- /dev/null
+++ b/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh
@@ -0,0 +1,72 @@
+#!/usr/bin/env bash
+
+# This script was used to generate the results included in
+# <distiller_root>/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml
+# and
+# <distiller_root>/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml
+#
+# IMPORTANT: 
+#  * It is assumed that the script is executed from the following directory:
+#     <distiller_root>/examples/classifier_compression
+#    Some of the paths used are relative to this directory.
+#  * Some of the paths might need to be modified for your own system, e.g. 'dataset' and 't_ckpt'
+
+gpu=$1
+suffix=$2
+
+if [ -z $gpu ]; then
+  gpu=0
+fi
+
+if [ -z $suffix ]; then
+  suffix="try1"
+fi
+
+# Modify dataset path to your own
+dataset="$MACHINE_HOME/datasets/cifar10"
+base_fp32_sched="../quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml"
+dorefa_w3_a8_sched="../quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml"
+base_out_dir="logs/presnet_cifar"
+
+base_args="${dataset} --lr 0.1 -p 50 -b 128 -j 1 --epochs 200 --wd 0.0002 --vs 0 --gpus ${gpu}"
+base_cmd="python compress_classifier.py"
+
+# No distillation
+for mode in "base_fp32" "dorefa_w3_a8"; do
+  for depth in 20 32 44 56 110; do
+    arch=preact_resnet${depth}_cifar
+    sched=${mode}_sched
+    out_dir="${base_out_dir}/presnet${depth}"
+    exp_name="presnet${depth}_${mode}_${suffix}"
+    set -x
+    ${base_cmd} -a ${arch} ${base_args} --compress=${!sched} -o ${out_dir} -n ${exp_name}
+    set +x
+  done
+done
+
+# With distillation
+for mode in "base_fp32" "dorefa_w3_a8"; do
+  sched=${mode}_sched
+  for s_depth in 20 32 44 56; do
+    for temp in 1 2 5; do
+      for dw in 0.7; do
+        for sw in 0.3; do
+          for t_depth in 32 44 56 110; do
+            if (( $t_depth > $s_depth )); then
+              s_arch=preact_resnet${s_depth}_cifar
+              t_arch=preact_resnet${t_depth}_cifar
+			  # Change t_ckpt path to point to your pre-trained checkpoints
+              t_ckpt="../baselines/models/${t_arch}10/checkpoint_best.pth.tar"
+              out_dir="${base_out_dir}/presnet${s_depth}"
+              exp_name="presnet${s_depth}_${t_depth}_t_${temp}_dw_${dw}_${mode}_${suffix}"
+              kd_args="--kd-teacher ${t_arch} --kd-resume ${t_ckpt} --kd-temp ${temp} --kd-dw ${dw} --kd-sw ${sw}"
+              set -x
+              ${base_cmd} -a ${s_arch} ${base_args} --compress=${!sched} ${kd_args} -o ${out_dir} -n ${exp_name}
+              set +x
+            fi
+          done
+        done
+      done
+    done
+  done
+done
diff --git a/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml b/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml
index cb9e72d..536c0d0 100644
--- a/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml
+++ b/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml
@@ -21,13 +21,10 @@
 # Notes:
 #  * Replace 'preact_resnet44_cifar' with the required teacher model
 #  * To train baseline FP32 that can be used as teacher models, see preact_resnet_cifar_base_fp32.yaml
-#  * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss
-#    (that is - the loss of the student predictions vs. the teacher's soft targets).
-#  * Note we don't change any of the other training hyper-parameters
 #  * More details on knowledge distillation at: 
 #    https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation
 #
-# See some experimental results with the hyper-parameters shown above after the YAML schedule
+# See some experimental results below after the YAML schedule
 
 quantizers:
   dorefa_quantizer:
@@ -68,49 +65,105 @@ policies:
       ending_epoch: 161
       frequency: 1
 
-# The results listed here are based on 4 runs in each configuration. All results are Top-1:
+# The results listed here are based on 4 runs in each configuration. All results are Top-1.
 #
-# +-------+--------------+-------------------------+-------------------------+
-# |       |              |           FP32          |       DoReFa w3-a8      |
-# +-------+--------------+-------------------------+-------------------------+
-# | Depth | Distillation | Best  | Worst | Average | Best  | Worst | Average |
-# |       | Teacher      |       |       |         |       |       |         |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 20    | None         | 92.4  | 91.91 | 92.2225 | 91.87 | 91.34 | 91.605  |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 20    | 32           | 92.85 | 92.68 | 92.7375 | 92.16 | 91.96 | 92.0725 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 20    | 44           | 93.09 | 92.64 | 92.795  | 92.54 | 91.9  | 92.2225 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 20    | 56           | 92.77 | 92.52 | 92.6475 | 92.53 | 91.92 | 92.15   |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 20    | 110          | 92.87 | 92.66 | 92.7725 | 92.12 | 92.01 | 92.0825 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# |       |              |       |       |         |       |       |         |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 32    | None         | 93.31 | 92.93 | 93.13   | 92.66 | 92.33 | 92.485  |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 32    | 44           | 93.54 | 93.35 | 93.48   | 93.41 | 93.2  | 93.2875 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 32    | 56           | 93.58 | 93.47 | 93.5125 | 93.18 | 92.76 | 92.93   |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 32    | 110          | 93.6  | 93.29 | 93.4575 | 93.36 | 92.99 | 93.175  |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# |       |              |       |       |         |       |       |         |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 44    | None         | 94.07 | 93.5  | 93.7425 | 93.08 | 92.66 | 92.8125 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 44    | 56           | 94.08 | 93.58 | 93.875  | 93.46 | 93.28 | 93.3875 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 44    | 110          | 94.13 | 93.75 | 93.95   | 93.45 | 93.24 | 93.3825 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# |       |              |       |       |         |       |       |         |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 56    | None         | 94.2  | 93.52 | 93.8    | 93.44 | 92.91 | 93.0975 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 56    | 110          | 94.47 | 94.0  | 94.16   | 93.83 | 93.56 | 93.7225 |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# |       |              |       |       |         |       |       |         |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
-# | 110   | None         | 94.66 | 94.42 | 94.54   | 93.53 | 93.24 | 93.395  |
-# +-------+--------------+-------+-------+---------+-------+-------+---------+
+# Notes:
+#  * In this example we're testing three distillation temperatures: 1.0, 2.0, 5.0
+#    This is controlled by the '--kd-temp' command line argument.
+#  * We give a weight of 0.7 to the distillation loss (that is - the loss of the student predictions vs.
+#    the teacher's soft targets) and 0.3 weight to the "standard" studen vs. labels loss.
+#    This is done by passing the command line arguments: '--kd-dw 0.7 --kd-sw 0.3'
+#  * We don't change any of the other training hyper-parameters
+# 
+# The shell script used to generate these results is provided at:
+# <distiller_root>/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh
+#
+# Some notable outcomes:
+#  * FP32 runs:
+#     * Distillation improves results in almost all of the tested configurations, in some cases over 1.0% improvement
+#       in the average result.
+#     * Higher distillation temperature yields better results - between 0.24% to 0.9% (0.58% on average).
+#  * DoReFa runs:
+#     * Distillation is much less useful compared to FP32, in the settings we used.
+#     * As student model depth increases, the benefit from distillation decreases.
+#     * Opposite from FP32, higher distillation temperature yields worse results, and on a much bigger scale when
+#       comparing temperatures 1.0 and 5.0. The degradation is between 8.1% and 11.6% (10.2% on average).
+#
+#     Note that we used the same weights on the losses in all runs. So additional hyper-parameter tuning might
+#     still yield better results and provide further insights.
+#
+# +-------+-----------------------+-------------------------+-------------------------+
+# |       | Distillation Settings |      Results: FP32      |  Results: DoReFa w3-a8  |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
+# | Depth | Teacher | Temperature | Best  | Worst | Average | Best  | Worst | Average |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
+# | 20    | None    | N/A         | 92.21 | 91.96 | 92.13   | 91.56 | 91.31 | 91.44   |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 32      | 1           | 92.58 | 92.18 | 92.375  | 91.84 | 91.36 | 91.6425 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 92.92 | 92.61 | 92.775  | 91.15 | 90.66 | 90.9225 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.2  | 93.11 | 93.1725 | 82.07 | 78.09 | 79.995  |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 44      | 1           | 92.77 | 92.23 | 92.4625 | 91.59 | 91.19 | 91.395  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 93.02 | 92.74 | 92.8625 | 91.09 | 90.68 | 90.8725 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.3  | 93.11 | 93.21   | 83.22 | 79.1  | 80.94   |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 56      | 1           | 92.58 | 92    | 92.2875 | 91.89 | 91.22 | 91.6475 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 92.93 | 92.51 | 92.73   | 90.7  | 90.36 | 90.565  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.11 | 92.82 | 92.96   | 82.78 | 78.51 | 80.855  |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 110     | 1           | 92.46 | 92.16 | 92.25   | 91.73 | 91.59 | 91.66   |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 92.9  | 92.54 | 92.745  | 90.92 | 90.46 | 90.5875 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.33 | 93    | 93.1475 | 83.87 | 82.18 | 82.8325 |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
+# | 32    | None    | N/A         | 93.19 | 92.72 | 92.9975 | 92.92 | 92.27 | 92.4925 |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 44      | 1           | 93.43 | 93.22 | 93.3475 | 92.82 | 92.41 | 92.555  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 94.01 | 93.37 | 93.605  | 92.09 | 91.38 | 91.8375 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.9  | 93.42 | 93.68   | 82.78 | 77.99 | 80.9475 |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 56      | 1           | 93.34 | 92.93 | 93.16   | 92.86 | 92.26 | 92.5575 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 93.99 | 93.29 | 93.5925 | 92.02 | 91.1  | 91.6075 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.9  | 93.62 | 93.78   | 82.64 | 80.87 | 82.05   |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 110     | 1           | 93.43 | 93.18 | 93.335  | 92.92 | 92.26 | 92.545  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 93.89 | 93.5  | 93.6675 | 92.18 | 91.67 | 91.8925 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 93.89 | 93.63 | 93.795  | 83.03 | 80.13 | 82.25   |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
+# | 44    | None    | N/A         | 93.81 | 93.23 | 93.58   | 93.07 | 92.67 | 92.8475 |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 56      | 1           | 93.97 | 93.57 | 93.7775 | 92.98 | 92.7  | 92.84   |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 94.11 | 93.89 | 94.01   | 92.55 | 91.53 | 91.96   |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 94.28 | 93.97 | 94.1225 | 83.56 | 81.83 | 82.6375 |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 110     | 1           | 93.6  | 93.48 | 93.5375 | 92.88 | 92.65 | 92.7825 |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 94.36 | 93.9  | 94.1075 | 92.12 | 91.33 | 91.785  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 94.39 | 94.04 | 94.225  | 84.33 | 81.95 | 83.51   |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
+# | 56    | None    | N/A         | 93.99 | 93.7  | 93.88   | 93.2  | 92.89 | 93.035  |
+# |       +---------+-------------+-------+-------+---------+-------+-------+---------+
+# |       | 110     | 1           | 94.33 | 93.85 | 94.0675 | 93.23 | 92.81 | 93.005  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 2           | 94.33 | 94.1  | 94.2425 | 92.26 | 92.04 | 92.145  |
+# |       |         +-------------+-------+-------+---------+-------+-------+---------+
+# |       |         | 5           | 94.41 | 94.18 | 94.3125 | 85.36 | 84.59 | 84.9125 |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
+# | 110   | None    | N/A         | 94.69 | 94.33 | 94.475  | 93.56 | 92.69 | 93.1175 |
+# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+
-- 
GitLab