Skip to content
Snippets Groups Projects
Unverified Commit c9794e4a authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Add knowledge distillation flow (#41)

* Implemented as a Policy
* Integrated in image classification sample
* Updated docs and README
parent df74040e
No related branches found
No related tags found
No related merge requests found
Showing
with 608 additions and 9 deletions
......@@ -72,8 +72,9 @@ Highlighted features:
- The compression schedule is expressed in a YAML file so that a single file captures the details of experiments. This [dependency injection](https://en.wikipedia.org/wiki/Dependency_injection) design decouples the Distiller scheduler and library from future extensions of algorithms.
* Quantization:
- Automatic mechanism to transform existing models to quantized versions, with customizable bit-width configuration for different layers. No need to re-write the model for different quantization methods.
- Support for training with quantization in the loop
- Support for [training with quantization](https://nervanasystems.github.io/distiller/quantization/index.html#training-with-quantization) in the loop
- One-shot 8-bit quantization of trained full-precision models
* Training with [knowledge distillation](https://nervanasystems.github.io/distiller/knowledge_distillation/index.html), in conjunction with the other available pruning / regularization / quantization methods.
* Export statistics summaries using Pandas dataframes, which makes it easy to slice, query, display and graph the data.
* A set of [Jupyter notebooks](https://nervanasystems.github.io/distiller/jupyter/index.html) to plan experiments and analyze compression results. The graphs and visualizations you see on this page originate from the included Jupyter notebooks.
+ Take a look at [this notebook](https://github.com/NervanaSystems/distiller/blob/master/jupyter/alexnet_insights.ipynb), which compares visual aspects of dense and sparse Alexnet models.
......
......@@ -23,6 +23,7 @@ from .sensitivity import *
from .directives import *
from .policy import *
from .thinning import *
from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
#del utils
del dict_config
......
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.nn.functional as F
from collections import namedtuple
from .policy import ScheduledTrainingPolicy, PolicyLoss, LossComponent
DistillationLossWeights = namedtuple('DistillationLossWeights',
['distill', 'student', 'teacher'])
def add_distillation_args(argparser, arch_choices=None, enable_pretrained=False):
"""
Helper function to make it easier to add command line arguments for knowledge distillation to any script
Arguments:
argparser (argparse.ArgumentParser): Existing parser to which to add the arguments
arch_choices: Optional list of choices to be enforced by the parser for model selection
enable_pretrained (bool): Flag to enable/disable argument for "pre-trained" models.
"""
group = argparser.add_argument_group('Knowledge Distillation Training Arguments')
group.add_argument('--kd-teacher', choices=arch_choices, metavar='ARCH',
help='Model architecture for teacher model')
if enable_pretrained:
group.add_argument('--kd-pretrained', action='store_true', help='Use pre-trained model for teacher')
group.add_argument('--kd-resume', type=str, default='', metavar='PATH',
help='Path to checkpoint from which to load teacher weights')
group.add_argument('--kd-temperature', '--kd-temp', dest='kd_temp', type=float, default=1.0, metavar='TEMP',
help='Knowledge distillation softmax temperature')
group.add_argument('--kd-distill-wt', '--kd-dw', type=float, default=0.5, metavar='WEIGHT',
help='Weight for distillation loss (student vs. teacher soft targets)')
group.add_argument('--kd-student-wt', '--kd-sw', type=float, default=0.5, metavar='WEIGHT',
help='Weight for student vs. labels loss')
group.add_argument('--kd-teacher-wt', '--kd-tw', type=float, default=0.0, metavar='WEIGHT',
help='Weight for teacher vs. labels loss')
group.add_argument('--kd-start-epoch', type=int, default=0, metavar='EPOCH_NUM',
help='Epoch from which to enable distillation')
class KnowledgeDistillationPolicy(ScheduledTrainingPolicy):
"""
Policy which enables knowledge distillation from a teacher model to a student model, as presented in [1].
Notes:
1. In addition to the standard policy callbacks, this class also provides a 'forward' function that must
be called instead of calling the student model directly as is usually done. This is needed to facilitate
running the teacher model in addition to the student, and for caching the logits for loss calculation.
2. [TO BE ENABLED IN THE NEAR FUTURE] Option to train the teacher model in parallel with the student model,
described as "scheme A" in [2]. This can be achieved by passing teacher loss weight > 0.
3. [1] proposes a weighted average between the different losses. We allow arbitrary weights to be assigned
to each loss.
Arguments:
student_model (nn.Module): The student model, that is - the main model being trained. If only initialized with
random weights, this matches "scheme B" in [2]. If it has been bootstrapped with trained FP32 weights,
this matches "scheme C".
teacher_model (nn.Module): The teacher model from which soft targets are generated for knowledge distillation.
Usually this is a pre-trained model, however in the future it will be possible to train this model as well
(see Note 1 above)
temperature (float): Temperature value used when calculating soft targets and logits (see [1]).
loss_weights (DistillationLossWeights): Named tuple with 3 loss weights
(a) 'distill' for student predictions (default: 0.5) vs. teacher soft-targets
(b) 'student' for student predictions vs. true labels (default: 0.5)
(c) 'teacher' for teacher predictions vs. true labels (default: 0). Currently this is just a placeholder,
and cannot be set to a non-zero value.
[1] Hinton et al., Distilling the Knowledge in a Neural Network (https://arxiv.org/abs/1503.02531)
[2] Mishra and Marr, Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy
(https://arxiv.org/abs/1711.05852)
"""
def __init__(self, student_model, teacher_model, temperature=1.0,
loss_weights=DistillationLossWeights(0.5, 0.5, 0)):
super(KnowledgeDistillationPolicy, self).__init__()
if loss_weights.teacher != 0:
raise NotImplementedError('Using teacher vs. labels loss is not supported yet, '
'for now teacher loss weight must be set to 0')
self.active = False
self.student = student_model
self.teacher = teacher_model
self.temperature = temperature
self.loss_wts = loss_weights
self.last_students_logits = None
self.last_teacher_logits = None
def forward(self, *inputs):
"""
Performs forward propagation through both student and teached models and caches the logits.
This function MUST be used instead of calling the student model directly.
Returns:
The student model's returned output, to be consistent with what a script using this would expect
"""
if not self.active:
return self.student(*inputs)
if self.loss_wts.teacher == 0:
with torch.no_grad():
self.last_teacher_logits = self.teacher(*inputs)
else:
self.last_teacher_logits = self.teacher(*inputs)
out = self.student(*inputs)
self.last_students_logits = out.new_tensor(out, requires_grad=True)
return out
# Since the "forward" function isn't a policy callback, we use the epoch callbacks to toggle the
# activation of distillation according the schedule defined by the user
def on_epoch_begin(self, model, zeros_mask_dict, meta):
self.active = True
def on_epoch_end(self, model, zeros_mask_dict, meta):
self.active = False
def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, zeros_mask_dict,
optimizer=None):
# TODO: Consider adding 'labels' as an argument to this callback, so we can support teacher vs. labels loss
# (Otherwise we can't do it with a sub-class of ScheduledTrainingPolicy)
if not self.active:
return None
if self.last_teacher_logits is None or self.last_students_logits is None:
raise RuntimeError("KnowledgeDistillationPolicy: Student and or teacher logits were not cached. "
"Make sure to call KnowledgeDistillationPolicy.forward() in your script instead of "
"calling the model directly.")
# Calculate distillation loss
soft_log_probs = F.log_softmax(self.last_students_logits / self.temperature, dim=1)
# soft_targets = F.softmax(self.cached_teacher_logits[minibatch_id] / self.temperature)
soft_targets = F.softmax(self.last_teacher_logits / self.temperature, dim=1)
# The averaging used in PyTorch KL Div implementation is wrong, so we work around as suggested in
# https://pytorch.org/docs/stable/nn.html#kldivloss
# (Also see https://github.com/pytorch/pytorch/issues/6622, https://github.com/pytorch/pytorch/issues/2259)
distillation_loss = F.kl_div(soft_log_probs, soft_targets.detach(), size_average=False) / soft_targets.shape[0]
# The loss passed to the callback is the student's loss vs. the true labels, so we can use it directly, no
# need to calculate again
overall_loss = self.loss_wts.student * loss + self.loss_wts.distill * distillation_loss
return PolicyLoss(overall_loss,
[LossComponent('Distill Loss', distillation_loss)])
docs-src/docs/imgs/knowledge_distillation.png

232 KiB

# Knowledge Distillation
(For details on how to train a model with knowledge distillation in Distiller, see [here](schedule.md#knowledge-distillation))
Knowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as "teacher-student", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably).
The method was first proposed by [Bucila et al., 2006](#bucila-et-al-2006) and generalized by [Hinton et al., 2015](#hinton-et-al-2015). The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a [video lecture](https://www.youtube.com/watch?v=EK61htlw8hY) with [slides](http://www.ttic.edu/dl/dark14.pdf) is also available).
In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, [Hinton et al., 2015](#hinton-et-al-2015) introduced the concept of "softmax temperature". The probability \(p_i\) of class \(i\) is calculated from the logits \(z\) as:
\[p_i = \frac{exp\left(\frac{z_i}{T}\right)}{\sum_{j} \exp\left(\frac{z_j}{T}\right)}\]
where \(T\) is the temperature parameter. When \(T=1\) we get the standard softmax function. As \(T\) grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the "dark knowledge" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of \(T\) to compute the softmax on the student's logits. We call this loss the "distillation loss".
[Hinton et al., 2015](#hinton-et-al-2015) found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the "standard" loss between the student's predicted class probabilities and the ground-truth labels (also called "hard labels/targets"). We dub this loss the "student loss". When calculating the class probabilities for the student loss we use \(T = 1\).
The overall loss function, incorporating both distillation and student losses, is calculated as:
\[\mathcal{L}(x;W) = \alpha * \mathcal{H}(y, \sigma(z_s; T=1)) + \beta * \mathcal{H}(\sigma(z_t; T=\tau), \sigma(z_s, T=\tau))\]
where \(x\) is the input, \(W\) are the student model parameters, \(y\) is the ground truth label, \(\mathcal{H}\) is the cross-entropy loss function, \(\sigma\) is the softmax function parameterized by the temperature \(T\), and \(\alpha\) and \(\beta\) are coefficients. \(z_s\) and \(z_t\) are the logits of the student and teacher respectively.
![Knowledge Distillation](imgs/knowledge_distillation.png)
## New Hyper-Parameters
In general \(\tau\), \(\alpha\) and \(\beta\) are hyper parameters.
In their experiments, [Hinton et al., 2015](#hinton-et-al-2015) use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have.
With regards to \(\alpha\) and \(\beta\), [Hinton et al., 2015](#hinton-et-al-2015) use a weighted average between the distillation loss and the student loss. That is, \(\beta = 1 - \alpha\). They note that in general, they obtained the best results when setting \(\alpha\) to be much smaller than \(\beta\) (although in one of their experiments they use \(\alpha = \beta = 0.5\)). Other works which utilize knowledge distillation don't use a weighted average. Some set \(\alpha = 1\) while leaving \(\beta\) tunable, while others don't set any constraints.
## <a name="combining"></a>Combining with Other Model Compression Techniques
In the "basic" scenario, the smaller (student) model is a pre-defined architecture which just has a smaller number of parameters compared to the teacher model. For example, we could train ResNet-18 by distilling knowledge from ResNet-34. But, a model with smaller capacity can also be obtained by other model compression techniques - sparsification and/or quantization. So, for example, we could train a 4-bit ResNet-18 model with some method using quantization-aware training, and use a distillation loss function as described above. In that case, the teacher model can even be a FP32 ResNet-18 model. Same goes for pruning and regularization.
[Tann et al., 2017](#tann-et-al-2017), [Mishra and Marr, 2018](#mishra-and-marr-2018) and [Polino et al., 2018](#polino-et-al-2018) are some works that combine knowledge distillation with **quantization**. [Theis et al., 2018](#theis-et-al-2018) and [Ashok et al., 2018](#ashok-et-al-2018) combine distillation with **pruning**.
## References
<div id="bucila-et-al-2006"></div>
**Cristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil**. Model Compression. [KDD, 2006](https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf)
<div id="hinton-et-al-2015"></div>
**Geoffrey Hinton, Oriol Vinyals and Jeff Dean**. Distilling the Knowledge in a Neural Network. [arxiv:1503.02531](https://arxiv.org/abs/1503.02531)
<div id="tann-et-al-2017"></div>
**Hokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda**. Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. [DAC, 2017](https://arxiv.org/abs/1705.04288)
<div id="mishra-and-marr-2018"></div>
**Asit Mishra and Debbie Marr**. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. [ICLR, 2018](https://openreview.net/forum?id=B1ae1lZRb)
<div id="polino-et-al-2018"></div>
**Antonio Polino, Razvan Pascanu and Dan Alistarh**. Model compression via distillation and quantization. [ICLR, 2018](https://openreview.net/forum?id=S1XolQbRW)
<div id="ashok-et-al-2018"></div>
**Anubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani**. N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. [ICLR, 2018](https://openreview.net/forum?id=B1hcZZ-AW)
<div id="theis-et-al-2018"></div>
**Lucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Huszár**. Faster gaze prediction with dense networks and Fisher pruning. [arxiv:1801.05787](https://arxiv.org/abs/1801.05787)
......@@ -49,7 +49,7 @@ Another possible optimization point is **scale-factor scope**. The most common w
Naively quantizing a FP32 model to INT4 and lower usually incurs significant accuracy degradation. Many works have tried to mitigate this effect. They usually employ one or more of the following concepts in order to improve model accuracy:
- **Training / Re-Training**: For INT4 and lower, training is required in order to obtain reasonable accuracy. The training loop is modified to take quantization into account. See details in the [next section](#training-with-quantization).
[Zhou S et al., 2016](#zhou-et-al-2016) have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods *require* a trained FP32 model, either as a starting point ([Zhou A et al., 2017](#zhou-et-al-2017)), or as a teacher network in a knowledge distillation training setup ([Mishra and Marr, 2018](#mishra-and-marr-2018)).
[Zhou S et al., 2016](#zhou-et-al-2016) have shown that bootstrapping the quantized model with trained FP32 weights leads to higher accuracy, as opposed to training from scratch. Other methods *require* a trained FP32 model, either as a starting point ([Zhou A et al., 2017](#zhou-et-al-2017)), or as a teacher network in a knowledge distillation training setup (see [here](knowledge_distillation.md#combining)).
- **Replacing the activation function**: The most common activation function in vision models is ReLU, which is unbounded. That is - its dynamic range is not limited for positive inputs. This is very problematic for INT4 and below due to the very limited range and resolution. Therefore, most methods replace ReLU with another function which is bounded. In some cases a clipping function with hard coded values is used ([Zhou S et al., 2016](#zhou-et-al-2016), [Mishra et al., 2018](#mishra-et-al-2018)). Another method learns the clipping value per layer, with better results ([Choi et al., 2018](#choi-et-al-2018)). Once the clipping value is set, the scale factor used for quantization is also set, and no further calibration steps are required (as opposed to INT8 methods described above).
- **Modifying network structure**: [Mishra et al., 2018](#mishra-et-al-2018) try to compensate for the loss of information due to quantization by using wider layers (more channels). [Lin et al., 2017](#lin-et-al-2017) proposed a binary quantization method in which a single FP32 convolution is replaced with multiple binary convolutions, each scaled to represent a different "base", covering a larger dynamic range overall.
- **First and last layer**: Many methods do not quantize the first and last layer of the model. It has been observed by [Han et al., 2015](#han-et-al-2015) that the first convolutional layer is more sensitive to weights pruning, and some quantization works cite the same reason and show it empirically ([Zhou S et al., 2016](#zhou-et-al-2016), [Choi et al., 2018](#choi-et-al-2018)). Some works also note that these layers usually constitute a very small portion of the overall computation within the model, further reducing the motivation to quantize them ([Rastegari et al., 2016](#rastegari-et-al-2016)). Most methods keep the first and last layers at FP32. However, [Choi et al., 2018](#choi-et-al-2018) showed that "conservative" quantization of these layers, e.g. to INT8, does not reduce accuracy.
......@@ -90,9 +90,6 @@ An important question in this context is how to back-propagate through the quant
<div id="zhou-et-al-2017"></div>
**Aojun Zhou, Anbang Yao, Yiwen Guo, Lin Xu and Yurong Chen**. Incremental Network Quantization: Towards Lossless CNNs with Low-precision Weights. [ICLR, 2017](https://arxiv.org/abs/1702.03044)
<div id="mishra-and-marr-2018"></div>
**Asit Mishra and Debbie Marr**. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. [ICLR, 2018](https://openreview.net/forum?id=B1ae1lZRb)
<div id="mishra-et-al-2018"></div>
**Asit Mishra, Eriko Nurvitadhi, Jeffrey J Cook and Debbie Marr**. WRPN: Wide Reduced-Precision Networks. [ICLR, 2018](https://openreview.net/forum?id=B1ZvaaeAZ)
......
......@@ -10,7 +10,7 @@ Let's briefly discuss the main mechanisms and abstractions: A schedule specifica
These define the **what** part of the schedule.
The Policies define the **when** part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing.
The CompressionScheduler is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.
The `CompressionScheduler` is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.
## Syntax through example
We'll use ```alexnet.schedule_agp.yaml``` to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet.
......@@ -311,3 +311,74 @@ policies:
```
**Important Note**: As mentioned [here](design.md#training-with-quantization), since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to `prepare_model()` must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a "warm-startup" (or "boot-strapping"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.
## Knowledge Distillation
Knowledge distillation (see [here](knowledge_distillation.md)) is also implemented as a `Policy`, which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above.
To make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation:
```
import argparse
import distiller
parser = argparse.ArgumentParser()
distiller.knowledge_distillation.add_distillation_args(parser)
```
(The `add_distillation_args` function accepts some optional arguments, see its implementation at `distiller/knowledge_distillation.py` for details)
These are the command line arguments exposed by this function:
```
Knowledge Distillation Training Arguments:
--kd-teacher ARCH Model architecture for teacher model
--kd-pretrained Use pre-trained model for teacher
--kd-resume PATH Path to checkpoint from which to load teacher weights
--kd-temperature TEMP, --kd-temp TEMP
Knowledge distillation softmax temperature
--kd-distill-wt WEIGHT, --kd-dw WEIGHT
Weight for distillation loss (student vs. teacher soft
targets)
--kd-student-wt WEIGHT, --kd-sw WEIGHT
Weight for student vs. labels loss
--kd-teacher-wt WEIGHT, --kd-tw WEIGHT
Weight for teacher vs. labels loss
--kd-start-epoch EPOCH_NUM
Epoch from which to enable distillation
```
Once arguments have been parsed, some initialization code is required, similar to the following:
```
# Assuming:
# "args" variable holds command line arguments
# "model" variable holds the model we're going to train, that is - the student model
# "compression_scheduler" variable holds a CompressionScheduler instance
args.kd_policy = None
if args.kd_teacher:
# Create teacher model - replace this with your model creation code
teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)
if args.kd_resume:
teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)
# Create policy and add to scheduler
dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)
args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)
compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,
frequency=1)
```
Finally, during the training loop, we need to perform forward propagation through the teacher model as well. The `KnowledgeDistillationPolicy` class keeps a reference to both the student and teacher models, and exposes a `forward` function that performs forward propagation on both of them. Since this is not one of the standard policy callbacks, we need to call this function manually from our training loop, as follows:
```
if args.kd_policy is None:
# Revert to a "normal" forward-prop call if no knowledge distillation policy is present
output = model(input_var)
else:
output = args.kd_policy.forward(input_var)
```
To see this integration in action, take a look at the image classification sample at `examples/classifier_compression/compress_classifier.py`.
......@@ -21,6 +21,7 @@ pages:
- 'Pruning': 'pruning.md'
- 'Regularization': 'regularization.md'
- 'Quantization': 'quantization.md'
- 'Knowledge Distillation': 'knowledge_distillation.md'
- Algorithms:
- Pruning: algo_pruning.md
- Quantization: algo_quantization.md
......
......@@ -77,6 +77,10 @@
<a class="" href="/quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="/knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
......@@ -84,6 +84,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......@@ -296,7 +300,7 @@ abundant and gradually reduce the number of weights being pruned each time as th
<a href="../algo_quantization/index.html" class="btn btn-neutral float-right" title="Quantization">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../quantization/index.html" class="btn btn-neutral" title="Quantization"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../knowledge_distillation/index.html" class="btn btn-neutral" title="Knowledge Distillation"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
......@@ -322,7 +326,7 @@ abundant and gradually reduce the number of weights being pruned each time as th
<span class="rst-current-version" data-toggle="rst-current-version">
<span><a href="../quantization/index.html" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span><a href="../knowledge_distillation/index.html" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../algo_quantization/index.html" style="color: #fcfcfc">Next &raquo;</a></span>
......
......@@ -84,6 +84,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
......@@ -84,6 +84,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
docs/imgs/decision_boundary.png

274 KiB

docs/imgs/knowledge_distillation.png

232 KiB

......@@ -98,6 +98,10 @@
<a class="" href="quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......@@ -246,5 +250,5 @@ And of course, if we used a sparse or compressed representation, then we are red
<!--
MkDocs version : 0.17.2
Build Date UTC : 2018-07-22 11:48:56
Build Date UTC : 2018-09-03 21:12:52
-->
......@@ -100,6 +100,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
......@@ -84,6 +84,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="../img/favicon.ico">
<title>Knowledge Distillation - Neural Network Distiller</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="../css/theme.css" type="text/css" />
<link rel="stylesheet" href="../css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="../css/highlight.css">
<link href="../extra.css" rel="stylesheet">
<script>
// Current page data
var mkdocs_page_name = "Knowledge Distillation";
var mkdocs_page_input_path = "knowledge_distillation.md";
var mkdocs_page_url = "/knowledge_distillation/index.html";
</script>
<script src="../js/jquery-2.1.1.min.js"></script>
<script src="../js/modernizr-2.8.3.min.js"></script>
<script type="text/javascript" src="../js/highlight.pack.js"></script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-nav-search">
<a href="../index.html" class="icon icon-home"> Neural Network Distiller</a>
<div role="search">
<form id ="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul class="current">
<li class="toctree-l1">
<a class="" href="../index.html">Home</a>
</li>
<li class="toctree-l1">
<a class="" href="../install/index.html">Installation</a>
</li>
<li class="toctree-l1">
<a class="" href="../usage/index.html">Usage</a>
</li>
<li class="toctree-l1">
<a class="" href="../schedule/index.html">Compression scheduling</a>
</li>
<li class="toctree-l1">
<span class="caption-text">Compressing models</span>
<ul class="subnav">
<li class="">
<a class="" href="../pruning/index.html">Pruning</a>
</li>
<li class="">
<a class="" href="../regularization/index.html">Regularization</a>
</li>
<li class="">
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class=" current">
<a class="current" href="index.html">Knowledge Distillation</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#knowledge-distillation">Knowledge Distillation</a></li>
<ul>
<li><a class="toctree-l4" href="#new-hyper-parameters">New Hyper-Parameters</a></li>
<li><a class="toctree-l4" href="#references">References</a></li>
</ul>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1">
<span class="caption-text">Algorithms</span>
<ul class="subnav">
<li class="">
<a class="" href="../algo_pruning/index.html">Pruning</a>
</li>
<li class="">
<a class="" href="../algo_quantization/index.html">Quantization</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="../model_zoo/index.html">Model Zoo</a>
</li>
<li class="toctree-l1">
<a class="" href="../jupyter/index.html">Jupyter notebooks</a>
</li>
<li class="toctree-l1">
<a class="" href="../design/index.html">Design</a>
</li>
</ul>
</div>
&nbsp;
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">Neural Network Distiller</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html">Docs</a> &raquo;</li>
<li>Compressing models &raquo;</li>
<li>Knowledge Distillation</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main">
<div class="section">
<h1 id="knowledge-distillation">Knowledge Distillation</h1>
<p>(For details on how to train a model with knowledge distillation in Distiller, see <a href="../schedule/index.html#knowledge-distillation">here</a>)</p>
<p>Knowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as "teacher-student", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably).</p>
<p>The method was first proposed by <a href="#bucila-et-al-2006">Bucila et al., 2006</a> and generalized by <a href="#hinton-et-al-2015">Hinton et al., 2015</a>. The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a <a href="https://www.youtube.com/watch?v=EK61htlw8hY">video lecture</a> with <a href="http://www.ttic.edu/dl/dark14.pdf">slides</a> is also available).</p>
<p>In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, <a href="#hinton-et-al-2015">Hinton et al., 2015</a> introduced the concept of "softmax temperature". The probability <script type="math/tex">p_i</script> of class <script type="math/tex">i</script> is calculated from the logits <script type="math/tex">z</script> as:</p>
<p>
<script type="math/tex; mode=display">p_i = \frac{exp\left(\frac{z_i}{T}\right)}{\sum_{j} \exp\left(\frac{z_j}{T}\right)}</script>
</p>
<p>where <script type="math/tex">T</script> is the temperature parameter. When <script type="math/tex">T=1</script> we get the standard softmax function. As <script type="math/tex">T</script> grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the "dark knowledge" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of <script type="math/tex">T</script> to compute the softmax on the student's logits. We call this loss the "distillation loss".</p>
<p><a href="#hinton-et-al-2015">Hinton et al., 2015</a> found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the "standard" loss between the student's predicted class probabilities and the ground-truth labels (also called "hard labels/targets"). We dub this loss the "student loss". When calculating the class probabilities for the student loss we use <script type="math/tex">T = 1</script>. </p>
<p>The overall loss function, incorporating both distillation and student losses, is calculated as:</p>
<p>
<script type="math/tex; mode=display">\mathcal{L}(x;W) = \alpha * \mathcal{H}(y, \sigma(z_s; T=1)) + \beta * \mathcal{H}(\sigma(z_t; T=\tau), \sigma(z_s, T=\tau))</script>
</p>
<p>where <script type="math/tex">x</script> is the input, <script type="math/tex">W</script> are the student model parameters, <script type="math/tex">y</script> is the ground truth label, <script type="math/tex">\mathcal{H}</script> is the cross-entropy loss function, <script type="math/tex">\sigma</script> is the softmax function parameterized by the temperature <script type="math/tex">T</script>, and <script type="math/tex">\alpha</script> and <script type="math/tex">\beta</script> are coefficients. <script type="math/tex">z_s</script> and <script type="math/tex">z_t</script> are the logits of the student and teacher respectively.</p>
<p><img alt="Knowledge Distillation" src="../imgs/knowledge_distillation.png" /></p>
<h2 id="new-hyper-parameters">New Hyper-Parameters</h2>
<p>In general <script type="math/tex">\tau</script>, <script type="math/tex">\alpha</script> and <script type="math/tex">\beta</script> are hyper parameters.</p>
<p>In their experiments, <a href="#hinton-et-al-2015">Hinton et al., 2015</a> use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have.</p>
<p>With regards to <script type="math/tex">\alpha</script> and <script type="math/tex">\beta</script>, <a href="#hinton-et-al-2015">Hinton et al., 2015</a> use a weighted average between the distillation loss and the student loss. That is, <script type="math/tex">\beta = 1 - \alpha</script>. They note that in general, they obtained the best results when setting <script type="math/tex">\alpha</script> to be much smaller than <script type="math/tex">\beta</script> (although in one of their experiments they use <script type="math/tex">\alpha = \beta = 0.5</script>). Other works which utilize knowledge distillation don't use a weighted average. Some set <script type="math/tex">\alpha = 1</script> while leaving <script type="math/tex">\beta</script> tunable, while others don't set any constraints.</p>
<h2 id="combining-with-other-model-compression-techniques"><a name="combining"></a>Combining with Other Model Compression Techniques</h2>
<p>In the "basic" scenario, the smaller (student) model is a pre-defined architecture which just has a smaller number of parameters compared to the teacher model. For example, we could train ResNet-18 by distilling knowledge from ResNet-34. But, a model with smaller capacity can also be obtained by other model compression techniques - sparsification and/or quantization. So, for example, we could train a 4-bit ResNet-18 model with some method using quantization-aware training, and use a distillation loss function as described above. In that case, the teacher model can even be a FP32 ResNet-18 model. Same goes for pruning and regularization.</p>
<p><a href="#tann-et-al-2017">Tann et al., 2017</a>, <a href="#mishra-and-marr-2018">Mishra and Marr, 2018</a> and <a href="#polino-et-al-2018">Polino et al., 2018</a> are some works that combine knowledge distillation with <strong>quantization</strong>. <a href="#theis-et-al-2018">Theis et al., 2018</a> and <a href="#ashok-et-al-2018">Ashok et al., 2018</a> combine distillation with <strong>pruning</strong>.</p>
<h2 id="references">References</h2>
<p><div id="bucila-et-al-2006"></div>
<strong>Cristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil</strong>. Model Compression. <a href="https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf">KDD, 2006</a></p>
<div id="hinton-et-al-2015"></div>
<p><strong>Geoffrey Hinton, Oriol Vinyals and Jeff Dean</strong>. Distilling the Knowledge in a Neural Network. <a href="https://arxiv.org/abs/1503.02531">arxiv:1503.02531</a></p>
<div id="tann-et-al-2017"></div>
<p><strong>Hokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda</strong>. Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. <a href="https://arxiv.org/abs/1705.04288">DAC, 2017</a></p>
<div id="mishra-and-marr-2018"></div>
<p><strong>Asit Mishra and Debbie Marr</strong>. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. <a href="https://openreview.net/forum?id=B1ae1lZRb">ICLR, 2018</a></p>
<div id="polino-et-al-2018"></div>
<p><strong>Antonio Polino, Razvan Pascanu and Dan Alistarh</strong>. Model compression via distillation and quantization. <a href="https://openreview.net/forum?id=S1XolQbRW">ICLR, 2018</a></p>
<div id="ashok-et-al-2018"></div>
<p><strong>Anubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani</strong>. N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. <a href="https://openreview.net/forum?id=B1hcZZ-AW">ICLR, 2018</a></p>
<div id="theis-et-al-2018"></div>
<p><strong>Lucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Huszár</strong>. Faster gaze prediction with dense networks and Fisher pruning. <a href="https://arxiv.org/abs/1801.05787">arxiv:1801.05787</a></p>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../algo_pruning/index.html" class="btn btn-neutral float-right" title="Pruning">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../quantization/index.html" class="btn btn-neutral" title="Quantization"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" style="cursor: pointer">
<span class="rst-current-version" data-toggle="rst-current-version">
<span><a href="../quantization/index.html" style="color: #fcfcfc;">&laquo; Previous</a></span>
<span style="margin-left: 15px"><a href="../algo_pruning/index.html" style="color: #fcfcfc">Next &raquo;</a></span>
</span>
</div>
<script>var base_url = '..';</script>
<script src="../js/theme.js"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML"></script>
<script src="../search/require.js"></script>
<script src="../search/search.js"></script>
</body>
</html>
......@@ -84,6 +84,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
......@@ -106,6 +106,10 @@
<a class="" href="../quantization/index.html">Quantization</a>
</li>
<li class="">
<a class="" href="../knowledge_distillation/index.html">Knowledge Distillation</a>
</li>
</ul>
</li>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment