Skip to content
Snippets Groups Projects
Commit 443e7381 authored by Guy Jacob's avatar Guy Jacob Committed by Neta Zmora
Browse files

8-bit Quantization - Save model + add test + updated docs (#3)

parent 27ddbcc2
No related branches found
No related tags found
No related merge requests found
......@@ -196,6 +196,8 @@ This example performs 8-bit quantization of ResNet20 for CIFAR10. We've include
$ python3 compress_classifier.py -a resnet20_cifar ../../../data.cifar10 --resume ../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar --quantize --evaluate
```
The command-line above will save a checkpoint named `quantized_checkpoint.pth.tar` containing the quantized model parameters.
### Explore the sample Jupyter notebooks
The set of notebooks that come with Distiller is described [here](https://nervanasystems.github.io/distiller/jupyter/index.html#using-the-distiller-notebooks), which also explains the steps to install the Jupyter notebook server.<br>
After installing and running the server, take a look at the [notebook](https://github.com/NervanaSystems/distiller/blob/master/jupyter/sensitivity_analysis.ipynb) covering pruning sensitivity analysis.
......
......@@ -2,17 +2,22 @@
## Symmetric Linear Quantization
**Linear** - float value quantized by multiplying with scale factor. **Symmetric** - no quantization bias (or "offset") used, so zero in the float domain is mapped to zero in the integer domain.
In the current implementation the scale factor is chosen so that the entire range of the tensor is quantized. So, we get: (Using \(q\) to denote the scale factor and \(x\) to denote the tensor being quantized)
\[q_x = \frac{2^n-1}{\max|x|}\]
\[x_q = round(q_x\cdot x_f)\]
Where \(n\) is the number of bits used for quantization.
For weights and bias the scale factor is determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online").
Currently, quantized implementations are provided for **convolution** and **fully-connected** layers. These layers are quantized as follows (using \(x, y, w, b\) for input, output, weights, bias respectively):
\[y_f = \sum{x_f\cdot w_f} + b_f = \sum{x_f\cdot w_f} + b_f = \sum{\frac{x_q}{q_x}\cdot \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)}\]
\[y_q = round(q_y\cdot y_f) = round(\frac{q_y}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)})\]
Note how the bias has to be re-scaled to match the scale of the summation.
All other layers are executed in FP32. This is done by adding quantize and de-quantize operations at the beginning and end of the quantized layers.
In this method, a float value is quantized by multiplying with a numeric constant (the **scale factor**), hence it is **Linear**. We use a signed integer to represent the quantized range, with no quantization bias (or "offset") used. As a result, the floating-point range considered for quantization is **symmetric** with respect to zero.
In the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers).
Let us denote the original floating-point tensor by \(x_f\), the quantized tensor by \(x_q\), the scale factor by \(q_x\) and the number of bits used for quantization by \(n\). Then, we get:
\[q_x = \frac{2^{n-1}-1}{\max|x|}\]
\[x_q = round(q_x x_f)\]
(The \(round\) operation is round-to-nearest-integer)
Let's see how a **convolution** or **fully-connected (FC)** layer is quantized using this method: (we denote input, output, weights and bias with \(x, y, w\) and \(b\) respectively)
\[y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q}{q_x} \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)}\]
\[y_q = round(q_y y_f) = round(\frac{q_y}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)})\]
Note how the bias has to be re-scaled to match the scale of the summation.
!!! note
This method is implemented as **inference only**, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, and as such, using it with \(n < 8\) is likely to lead to severe accuracy degradation for any non-trivial workload.
\ No newline at end of file
### Implementation
We've implemented **convolution** and **FC** using this method.
- They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values.
- All other layers are unaffected and are executed using their original FP32 implementation.
- For weights and bias the scale factor is determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online").
- **Important note:** Currently, this method is implemented as **inference only**, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with \(n < 8\) is likely to lead to severe accuracy degradation for any non-trivial workload.
\ No newline at end of file
......@@ -167,21 +167,24 @@
<h1 id="quantization-algorithms">Quantization Algorithms</h1>
<h2 id="symmetric-linear-quantization">Symmetric Linear Quantization</h2>
<p><strong>Linear</strong> - float value quantized by multiplying with scale factor. <strong>Symmetric</strong> - no quantization bias (or "offset") used, so zero in the float domain is mapped to zero in the integer domain.<br />
In the current implementation the scale factor is chosen so that the entire range of the tensor is quantized. So, we get: (Using <script type="math/tex">q</script> to denote the scale factor and <script type="math/tex">x</script> to denote the tensor being quantized)
<script type="math/tex; mode=display">q_x = \frac{2^n-1}{\max|x|}</script>
<script type="math/tex; mode=display">x_q = round(q_x\cdot x_f)</script>
Where <script type="math/tex">n</script> is the number of bits used for quantization.<br />
For weights and bias the scale factor is determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online").<br />
Currently, quantized implementations are provided for <strong>convolution</strong> and <strong>fully-connected</strong> layers. These layers are quantized as follows (using <script type="math/tex">x, y, w, b</script> for input, output, weights, bias respectively):
<script type="math/tex; mode=display">y_f = \sum{x_f\cdot w_f} + b_f = \sum{x_f\cdot w_f} + b_f = \sum{\frac{x_q}{q_x}\cdot \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)}</script>
<script type="math/tex; mode=display">y_q = round(q_y\cdot y_f) = round(\frac{q_y}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)})</script>
Note how the bias has to be re-scaled to match the scale of the summation.<br />
All other layers are executed in FP32. This is done by adding quantize and de-quantize operations at the beginning and end of the quantized layers. </p>
<div class="admonition note">
<p class="admonition-title">Note</p>
</div>
<p>This method is implemented as <strong>inference only</strong>, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, and as such, using it with <script type="math/tex">n < 8</script> is likely to lead to severe accuracy degradation for any non-trivial workload.</p>
<p>In this method, a float value is quantized by multiplying with a numeric constant (the <strong>scale factor</strong>), hence it is <strong>Linear</strong>. We use a signed integer to represent the quantized range, with no quantization bias (or "offset") used. As a result, the floating-point range considered for quantization is <strong>symmetric</strong> with respect to zero.<br />
In the current implementation the scale factor is chosen so that the entire range of the floating-point tensor is quantized (we do not attempt to remove outliers).<br />
Let us denote the original floating-point tensor by <script type="math/tex">x_f</script>, the quantized tensor by <script type="math/tex">x_q</script>, the scale factor by <script type="math/tex">q_x</script> and the number of bits used for quantization by <script type="math/tex">n</script>. Then, we get:
<script type="math/tex; mode=display">q_x = \frac{2^{n-1}-1}{\max|x|}</script>
<script type="math/tex; mode=display">x_q = round(q_x x_f)</script>
(The <script type="math/tex">round</script> operation is round-to-nearest-integer) </p>
<p>Let's see how a <strong>convolution</strong> or <strong>fully-connected (FC)</strong> layer is quantized using this method: (we denote input, output, weights and bias with <script type="math/tex">x, y, w</script> and <script type="math/tex">b</script> respectively)
<script type="math/tex; mode=display">y_f = \sum{x_f w_f} + b_f = \sum{\frac{x_q}{q_x} \frac{w_q}{q_w}} + \frac{b_q}{q_b} = \frac{1}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)}</script>
<script type="math/tex; mode=display">y_q = round(q_y y_f) = round(\frac{q_y}{q_x q_w} \sum{(x_q w_q + \frac{q_b}{q_x q_w}b_q)})</script>
Note how the bias has to be re-scaled to match the scale of the summation.</p>
<h3 id="implementation">Implementation</h3>
<p>We've implemented <strong>convolution</strong> and <strong>FC</strong> using this method. </p>
<ul>
<li>They are implemented by wrapping the existing PyTorch layers with quantization and de-quantization operations. That is - the computation is done on floating-point tensors, but the values themselves are restricted to integer values. </li>
<li>All other layers are unaffected and are executed using their original FP32 implementation. </li>
<li>For weights and bias the scale factor is determined once at quantization setup ("offline"), and for activations it is determined dynamically at runtime ("online"). </li>
<li><strong>Important note:</strong> Currently, this method is implemented as <strong>inference only</strong>, with no back-propagation functionality. Hence, it can only be used to quantize a pre-trained FP32 model, with no re-training. As such, using it with <script type="math/tex">n < 8</script> is likely to lead to severe accuracy degradation for any non-trivial workload.</li>
</ul>
</div>
</div>
......
......@@ -246,5 +246,5 @@ And of course, if we used a sparse or compressed representation, then we are red
<!--
MkDocs version : 0.17.2
Build Date UTC : 2018-05-09 14:06:49
Build Date UTC : 2018-05-14 13:58:17
-->
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -4,7 +4,7 @@
<url>
<loc>/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -12,7 +12,7 @@
<url>
<loc>/install/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -20,7 +20,7 @@
<url>
<loc>/usage/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -28,7 +28,7 @@
<url>
<loc>/schedule/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -37,19 +37,19 @@
<url>
<loc>/pruning/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/regularization/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/quantization/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -59,13 +59,13 @@
<url>
<loc>/algo_pruning/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/algo_quantization/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -74,7 +74,7 @@
<url>
<loc>/model_zoo/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -82,7 +82,7 @@
<url>
<loc>/jupyter/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -90,7 +90,7 @@
<url>
<loc>/design/index.html</loc>
<lastmod>2018-05-09</lastmod>
<lastmod>2018-05-14</lastmod>
<changefreq>daily</changefreq>
</url>
......
......@@ -274,7 +274,11 @@ def main():
quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
quantizer.prepare_model()
model.cuda()
test(test_loader, model, criterion, [pylogger], args.print_freq)
top1, _, _ = test(test_loader, model, criterion, [pylogger], args.print_freq)
if args.quantize:
checkpoint_name = 'quantized'
apputils.save_checkpoint(0, args.arch, model, optimizer, best_top1=top1,
name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name)
exit()
if args.compress:
......
......@@ -24,6 +24,10 @@ import time
DS_CIFAR = 'cifar10'
distiller_root = os.path.realpath('..')
examples_root = os.path.join(distiller_root, 'examples')
script_path = os.path.realpath(os.path.join(examples_root, 'classifier_compression',
'compress_classifier.py'))
###########
# Some Basic Logging Mechanisms
......@@ -98,6 +102,9 @@ TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker
test_configs = [
TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [48.340, 92.630]),
TestConfig('-a resnet20_cifar --resume {0} --quantize --evaluate'.
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
DS_CIFAR, accuracy_checker, [91.620, 99.630]),
]
......@@ -143,7 +150,6 @@ def run_tests():
cifar10_path = validate_dataset_path(args.cifar10_path, default='data.cifar10', name='CIFAR-10')
datasets = {DS_CIFAR: cifar10_path}
script_path = os.path.realpath(os.path.join('..', 'examples', 'classifier_compression', 'compress_classifier.py'))
total_configs = len(test_configs)
failed_tests = []
for idx, tc in enumerate(test_configs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment