diff --git a/.gitignore b/.gitignore index 69edb25f12fb33f2910fde05a6a4c92bb567ac14..dbdc98ba97a1ec21fd80995b80df28fcfe114822 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ env/ .idea/ logs/ .DS_Store +.vscode/ diff --git a/examples/knowledge_distillation.txt b/examples/knowledge_distillation.txt new file mode 100644 index 0000000000000000000000000000000000000000..9d9e9f52565f8134f1cb5373eb5f4803355af555 --- /dev/null +++ b/examples/knowledge_distillation.txt @@ -0,0 +1,3 @@ +For examples of knowledge distillation usage see: +* quantization/preact_resnet_cifar_base_fp32.yaml +* quantization/preact_resnet_cifar_dorefa.yaml \ No newline at end of file diff --git a/examples/quantization/preact_resnet20_cifar_base_fp32.yaml b/examples/quantization/preact_resnet20_cifar_base_fp32.yaml deleted file mode 100644 index 63aac98002be1b280e491e65faa12ec7454adcd5..0000000000000000000000000000000000000000 --- a/examples/quantization/preact_resnet20_cifar_base_fp32.yaml +++ /dev/null @@ -1,30 +0,0 @@ - -# time python3 compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 ../../../data.cifar10/ -j 1 -# --epochs 200 --compress=../quantization/preact_resnet20_cifar_base_fp32.yaml --out-dir="logs/" --wd=0.0002 --vs=0 - - -#2018-07-18 12:25:56,477 - --- validate (epoch=199)----------- -#2018-07-18 12:25:56,477 - 10000 samples (128 per mini-batch) -#2018-07-18 12:25:57,810 - Epoch: [199][ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625 -#2018-07-18 12:25:58,402 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307 -# -#2018-07-18 12:25:58,404 - ==> Best validation Top1: 92.560 Epoch: 127 -#2018-07-18 12:25:58,404 - Saving checkpoint to: logs/checkpoint.pth.tar -#2018-07-18 12:25:58,418 - --- test --------------------- -#2018-07-18 12:25:58,418 - 10000 samples (128 per mini-batch) -#2018-07-18 12:25:59,664 - Test: [ 50/ 78] Loss 0.312961 Top1 92.140625 Top5 99.765625 -#2018-07-18 12:26:00,248 - ==> Top1: 92.270 Top5: 99.800 Loss: 0.307 - - -lr_schedulers: - training_lr: - class: MultiStepMultiGammaLR - milestones: [80, 120, 160] - gammas: [0.1, 0.1, 0.2] - -policies: - - lr_scheduler: - instance_name: training_lr - starting_epoch: 0 - ending_epoch: 200 - frequency: 1 diff --git a/examples/quantization/preact_resnet20_cifar_dorefa.yaml b/examples/quantization/preact_resnet20_cifar_dorefa.yaml deleted file mode 100644 index a9870530bc6be26214b85610b417ea89ebd3a324..0000000000000000000000000000000000000000 --- a/examples/quantization/preact_resnet20_cifar_dorefa.yaml +++ /dev/null @@ -1,38 +0,0 @@ -quantizers: - dorefa_quantizer: - class: DorefaQuantizer - bits_activations: 8 - bits_weights: 3 - bits_overrides: - # Don't quantize first and last layer - conv1: - wts: null - acts: null - layer1.0.pre_relu: - wts: null - acts: null - final_relu: - wts: null - acts: null - fc: - wts: null - acts: null - -lr_schedulers: - training_lr: - class: MultiStepMultiGammaLR - milestones: [80, 120, 160] - gammas: [0.1, 0.1, 0.2] - -policies: - - quantizer: - instance_name: dorefa_quantizer - starting_epoch: 0 - ending_epoch: 200 - frequency: 1 - - - lr_scheduler: - instance_name: training_lr - starting_epoch: 0 - ending_epoch: 161 - frequency: 1 diff --git a/examples/quantization/preact_resnet_cifar_base_fp32.yaml b/examples/quantization/preact_resnet_cifar_base_fp32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b5b41dcce8e27ecefa2e034c35470f9810216df --- /dev/null +++ b/examples/quantization/preact_resnet_cifar_base_fp32.yaml @@ -0,0 +1,86 @@ +# Scheduler for training baseline FP32 models of pre-activation ResNet on CIFAR-10 +# Applicable to ResNet 20 / 32 / 44 / 56 / 110 +# +# Command line for training (running from the compress_classifier.py directory): +# python compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 <path_to_cifar10_dataset> -j 1 --epochs 200 --compress=../quantization/preact_resnet_cifar_base_fp32.yaml --wd=0.0002 --vs=0 --gpus 0 +# +# Notes: +# * Replace '-a preact_resnet20_cifar' with the required depth +# * '--wd-0.0002': Weight decay of 0.0002 is used +# * '--vs=0': We train on the entire training dataset, and validate using the test set +# +# Knowledge Distillation: +# ----------------------- +# To train these models with knowledge distillation, add the following arguments to the command line: +# --kd-teacher preact_resnet44_cifar --kd-resume <path_to_teacher_model_checkpoint> --kd-temp 5.0 --kd-dw 0.7 --kd-sw 0.3 +# +# Notes: +# * Replace 'preact_resnet44_cifar' with the required teacher model +# * Use the command line above to train teacher models, which can be used as checkpoints passed to '--kd-resume' +# * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss +# (that is - the loss of the student predictions vs. the teacher's soft targets). +# * Note we don't change any of the other training hyper-parameters +# * More details on knowledge distillation at: +# https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation +# +# See some experimental results with the hyper-parameters shown above after the YAML schedule + +lr_schedulers: + training_lr: + class: MultiStepMultiGammaLR + milestones: [80, 120, 160] + gammas: [0.1, 0.1, 0.2] + +policies: + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 + +# The results listed here are based on 4 runs in each configuration: + +# +-------+--------------+-------------------------+ +# | | | FP32 | +# +-------+--------------+-------------------------+ +# | Depth | Distillation | Best | Worst | Average | +# | | Teacher | | | | +# +-------+--------------+-------+-------+---------+ +# | 20 | None | 92.4 | 91.91 | 92.2225 | +# +-------+--------------+-------+-------+---------+ +# | 20 | 32 | 92.85 | 92.68 | 92.7375 | +# +-------+--------------+-------+-------+---------+ +# | 20 | 44 | 93.09 | 92.64 | 92.795 | +# +-------+--------------+-------+-------+---------+ +# | 20 | 56 | 92.77 | 92.52 | 92.6475 | +# +-------+--------------+-------+-------+---------+ +# | 20 | 110 | 92.87 | 92.66 | 92.7725 | +# +-------+--------------+-------+-------+---------+ +# | | | | | | +# +-------+--------------+-------+-------+---------+ +# | 32 | None | 93.31 | 92.93 | 93.13 | +# +-------+--------------+-------+-------+---------+ +# | 32 | 44 | 93.54 | 93.35 | 93.48 | +# +-------+--------------+-------+-------+---------+ +# | 32 | 56 | 93.58 | 93.47 | 93.5125 | +# +-------+--------------+-------+-------+---------+ +# | 32 | 110 | 93.6 | 93.29 | 93.4575 | +# +-------+--------------+-------+-------+---------+ +# | | | | | | +# +-------+--------------+-------+-------+---------+ +# | 44 | None | 94.07 | 93.5 | 93.7425 | +# +-------+--------------+-------+-------+---------+ +# | 44 | 56 | 94.08 | 93.58 | 93.875 | +# +-------+--------------+-------+-------+---------+ +# | 44 | 110 | 94.13 | 93.75 | 93.95 | +# +-------+--------------+-------+-------+---------+ +# | | | | | | +# +-------+--------------+-------+-------+---------+ +# | 56 | None | 94.2 | 93.52 | 93.8 | +# +-------+--------------+-------+-------+---------+ +# | 56 | 110 | 94.47 | 94.0 | 94.16 | +# +-------+--------------+-------+-------+---------+ +# | | | | | | +# +-------+--------------+-------+-------+---------+ +# | 110 | None | 94.66 | 94.42 | 94.54 | +# +-------+--------------+-------+-------+---------+ diff --git a/examples/quantization/preact_resnet_cifar_dorefa.yaml b/examples/quantization/preact_resnet_cifar_dorefa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ae1faadaecade9a30bd6245a2f82db47ea9442d --- /dev/null +++ b/examples/quantization/preact_resnet_cifar_dorefa.yaml @@ -0,0 +1,116 @@ +# Scheduler for training pre-activation ResNet on CIFAR-10, quantized using the DoReFa scheme +# See: +# https://nervanasystems.github.io/distiller/algo_quantization/index.html#dorefa +# https://arxiv.org/abs/1606.06160 +# +# Applicable to ResNet 20 / 32 / 44 / 56 / 110 +# +# Command line for training (running from the compress_classifier.py directory): +# python compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 <path_to_cifar10_dataset> -j 1 --epochs 200 --compress=../quantization/preact_resnet_cifar_dorefa.yaml --wd=0.0002 --vs=0 --gpus 0 +# +# Notes: +# * Replace '-a preact_resnet20_cifar' with the required depth +# * '--wd-0.0002': Weight decay of 0.0002 is used +# * '--vs=0': We train on the entire training dataset, and validate using the test set +# +# Knowledge Distillation: +# ----------------------- +# To train these models with knowledge distillation, add the following arguments to the command line: +# --kd-teacher preact_resnet44_cifar --kd-resume <path_to_teacher_model_checkpoint> --kd-temp 5.0 --kd-dw 0.7 --kd-sw 0.3 +# +# Notes: +# * Replace 'preact_resnet44_cifar' with the required teacher model +# * To train baseline FP32 that can be used as teacher models, see preact_resnet_cifar_base_fp32.yaml +# * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss +# (that is - the loss of the student predictions vs. the teacher's soft targets). +# * Note we don't change any of the other training hyper-parameters +# * More details on knowledge distillation at: +# https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation +# +# See some experimental results with the hyper-parameters shown above after the YAML schedule + +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 8 + bits_weights: 3 + bits_overrides: + # Don't quantize first and last layer + conv1: + wts: null + acts: null + layer1.0.pre_relu: + wts: null + acts: null + final_relu: + wts: null + acts: null + fc: + wts: null + acts: null + +lr_schedulers: + training_lr: + class: MultiStepMultiGammaLR + milestones: [80, 120, 160] + gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 200 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 161 + frequency: 1 + +# The results listed here are based on 4 runs in each configuration: + +# +-------+--------------+-------------------------+-------------------------+ +# | | | FP32 | DoReFa w3-a8 | +# +-------+--------------+-------------------------+-------------------------+ +# | Depth | Distillation | Best | Worst | Average | Best | Worst | Average | +# | | Teacher | | | | | | | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 20 | None | 92.4 | 91.91 | 92.2225 | 91.87 | 91.34 | 91.605 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 20 | 32 | 92.85 | 92.68 | 92.7375 | 92.16 | 91.96 | 92.0725 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 20 | 44 | 93.09 | 92.64 | 92.795 | 92.54 | 91.9 | 92.2225 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 20 | 56 | 92.77 | 92.52 | 92.6475 | 92.53 | 91.92 | 92.15 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 20 | 110 | 92.87 | 92.66 | 92.7725 | 92.12 | 92.01 | 92.0825 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | | | | | | | | | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 32 | None | 93.31 | 92.93 | 93.13 | 92.66 | 92.33 | 92.485 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 32 | 44 | 93.54 | 93.35 | 93.48 | 93.41 | 93.2 | 93.2875 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 32 | 56 | 93.58 | 93.47 | 93.5125 | 93.18 | 92.76 | 92.93 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 32 | 110 | 93.6 | 93.29 | 93.4575 | 93.36 | 92.99 | 93.175 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | | | | | | | | | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 44 | None | 94.07 | 93.5 | 93.7425 | 93.08 | 92.66 | 92.8125 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 44 | 56 | 94.08 | 93.58 | 93.875 | 93.46 | 93.28 | 93.3875 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 44 | 110 | 94.13 | 93.75 | 93.95 | 93.45 | 93.24 | 93.3825 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | | | | | | | | | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 56 | None | 94.2 | 93.52 | 93.8 | 93.44 | 92.91 | 93.0975 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 56 | 110 | 94.47 | 94.0 | 94.16 | 93.83 | 93.56 | 93.7225 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | | | | | | | | | +# +-------+--------------+-------+-------+---------+-------+-------+---------+ +# | 110 | None | 94.66 | 94.42 | 94.54 | 93.53 | 93.24 | 93.395 | +# +-------+--------------+-------+-------+---------+-------+-------+---------+