diff --git a/README.md b/README.md index 625010a7aef75eff888f02633e89094eef3a4114..ac520dab73f56e6346c2583c6ecd96d224b3205e 100755 --- a/README.md +++ b/README.md @@ -174,9 +174,9 @@ If you do not use CUDA 10.1 in your environment, please refer to [PyTorch websit Distiller comes with sample applications and tutorials covering a range of model types: -| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | -|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:| -| [Image classification](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression) | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | Knowledge Distillation | +|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|:--------:| +| [Image classification](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression) | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | [Word-level language model](https://github.com/NervanaSystems/distiller/tree/master/examples/word_language_model)| :white_check_mark: | :white_check_mark: | | | | [Translation (GNMT)](https://github.com/NervanaSystems/distiller/tree/master/examples/GNMT) | | :white_check_mark: | | | | [Recommendation System (NCF)](https://github.com/NervanaSystems/distiller/tree/master/examples/ncf) | | :white_check_mark: | | | diff --git a/distiller/knowledge_distillation.py b/distiller/knowledge_distillation.py index acba5ceb504cd5526c60997138f29a9bbb2999e3..48f313ab6d07ed62d62f454f980fc127b5f0c285 100644 --- a/distiller/knowledge_distillation.py +++ b/distiller/knowledge_distillation.py @@ -104,7 +104,7 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy): def forward(self, *inputs): """ - Performs forward propagation through both student and teached models and caches the logits. + Performs forward propagation through both student and teacher models and caches the logits. This function MUST be used instead of calling the student model directly. Returns: @@ -120,7 +120,7 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy): self.last_teacher_logits = self.teacher(*inputs) out = self.student(*inputs) - self.last_students_logits = out.new_tensor(out, requires_grad=True) + self.last_students_logits = out.clone() return out @@ -150,14 +150,20 @@ class KnowledgeDistillationPolicy(ScheduledTrainingPolicy): # soft_targets = F.softmax(self.cached_teacher_logits[minibatch_id] / self.temperature) soft_targets = F.softmax(self.last_teacher_logits / self.temperature, dim=1) - # The averaging used in PyTorch KL Div implementation is wrong, so we work around as suggested in - # https://pytorch.org/docs/stable/nn.html#kldivloss - # (Also see https://github.com/pytorch/pytorch/issues/6622, https://github.com/pytorch/pytorch/issues/2259) - distillation_loss = F.kl_div(soft_log_probs, soft_targets.detach(), size_average=False) / soft_targets.shape[0] + distillation_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction='batchmean') + + # According to [1]: + # "Since the magnitudes of the gradients produced by the soft targets scale as 1/(T^2) it is important + # to multiply them by T^2 when using both hard and soft targets. This ensures that the relative contributions + # of the hard and soft targets remain roughly unchanged if the temperature used for distillation is changed + # while experimenting with meta-parameters." + distillation_loss_scaled = distillation_loss * self.temperature ** 2 # The loss passed to the callback is the student's loss vs. the true labels, so we can use it directly, no # need to calculate again + overall_loss = self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss_scaled - overall_loss = self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss + # For logging purposes, we return the un-scaled distillation loss so it's comparable between runs with + # different temperatures return PolicyLoss(overall_loss, [LossComponent('Distill Loss', distillation_loss)]) diff --git a/examples/README.md b/examples/README.md index 8d8bbe1363ba954fb2084b82a8eb9f4a1e425078..9cf5a2caced168e9489369e007d426b3f27115db 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,12 +2,12 @@ Distiller comes with sample applications and tutorials covering a range of model types: -| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | In Directories | -|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|----------------| -| **Image classification** | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | [classifier_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression), [auto_compression/amc](https://github.com/NervanaSystems/distiller/tree/master/examples/auto_compression/amc) | -| **Word-level language model** | :white_check_mark: | :white_check_mark: | | |[word_language_model](https://github.com/NervanaSystems/distiller/tree/master/examples/word_language_model) | -| **Translation (GNMT)** | | :white_check_mark: | | | [GNMT](https://github.com/NervanaSystems/distiller/tree/master/examples/GNMT) | -| **Recommendation System (NCF)** | | :white_check_mark: | | | [ncf](https://github.com/NervanaSystems/distiller/tree/master/examples/ncf) | -| **Object Detection** | :white_check_mark: | | | | [object_detection_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/object_detection_compression) | +| Model Type | Sparsity | Post-training quantization | Quantization-aware training | Auto Compression (AMC) | Knowledge Distillation | In Directories | +|------------|:--------:|:--------------------------:|:---------------------------:|:----------------------:|:----------------------:|----------------| +| **Image classification** | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | [classifier_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/classifier_compression), [auto_compression/amc](https://github.com/NervanaSystems/distiller/tree/master/examples/auto_compression/amc) | +| **Word-level language model** | :white_check_mark: | :white_check_mark: | | | |[word_language_model](https://github.com/NervanaSystems/distiller/tree/master/examples/word_language_model) | +| **Translation (GNMT)** | | :white_check_mark: | | | | [GNMT](https://github.com/NervanaSystems/distiller/tree/master/examples/GNMT) | +| **Recommendation System (NCF)** | | :white_check_mark: | | | | [ncf](https://github.com/NervanaSystems/distiller/tree/master/examples/ncf) | +| **Object Detection** | :white_check_mark: | | | | | [object_detection_compression](https://github.com/NervanaSystems/distiller/tree/master/examples/object_detection_compression) | The directories specified in the table contain the code implementing each of the modalities. The rest of the sub-directories in this directory are each dedicated to a specific compression method, and contain YAML schedules and other files that can be used with the sample applications. Most of these files contain details on the results obtained and how to re-produce them. diff --git a/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml b/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml index f0d13439b34f2e173b67fc292553d4f03e22e10e..34a1517471fbf6f50e82e7dd3368390091e02bf2 100644 --- a/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml +++ b/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml @@ -16,14 +16,12 @@ # # Notes: # * Replace 'preact_resnet44_cifar' with the required teacher model -# * Use the command line above to train teacher models, which can be used as checkpoints passed to '--kd-resume' -# * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss -# (that is - the loss of the student predictions vs. the teacher's soft targets). -# * Note we don't change any of the other training hyper-parameters +# * Use the command line at the top to train teacher models, which can be used as checkpoints passed to '--kd-resume' +# (the shell script we point to below can be used to train baseline teacher models) # * More details on knowledge distillation at: # https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation # -# See some experimental results with the hyper-parameters shown above after the YAML schedule +# See some experimental results below after the YAML schedule lr_schedulers: training_lr: @@ -38,49 +36,94 @@ policies: ending_epoch: 200 frequency: 1 -# The results listed here are based on 4 runs in each configuration. All results are Top-1: +# The results listed here are based on 4 runs in each configuration. All results are Top-1. +# +# Notes: +# * In this example we're testing three distillation temperatures: 1.0, 2.0, 5.0 +# This is controlled by the '--kd-temp' command line argument. +# * We give a weight of 0.7 to the distillation loss (that is - the loss of the student predictions vs. +# the teacher's soft targets) and 0.3 weight to the "standard" studen vs. labels loss. +# This is done by passing the command line arguments: '--kd-dw 0.7 --kd-sw 0.3' +# * We don't change any of the other training hyper-parameters +# +# The shell script used to generate these results is provided at: +# <distiller_root>/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh +# +# (additional results and some limited analysis is included in the preact_resnet_cifar_dorefa.yaml file, +# located in the same directory as this file) # -# +-------+--------------+-------------------------+ -# | | | FP32 | -# +-------+--------------+-------------------------+ -# | Depth | Distillation | Best | Worst | Average | -# | | Teacher | | | | -# +-------+--------------+-------+-------+---------+ -# | 20 | None | 92.4 | 91.91 | 92.2225 | -# +-------+--------------+-------+-------+---------+ -# | 20 | 32 | 92.85 | 92.68 | 92.7375 | -# +-------+--------------+-------+-------+---------+ -# | 20 | 44 | 93.09 | 92.64 | 92.795 | -# +-------+--------------+-------+-------+---------+ -# | 20 | 56 | 92.77 | 92.52 | 92.6475 | -# +-------+--------------+-------+-------+---------+ -# | 20 | 110 | 92.87 | 92.66 | 92.7725 | -# +-------+--------------+-------+-------+---------+ -# | | | | | | -# +-------+--------------+-------+-------+---------+ -# | 32 | None | 93.31 | 92.93 | 93.13 | -# +-------+--------------+-------+-------+---------+ -# | 32 | 44 | 93.54 | 93.35 | 93.48 | -# +-------+--------------+-------+-------+---------+ -# | 32 | 56 | 93.58 | 93.47 | 93.5125 | -# +-------+--------------+-------+-------+---------+ -# | 32 | 110 | 93.6 | 93.29 | 93.4575 | -# +-------+--------------+-------+-------+---------+ -# | | | | | | -# +-------+--------------+-------+-------+---------+ -# | 44 | None | 94.07 | 93.5 | 93.7425 | -# +-------+--------------+-------+-------+---------+ -# | 44 | 56 | 94.08 | 93.58 | 93.875 | -# +-------+--------------+-------+-------+---------+ -# | 44 | 110 | 94.13 | 93.75 | 93.95 | -# +-------+--------------+-------+-------+---------+ -# | | | | | | -# +-------+--------------+-------+-------+---------+ -# | 56 | None | 94.2 | 93.52 | 93.8 | -# +-------+--------------+-------+-------+---------+ -# | 56 | 110 | 94.47 | 94.0 | 94.16 | -# +-------+--------------+-------+-------+---------+ -# | | | | | | -# +-------+--------------+-------+-------+---------+ -# | 110 | None | 94.66 | 94.42 | 94.54 | -# +-------+--------------+-------+-------+---------+ +# +-------+-----------------------+-------------------------+ +# | | Distillation Settings | Results: FP32 | +# +-------+---------+-------------+-------+-------+---------+ +# | Depth | Teacher | Temperature | Best | Worst | Average | +# +-------+---------+-------------+-------+-------+---------+ +# | 20 | None | N/A | 92.21 | 91.96 | 92.13 | +# | +---------+-------------+-------+-------+---------+ +# | | 32 | 1 | 92.58 | 92.18 | 92.375 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 92.92 | 92.61 | 92.775 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.2 | 93.11 | 93.1725 | +# | +---------+-------------+-------+-------+---------+ +# | | 44 | 1 | 92.77 | 92.23 | 92.4625 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 93.02 | 92.74 | 92.8625 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.3 | 93.11 | 93.21 | +# | +---------+-------------+-------+-------+---------+ +# | | 56 | 1 | 92.58 | 92 | 92.2875 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 92.93 | 92.51 | 92.73 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.11 | 92.82 | 92.96 | +# | +---------+-------------+-------+-------+---------+ +# | | 110 | 1 | 92.46 | 92.16 | 92.25 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 92.9 | 92.54 | 92.745 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.33 | 93 | 93.1475 | +# +-------+---------+-------------+-------+-------+---------+ +# | 32 | None | N/A | 93.19 | 92.72 | 92.9975 | +# | +---------+-------------+-------+-------+---------+ +# | | 44 | 1 | 93.43 | 93.22 | 93.3475 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 94.01 | 93.37 | 93.605 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.9 | 93.42 | 93.68 | +# | +---------+-------------+-------+-------+---------+ +# | | 56 | 1 | 93.34 | 92.93 | 93.16 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 93.99 | 93.29 | 93.5925 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.9 | 93.62 | 93.78 | +# | +---------+-------------+-------+-------+---------+ +# | | 110 | 1 | 93.43 | 93.18 | 93.335 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 93.89 | 93.5 | 93.6675 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 93.89 | 93.63 | 93.795 | +# +-------+---------+-------------+-------+-------+---------+ +# | 44 | None | N/A | 93.81 | 93.23 | 93.58 | +# | +---------+-------------+-------+-------+---------+ +# | | 56 | 1 | 93.97 | 93.57 | 93.7775 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 94.11 | 93.89 | 94.01 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 94.28 | 93.97 | 94.1225 | +# | +---------+-------------+-------+-------+---------+ +# | | 110 | 1 | 93.6 | 93.48 | 93.5375 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 94.36 | 93.9 | 94.1075 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 94.39 | 94.04 | 94.225 | +# +-------+---------+-------------+-------+-------+---------+ +# | 56 | None | N/A | 93.99 | 93.7 | 93.88 | +# | +---------+-------------+-------+-------+---------+ +# | | 110 | 1 | 94.33 | 93.85 | 94.0675 | +# | | +-------------+-------+-------+---------+ +# | | | 2 | 94.33 | 94.1 | 94.2425 | +# | | +-------------+-------+-------+---------+ +# | | | 5 | 94.41 | 94.18 | 94.3125 | +# +-------+---------+-------------+-------+-------+---------+ +# | 110 | None | N/A | 94.69 | 94.33 | 94.475 | +# +-------+---------+-------------+-------+-------+---------+ diff --git a/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh b/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh new file mode 100644 index 0000000000000000000000000000000000000000..54e8361fb748380cde1ee582f633c1eba2b16ebb --- /dev/null +++ b/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +# This script was used to generate the results included in +# <distiller_root>/examples/quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml +# and +# <distiller_root>/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml +# +# IMPORTANT: +# * It is assumed that the script is executed from the following directory: +# <distiller_root>/examples/classifier_compression +# Some of the paths used are relative to this directory. +# * Some of the paths might need to be modified for your own system, e.g. 'dataset' and 't_ckpt' + +gpu=$1 +suffix=$2 + +if [ -z $gpu ]; then + gpu=0 +fi + +if [ -z $suffix ]; then + suffix="try1" +fi + +# Modify dataset path to your own +dataset="$MACHINE_HOME/datasets/cifar10" +base_fp32_sched="../quantization/fp32_baselines/preact_resnet_cifar_base_fp32.yaml" +dorefa_w3_a8_sched="../quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml" +base_out_dir="logs/presnet_cifar" + +base_args="${dataset} --lr 0.1 -p 50 -b 128 -j 1 --epochs 200 --wd 0.0002 --vs 0 --gpus ${gpu}" +base_cmd="python compress_classifier.py" + +# No distillation +for mode in "base_fp32" "dorefa_w3_a8"; do + for depth in 20 32 44 56 110; do + arch=preact_resnet${depth}_cifar + sched=${mode}_sched + out_dir="${base_out_dir}/presnet${depth}" + exp_name="presnet${depth}_${mode}_${suffix}" + set -x + ${base_cmd} -a ${arch} ${base_args} --compress=${!sched} -o ${out_dir} -n ${exp_name} + set +x + done +done + +# With distillation +for mode in "base_fp32" "dorefa_w3_a8"; do + sched=${mode}_sched + for s_depth in 20 32 44 56; do + for temp in 1 2 5; do + for dw in 0.7; do + for sw in 0.3; do + for t_depth in 32 44 56 110; do + if (( $t_depth > $s_depth )); then + s_arch=preact_resnet${s_depth}_cifar + t_arch=preact_resnet${t_depth}_cifar + # Change t_ckpt path to point to your pre-trained checkpoints + t_ckpt="../baselines/models/${t_arch}10/checkpoint_best.pth.tar" + out_dir="${base_out_dir}/presnet${s_depth}" + exp_name="presnet${s_depth}_${t_depth}_t_${temp}_dw_${dw}_${mode}_${suffix}" + kd_args="--kd-teacher ${t_arch} --kd-resume ${t_ckpt} --kd-temp ${temp} --kd-dw ${dw} --kd-sw ${sw}" + set -x + ${base_cmd} -a ${s_arch} ${base_args} --compress=${!sched} ${kd_args} -o ${out_dir} -n ${exp_name} + set +x + fi + done + done + done + done + done +done diff --git a/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml b/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml index cb9e72de7994b782057b33134f4e9be8dd6a3fc2..536c0d0af54828a1fb49d5c46a13776deb52658c 100644 --- a/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml +++ b/examples/quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml @@ -21,13 +21,10 @@ # Notes: # * Replace 'preact_resnet44_cifar' with the required teacher model # * To train baseline FP32 that can be used as teacher models, see preact_resnet_cifar_base_fp32.yaml -# * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss -# (that is - the loss of the student predictions vs. the teacher's soft targets). -# * Note we don't change any of the other training hyper-parameters # * More details on knowledge distillation at: # https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation # -# See some experimental results with the hyper-parameters shown above after the YAML schedule +# See some experimental results below after the YAML schedule quantizers: dorefa_quantizer: @@ -68,49 +65,105 @@ policies: ending_epoch: 161 frequency: 1 -# The results listed here are based on 4 runs in each configuration. All results are Top-1: +# The results listed here are based on 4 runs in each configuration. All results are Top-1. # -# +-------+--------------+-------------------------+-------------------------+ -# | | | FP32 | DoReFa w3-a8 | -# +-------+--------------+-------------------------+-------------------------+ -# | Depth | Distillation | Best | Worst | Average | Best | Worst | Average | -# | | Teacher | | | | | | | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 20 | None | 92.4 | 91.91 | 92.2225 | 91.87 | 91.34 | 91.605 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 20 | 32 | 92.85 | 92.68 | 92.7375 | 92.16 | 91.96 | 92.0725 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 20 | 44 | 93.09 | 92.64 | 92.795 | 92.54 | 91.9 | 92.2225 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 20 | 56 | 92.77 | 92.52 | 92.6475 | 92.53 | 91.92 | 92.15 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 20 | 110 | 92.87 | 92.66 | 92.7725 | 92.12 | 92.01 | 92.0825 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | | | | | | | | | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 32 | None | 93.31 | 92.93 | 93.13 | 92.66 | 92.33 | 92.485 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 32 | 44 | 93.54 | 93.35 | 93.48 | 93.41 | 93.2 | 93.2875 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 32 | 56 | 93.58 | 93.47 | 93.5125 | 93.18 | 92.76 | 92.93 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 32 | 110 | 93.6 | 93.29 | 93.4575 | 93.36 | 92.99 | 93.175 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | | | | | | | | | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 44 | None | 94.07 | 93.5 | 93.7425 | 93.08 | 92.66 | 92.8125 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 44 | 56 | 94.08 | 93.58 | 93.875 | 93.46 | 93.28 | 93.3875 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 44 | 110 | 94.13 | 93.75 | 93.95 | 93.45 | 93.24 | 93.3825 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | | | | | | | | | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 56 | None | 94.2 | 93.52 | 93.8 | 93.44 | 92.91 | 93.0975 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 56 | 110 | 94.47 | 94.0 | 94.16 | 93.83 | 93.56 | 93.7225 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | | | | | | | | | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ -# | 110 | None | 94.66 | 94.42 | 94.54 | 93.53 | 93.24 | 93.395 | -# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# Notes: +# * In this example we're testing three distillation temperatures: 1.0, 2.0, 5.0 +# This is controlled by the '--kd-temp' command line argument. +# * We give a weight of 0.7 to the distillation loss (that is - the loss of the student predictions vs. +# the teacher's soft targets) and 0.3 weight to the "standard" studen vs. labels loss. +# This is done by passing the command line arguments: '--kd-dw 0.7 --kd-sw 0.3' +# * We don't change any of the other training hyper-parameters +# +# The shell script used to generate these results is provided at: +# <distiller_root>/examples/quantization/preact_resnet_cifar_quant_distill_tests.sh +# +# Some notable outcomes: +# * FP32 runs: +# * Distillation improves results in almost all of the tested configurations, in some cases over 1.0% improvement +# in the average result. +# * Higher distillation temperature yields better results - between 0.24% to 0.9% (0.58% on average). +# * DoReFa runs: +# * Distillation is much less useful compared to FP32, in the settings we used. +# * As student model depth increases, the benefit from distillation decreases. +# * Opposite from FP32, higher distillation temperature yields worse results, and on a much bigger scale when +# comparing temperatures 1.0 and 5.0. The degradation is between 8.1% and 11.6% (10.2% on average). +# +# Note that we used the same weights on the losses in all runs. So additional hyper-parameter tuning might +# still yield better results and provide further insights. +# +# +-------+-----------------------+-------------------------+-------------------------+ +# | | Distillation Settings | Results: FP32 | Results: DoReFa w3-a8 | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+ +# | Depth | Teacher | Temperature | Best | Worst | Average | Best | Worst | Average | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+ +# | 20 | None | N/A | 92.21 | 91.96 | 92.13 | 91.56 | 91.31 | 91.44 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 32 | 1 | 92.58 | 92.18 | 92.375 | 91.84 | 91.36 | 91.6425 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 92.92 | 92.61 | 92.775 | 91.15 | 90.66 | 90.9225 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.2 | 93.11 | 93.1725 | 82.07 | 78.09 | 79.995 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 44 | 1 | 92.77 | 92.23 | 92.4625 | 91.59 | 91.19 | 91.395 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 93.02 | 92.74 | 92.8625 | 91.09 | 90.68 | 90.8725 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.3 | 93.11 | 93.21 | 83.22 | 79.1 | 80.94 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 56 | 1 | 92.58 | 92 | 92.2875 | 91.89 | 91.22 | 91.6475 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 92.93 | 92.51 | 92.73 | 90.7 | 90.36 | 90.565 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.11 | 92.82 | 92.96 | 82.78 | 78.51 | 80.855 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 110 | 1 | 92.46 | 92.16 | 92.25 | 91.73 | 91.59 | 91.66 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 92.9 | 92.54 | 92.745 | 90.92 | 90.46 | 90.5875 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.33 | 93 | 93.1475 | 83.87 | 82.18 | 82.8325 | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+ +# | 32 | None | N/A | 93.19 | 92.72 | 92.9975 | 92.92 | 92.27 | 92.4925 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 44 | 1 | 93.43 | 93.22 | 93.3475 | 92.82 | 92.41 | 92.555 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 94.01 | 93.37 | 93.605 | 92.09 | 91.38 | 91.8375 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.9 | 93.42 | 93.68 | 82.78 | 77.99 | 80.9475 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 56 | 1 | 93.34 | 92.93 | 93.16 | 92.86 | 92.26 | 92.5575 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 93.99 | 93.29 | 93.5925 | 92.02 | 91.1 | 91.6075 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.9 | 93.62 | 93.78 | 82.64 | 80.87 | 82.05 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 110 | 1 | 93.43 | 93.18 | 93.335 | 92.92 | 92.26 | 92.545 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 93.89 | 93.5 | 93.6675 | 92.18 | 91.67 | 91.8925 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 93.89 | 93.63 | 93.795 | 83.03 | 80.13 | 82.25 | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+ +# | 44 | None | N/A | 93.81 | 93.23 | 93.58 | 93.07 | 92.67 | 92.8475 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 56 | 1 | 93.97 | 93.57 | 93.7775 | 92.98 | 92.7 | 92.84 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 94.11 | 93.89 | 94.01 | 92.55 | 91.53 | 91.96 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 94.28 | 93.97 | 94.1225 | 83.56 | 81.83 | 82.6375 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 110 | 1 | 93.6 | 93.48 | 93.5375 | 92.88 | 92.65 | 92.7825 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 94.36 | 93.9 | 94.1075 | 92.12 | 91.33 | 91.785 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 94.39 | 94.04 | 94.225 | 84.33 | 81.95 | 83.51 | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+ +# | 56 | None | N/A | 93.99 | 93.7 | 93.88 | 93.2 | 92.89 | 93.035 | +# | +---------+-------------+-------+-------+---------+-------+-------+---------+ +# | | 110 | 1 | 94.33 | 93.85 | 94.0675 | 93.23 | 92.81 | 93.005 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 2 | 94.33 | 94.1 | 94.2425 | 92.26 | 92.04 | 92.145 | +# | | +-------------+-------+-------+---------+-------+-------+---------+ +# | | | 5 | 94.41 | 94.18 | 94.3125 | 85.36 | 84.59 | 84.9125 | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+ +# | 110 | None | N/A | 94.69 | 94.33 | 94.475 | 93.56 | 92.69 | 93.1175 | +# +-------+---------+-------------+-------+-------+---------+-------+-------+---------+