Skip to content
Snippets Groups Projects
Commit 6d455597 authored by Guy Jacob's avatar Guy Jacob
Browse files

Add knowledge distillation examples + some results

parent 75a1e7e8
No related branches found
No related tags found
No related merge requests found
......@@ -10,3 +10,4 @@ env/
.idea/
logs/
.DS_Store
.vscode/
For examples of knowledge distillation usage see:
* quantization/preact_resnet_cifar_base_fp32.yaml
* quantization/preact_resnet_cifar_dorefa.yaml
\ No newline at end of file
# time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1
# --epochs 200 --compress=../quantization/preact_resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0
#2018-07-18 12:25:56,477 - --- validate (epoch=199)-----------
#2018-07-18 12:25:56,477 - 10000 samples (128 per mini-batch)
#2018-07-18 12:25:57,810 - Epoch: [199][ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625
#2018-07-18 12:25:58,402 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307
#
#2018-07-18 12:25:58,404 - ==> Best validation Top1: 92.560 Epoch: 127
#2018-07-18 12:25:58,404 - Saving checkpoint to: logs/checkpoint.pth.tar
#2018-07-18 12:25:58,418 - --- test ---------------------
#2018-07-18 12:25:58,418 - 10000 samples (128 per mini-batch)
#2018-07-18 12:25:59,664 - Test: [ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625
#2018-07-18 12:26:00,248 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307
lr_schedulers:
training_lr:
class: MultiStepMultiGammaLR
milestones: [80, 120, 160]
gammas: [0.1, 0.1, 0.2]
policies:
- lr_scheduler:
instance_name: training_lr
starting_epoch: 0
ending_epoch: 200
frequency: 1
quantizers:
dorefa_quantizer:
class: DorefaQuantizer
bits_activations: 8
bits_weights: 3
bits_overrides:
# Don't quantize first and last layer
conv1:
wts: null
acts: null
layer1.0.pre_relu:
wts: null
acts: null
final_relu:
wts: null
acts: null
fc:
wts: null
acts: null
lr_schedulers:
training_lr:
class: MultiStepMultiGammaLR
milestones: [80, 120, 160]
gammas: [0.1, 0.1, 0.2]
policies:
- quantizer:
instance_name: dorefa_quantizer
starting_epoch: 0
ending_epoch: 200
frequency: 1
- lr_scheduler:
instance_name: training_lr
starting_epoch: 0
ending_epoch: 161
frequency: 1
# Scheduler for training baseline FP32 models of pre-activation ResNet on CIFAR-10
# Applicable to ResNet 20 / 32 / 44 / 56 / 110
#
# Command line for training (running from the compress_classifier.py directory):
# python compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 <path_to_cifar10_dataset> -j 1 --epochs 200 --compress=../quantization/preact_resnet_cifar_base_fp32.yaml --wd=0.0002 --vs=0 --gpus 0
#
# Notes:
# * Replace '-a preact_resnet20_cifar' with the required depth
# * '--wd-0.0002': Weight decay of 0.0002 is used
# * '--vs=0': We train on the entire training dataset, and validate using the test set
#
# Knowledge Distillation:
# -----------------------
# To train these models with knowledge distillation, add the following arguments to the command line:
# --kd-teacher preact_resnet44_cifar --kd-resume <path_to_teacher_model_checkpoint> --kd-temp 5.0 --kd-dw 0.7 --kd-sw 0.3
#
# Notes:
# * Replace 'preact_resnet44_cifar' with the required teacher model
# * Use the command line above to train teacher models, which can be used as checkpoints passed to '--kd-resume'
# * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss
# (that is - the loss of the student predictions vs. the teacher's soft targets).
# * Note we don't change any of the other training hyper-parameters
# * More details on knowledge distillation at:
# https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation
#
# See some experimental results with the hyper-parameters shown above after the YAML schedule
lr_schedulers:
training_lr:
class: MultiStepMultiGammaLR
milestones: [80, 120, 160]
gammas: [0.1, 0.1, 0.2]
policies:
- lr_scheduler:
instance_name: training_lr
starting_epoch: 0
ending_epoch: 200
frequency: 1
# The results listed here are based on 4 runs in each configuration:
# +-------+--------------+-------------------------+
# | | | FP32 |
# +-------+--------------+-------------------------+
# | Depth | Distillation | Best | Worst | Average |
# | | Teacher | | | |
# +-------+--------------+-------+-------+---------+
# | 20 | None | 92.4 | 91.91 | 92.2225 |
# +-------+--------------+-------+-------+---------+
# | 20 | 32 | 92.85 | 92.68 | 92.7375 |
# +-------+--------------+-------+-------+---------+
# | 20 | 44 | 93.09 | 92.64 | 92.795 |
# +-------+--------------+-------+-------+---------+
# | 20 | 56 | 92.77 | 92.52 | 92.6475 |
# +-------+--------------+-------+-------+---------+
# | 20 | 110 | 92.87 | 92.66 | 92.7725 |
# +-------+--------------+-------+-------+---------+
# | | | | | |
# +-------+--------------+-------+-------+---------+
# | 32 | None | 93.31 | 92.93 | 93.13 |
# +-------+--------------+-------+-------+---------+
# | 32 | 44 | 93.54 | 93.35 | 93.48 |
# +-------+--------------+-------+-------+---------+
# | 32 | 56 | 93.58 | 93.47 | 93.5125 |
# +-------+--------------+-------+-------+---------+
# | 32 | 110 | 93.6 | 93.29 | 93.4575 |
# +-------+--------------+-------+-------+---------+
# | | | | | |
# +-------+--------------+-------+-------+---------+
# | 44 | None | 94.07 | 93.5 | 93.7425 |
# +-------+--------------+-------+-------+---------+
# | 44 | 56 | 94.08 | 93.58 | 93.875 |
# +-------+--------------+-------+-------+---------+
# | 44 | 110 | 94.13 | 93.75 | 93.95 |
# +-------+--------------+-------+-------+---------+
# | | | | | |
# +-------+--------------+-------+-------+---------+
# | 56 | None | 94.2 | 93.52 | 93.8 |
# +-------+--------------+-------+-------+---------+
# | 56 | 110 | 94.47 | 94.0 | 94.16 |
# +-------+--------------+-------+-------+---------+
# | | | | | |
# +-------+--------------+-------+-------+---------+
# | 110 | None | 94.66 | 94.42 | 94.54 |
# +-------+--------------+-------+-------+---------+
# Scheduler for training pre-activation ResNet on CIFAR-10, quantized using the DoReFa scheme
# See:
# https://nervanasystems.github.io/distiller/algo_quantization/index.html#dorefa
# https://arxiv.org/abs/1606.06160
#
# Applicable to ResNet 20 / 32 / 44 / 56 / 110
#
# Command line for training (running from the compress_classifier.py directory):
# python compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 <path_to_cifar10_dataset> -j 1 --epochs 200 --compress=../quantization/preact_resnet_cifar_dorefa.yaml --wd=0.0002 --vs=0 --gpus 0
#
# Notes:
# * Replace '-a preact_resnet20_cifar' with the required depth
# * '--wd-0.0002': Weight decay of 0.0002 is used
# * '--vs=0': We train on the entire training dataset, and validate using the test set
#
# Knowledge Distillation:
# -----------------------
# To train these models with knowledge distillation, add the following arguments to the command line:
# --kd-teacher preact_resnet44_cifar --kd-resume <path_to_teacher_model_checkpoint> --kd-temp 5.0 --kd-dw 0.7 --kd-sw 0.3
#
# Notes:
# * Replace 'preact_resnet44_cifar' with the required teacher model
# * To train baseline FP32 that can be used as teacher models, see preact_resnet_cifar_base_fp32.yaml
# * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss
# (that is - the loss of the student predictions vs. the teacher's soft targets).
# * Note we don't change any of the other training hyper-parameters
# * More details on knowledge distillation at:
# https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation
#
# See some experimental results with the hyper-parameters shown above after the YAML schedule
quantizers:
dorefa_quantizer:
class: DorefaQuantizer
bits_activations: 8
bits_weights: 3
bits_overrides:
# Don't quantize first and last layer
conv1:
wts: null
acts: null
layer1.0.pre_relu:
wts: null
acts: null
final_relu:
wts: null
acts: null
fc:
wts: null
acts: null
lr_schedulers:
training_lr:
class: MultiStepMultiGammaLR
milestones: [80, 120, 160]
gammas: [0.1, 0.1, 0.2]
policies:
- quantizer:
instance_name: dorefa_quantizer
starting_epoch: 0
ending_epoch: 200
frequency: 1
- lr_scheduler:
instance_name: training_lr
starting_epoch: 0
ending_epoch: 161
frequency: 1
# The results listed here are based on 4 runs in each configuration:
# +-------+--------------+-------------------------+-------------------------+
# | | | FP32 | DoReFa w3-a8 |
# +-------+--------------+-------------------------+-------------------------+
# | Depth | Distillation | Best | Worst | Average | Best | Worst | Average |
# | | Teacher | | | | | | |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20 | None | 92.4 | 91.91 | 92.2225 | 91.87 | 91.34 | 91.605 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20 | 32 | 92.85 | 92.68 | 92.7375 | 92.16 | 91.96 | 92.0725 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20 | 44 | 93.09 | 92.64 | 92.795 | 92.54 | 91.9 | 92.2225 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20 | 56 | 92.77 | 92.52 | 92.6475 | 92.53 | 91.92 | 92.15 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20 | 110 | 92.87 | 92.66 | 92.7725 | 92.12 | 92.01 | 92.0825 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | | | | | | | | |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32 | None | 93.31 | 92.93 | 93.13 | 92.66 | 92.33 | 92.485 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32 | 44 | 93.54 | 93.35 | 93.48 | 93.41 | 93.2 | 93.2875 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32 | 56 | 93.58 | 93.47 | 93.5125 | 93.18 | 92.76 | 92.93 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32 | 110 | 93.6 | 93.29 | 93.4575 | 93.36 | 92.99 | 93.175 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | | | | | | | | |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 44 | None | 94.07 | 93.5 | 93.7425 | 93.08 | 92.66 | 92.8125 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 44 | 56 | 94.08 | 93.58 | 93.875 | 93.46 | 93.28 | 93.3875 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 44 | 110 | 94.13 | 93.75 | 93.95 | 93.45 | 93.24 | 93.3825 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | | | | | | | | |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 56 | None | 94.2 | 93.52 | 93.8 | 93.44 | 92.91 | 93.0975 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 56 | 110 | 94.47 | 94.0 | 94.16 | 93.83 | 93.56 | 93.7225 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | | | | | | | | |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 110 | None | 94.66 | 94.42 | 94.54 | 93.53 | 93.24 | 93.395 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment