<!DOCTYPE html> <!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]--> <!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]--> <head> <meta charset="utf-8"> <meta http-equiv="X-UA-Compatible" content="IE=edge"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <link rel="shortcut icon" href="img/favicon.ico"> <title>Compression Scheduling - Neural Network Distiller</title> <link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'> <link rel="stylesheet" href="css/theme.css" type="text/css" /> <link rel="stylesheet" href="css/theme_extra.css" type="text/css" /> <link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css"> <link href="extra.css" rel="stylesheet"> <script> // Current page data var mkdocs_page_name = "Compression Scheduling"; var mkdocs_page_input_path = "schedule.md"; var mkdocs_page_url = null; </script> <script src="js/jquery-2.1.1.min.js" defer></script> <script src="js/modernizr-2.8.3.min.js" defer></script> <script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script> <script>hljs.initHighlightingOnLoad();</script> </head> <body class="wy-body-for-nav" role="document"> <div class="wy-grid-for-nav"> <nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav"> <div class="wy-side-nav-search"> <a href="index.html" class="icon icon-home"> Neural Network Distiller</a> <div role="search"> <form id ="rtd-search-form" class="wy-form" action="./search.html" method="get"> <input type="text" name="q" placeholder="Search docs" title="Type search term here" /> </form> </div> </div> <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> <ul class="current"> <li class="toctree-l1"> <a class="" href="index.html">Home</a> </li> <li class="toctree-l1"> <a class="" href="install.html">Installation</a> </li> <li class="toctree-l1"> <a class="" href="usage.html">Usage</a> </li> <li class="toctree-l1 current"> <a class="current" href="schedule.html">Compression Scheduling</a> <ul class="subnav"> <li class="toctree-l2"><a href="#compression-scheduler">Compression scheduler</a></li> <ul> <li><a class="toctree-l3" href="#high-level-overview">High level overview</a></li> <li><a class="toctree-l3" href="#syntax-through-example">Syntax through example</a></li> <li><a class="toctree-l3" href="#regularization">Regularization</a></li> <li><a class="toctree-l3" href="#mixing-it-up">Mixing it up</a></li> <li><a class="toctree-l3" href="#quantization-aware-training">Quantization-Aware Training</a></li> <li><a class="toctree-l3" href="#post-training-quantization">Post-Training Quantization</a></li> <li><a class="toctree-l3" href="#pruning-fine-control">Pruning Fine-Control</a></li> <li><a class="toctree-l3" href="#knowledge-distillation">Knowledge Distillation</a></li> </ul> </ul> </li> <li class="toctree-l1"> <span class="caption-text">Compressing Models</span> <ul class="subnav"> <li class=""> <a class="" href="pruning.html">Pruning</a> </li> <li class=""> <a class="" href="regularization.html">Regularization</a> </li> <li class=""> <a class="" href="quantization.html">Quantization</a> </li> <li class=""> <a class="" href="knowledge_distillation.html">Knowledge Distillation</a> </li> <li class=""> <a class="" href="conditional_computation.html">Conditional Computation</a> </li> </ul> </li> <li class="toctree-l1"> <span class="caption-text">Algorithms</span> <ul class="subnav"> <li class=""> <a class="" href="algo_pruning.html">Pruning</a> </li> <li class=""> <a class="" href="algo_quantization.html">Quantization</a> </li> <li class=""> <a class="" href="algo_earlyexit.html">Early Exit</a> </li> </ul> </li> <li class="toctree-l1"> <a class="" href="model_zoo.html">Model Zoo</a> </li> <li class="toctree-l1"> <a class="" href="jupyter.html">Jupyter Notebooks</a> </li> <li class="toctree-l1"> <a class="" href="design.html">Design</a> </li> <li class="toctree-l1"> <span class="caption-text">Tutorials</span> <ul class="subnav"> <li class=""> <a class="" href="tutorial-struct_pruning.html">Pruning Filters and Channels</a> </li> <li class=""> <a class="" href="tutorial-lang_model.html">Pruning a Language Model</a> </li> <li class=""> <a class="" href="tutorial-lang_model_quant.html">Quantizing a Language Model</a> </li> </ul> </li> </ul> </div> </nav> <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> <nav class="wy-nav-top" role="navigation" aria-label="top navigation"> <i data-toggle="wy-nav-top" class="fa fa-bars"></i> <a href="index.html">Neural Network Distiller</a> </nav> <div class="wy-nav-content"> <div class="rst-content"> <div role="navigation" aria-label="breadcrumbs navigation"> <ul class="wy-breadcrumbs"> <li><a href="index.html">Docs</a> »</li> <li>Compression Scheduling</li> <li class="wy-breadcrumbs-aside"> </li> </ul> <hr/> </div> <div role="main"> <div class="section"> <h1 id="compression-scheduler">Compression scheduler</h1> <p>In iterative pruning, we create some kind of pruning regimen that specifies how to prune, and what to prune at every stage of the pruning and training stages. This motivated the design of <code>CompressionScheduler</code>: it needed to be part of the training loop, and to be able to make and implement pruning, regularization and quantization decisions. We wanted to be able to change the particulars of the compression schedule, w/o touching the code, and settled on using YAML as a container for this specification. We found that when we make many experiments on the same code base, it is easier to maintain all of these experiments if we decouple the differences from the code-base. Therefore, we added to the scheduler support for learning-rate decay scheduling because, again, we wanted the freedom to change the LR-decay policy without changing code. </p> <h2 id="high-level-overview">High level overview</h2> <p>Let's briefly discuss the main mechanisms and abstractions: A schedule specification is composed of a list of sections defining instances of Pruners, Regularizers, Quantizers, LR-scheduler and Policies.</p> <ul> <li>Pruners, Regularizers and Quantizers are very similar: They implement either a Pruning/Regularization/Quantization algorithm, respectively. </li> <li>An LR-scheduler specifies the LR-decay algorithm. </li> </ul> <p>These define the <strong>what</strong> part of the schedule. </p> <p>The Policies define the <strong>when</strong> part of the schedule: at which epoch to start applying the Pruner/Regularizer/Quantizer/LR-decay, the epoch to end, and how often to invoke the policy (frequency of application). A policy also defines the instance of Pruner/Regularizer/Quantizer/LR-decay it is managing.<br /> The <code>CompressionScheduler</code> is configured from a YAML file or from a dictionary, but you can also manually create Policies, Pruners, Regularizers and Quantizers from code.</p> <h2 id="syntax-through-example">Syntax through example</h2> <p>We'll use <code>alexnet.schedule_agp.yaml</code> to explain some of the YAML syntax for configuring Sensitivity Pruning of Alexnet.</p> <pre><code>version: 1 pruners: my_pruner: class: 'SensitivityPruner' sensitivities: 'features.module.0.weight': 0.25 'features.module.3.weight': 0.35 'features.module.6.weight': 0.40 'features.module.8.weight': 0.45 'features.module.10.weight': 0.55 'classifier.1.weight': 0.875 'classifier.4.weight': 0.875 'classifier.6.weight': 0.625 lr_schedulers: pruning_lr: class: ExponentialLR gamma: 0.9 policies: - pruner: instance_name : 'my_pruner' starting_epoch: 0 ending_epoch: 38 frequency: 2 - lr_scheduler: instance_name: pruning_lr starting_epoch: 24 ending_epoch: 200 frequency: 1 </code></pre> <p>There is only one version of the YAML syntax, and the version number is not verified at the moment. However, to be future-proof it is probably better to let the YAML parser know that you are using version-1 syntax, in case there is ever a version 2.</p> <pre><code>version: 1 </code></pre> <p>In the <code>pruners</code> section, we define the instances of pruners we want the scheduler to instantiate and use.<br /> We define a single pruner instance, named <code>my_pruner</code>, of algorithm <code>SensitivityPruner</code>. We will refer to this instance in the <code>Policies</code> section.<br /> Then we list the sensitivity multipliers, \(s\), of each of the weight tensors.<br /> You may list as many Pruners as you want in this section, as long as each has a unique name. You can several types of pruners in one schedule.</p> <pre><code>pruners: my_pruner: class: 'SensitivityPruner' sensitivities: 'features.module.0.weight': 0.25 'features.module.3.weight': 0.35 'features.module.6.weight': 0.40 'features.module.8.weight': 0.45 'features.module.10.weight': 0.55 'classifier.1.weight': 0.875 'classifier.4.weight': 0.875 'classifier.6.weight': 0.6 </code></pre> <p>Next, we want to specify the learning-rate decay scheduling in the <code>lr_schedulers</code> section. We assign a name to this instance: <code>pruning_lr</code>. As in the <code>pruners</code> section, you may use any name, as long as all LR-schedulers have a unique name. At the moment, only one instance of LR-scheduler is allowed. The LR-scheduler must be a subclass of PyTorch's <a href="http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html">_LRScheduler</a>. You can use any of the schedulers defined in <code>torch.optim.lr_scheduler</code> (see <a href="https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate">here</a>). In addition, we've implemented some additional schedulers in Distiller (see <a href="https://github.com/NervanaSystems/distiller/blob/master/distiller/learning_rate.py">here</a>). The keyword arguments (kwargs) are passed directly to the LR-scheduler's constructor, so that as new LR-schedulers are added to <code>torch.optim.lr_scheduler</code>, they can be used without changing the application code.</p> <pre><code>lr_schedulers: pruning_lr: class: ExponentialLR gamma: 0.9 </code></pre> <p>Finally, we define the <code>policies</code> section which defines the actual scheduling. A <code>Policy</code> manages an instance of a <code>Pruner</code>, <code>Regularizer</code>, <code>Quantizer</code>, or <code>LRScheduler</code>, by naming the instance. In the example below, a <code>PruningPolicy</code> uses the pruner instance named <code>my_pruner</code>: it activates it at a frequency of 2 epochs (i.e. every other epoch), starting at epoch 0, and ending at epoch 38. </p> <pre><code>policies: - pruner: instance_name : 'my_pruner' starting_epoch: 0 ending_epoch: 38 frequency: 2 - lr_scheduler: instance_name: pruning_lr starting_epoch: 24 ending_epoch: 200 frequency: 1 </code></pre> <p>This is <em>iterative pruning</em>:</p> <ol> <li> <p>Train Connectivity</p> </li> <li> <p>Prune Connections</p> </li> <li> <p>Retrain Weights</p> </li> <li> <p>Goto 2</p> </li> </ol> <p>It is described in <a href="https://arxiv.org/abs/1506.02626">Learning both Weights and Connections for Efficient Neural Networks</a>:</p> <blockquote> <p>"Our method prunes redundant connections using a three-step method. First, we train the network to learn which connections are important. Next, we prune the unimportant connections. Finally, we retrain the network to fine tune the weights of the remaining connections...After an initial training phase, we remove all connections whose weight is lower than a threshold. This pruning converts a dense, fully-connected layer to a sparse layer. This first phase learns the topology of the networks — learning which connections are important and removing the unimportant connections. We then retrain the sparse network so the remaining connections can compensate for the connections that have been removed. The phases of pruning and retraining may be repeated iteratively to further reduce network complexity."</p> </blockquote> <h2 id="regularization">Regularization</h2> <p>You can also define and schedule regularization.</p> <h3 id="l1-regularization">L1 regularization</h3> <p>Format (this is an informal specification, not a valid <a href="https://en.wikipedia.org/wiki/Augmented_Backus%E2%80%93Naur_form">ABNF</a> specification):</p> <pre><code>regularizers: <REGULARIZER_NAME_STR>: class: L1Regularizer reg_regims: <PYTORCH_PARAM_NAME_STR>: <STRENGTH_FLOAT> ... <PYTORCH_PARAM_NAME_STR>: <STRENGTH_FLOAT> threshold_criteria: [Mean_Abs | Max] </code></pre> <p>For example:</p> <pre><code>version: 1 regularizers: my_L1_reg: class: L1Regularizer reg_regims: 'module.layer3.1.conv1.weight': 0.000002 'module.layer3.1.conv2.weight': 0.000002 'module.layer3.1.conv3.weight': 0.000002 'module.layer3.2.conv1.weight': 0.000002 threshold_criteria: Mean_Abs policies: - regularizer: instance_name: my_L1_reg starting_epoch: 0 ending_epoch: 60 frequency: 1 </code></pre> <h3 id="group-regularization">Group regularization</h3> <p>Format (informal specification):</p> <pre><code>Format: regularizers: <REGULARIZER_NAME_STR>: class: L1Regularizer reg_regims: <PYTORCH_PARAM_NAME_STR>: [<STRENGTH_FLOAT>, <'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'>] <PYTORCH_PARAM_NAME_STR>: [<STRENGTH_FLOAT>, <'2D' | '3D' | '4D' | 'Channels' | 'Cols' | 'Rows'>] threshold_criteria: [Mean_Abs | Max] </code></pre> <p>For example:</p> <pre><code>version: 1 regularizers: my_filter_regularizer: class: GroupLassoRegularizer reg_regims: 'module.layer3.1.conv1.weight': [0.00005, '3D'] 'module.layer3.1.conv2.weight': [0.00005, '3D'] 'module.layer3.1.conv3.weight': [0.00005, '3D'] 'module.layer3.2.conv1.weight': [0.00005, '3D'] threshold_criteria: Mean_Abs policies: - regularizer: instance_name: my_filter_regularizer starting_epoch: 0 ending_epoch: 60 frequency: 1 </code></pre> <h2 id="mixing-it-up">Mixing it up</h2> <p>You can mix pruning and regularization.</p> <pre><code>version: 1 pruners: my_pruner: class: 'SensitivityPruner' sensitivities: 'features.module.0.weight': 0.25 'features.module.3.weight': 0.35 'features.module.6.weight': 0.40 'features.module.8.weight': 0.45 'features.module.10.weight': 0.55 'classifier.1.weight': 0.875 'classifier.4.weight': 0.875 'classifier.6.weight': 0.625 regularizers: 2d_groups_regularizer: class: GroupLassoRegularizer reg_regims: 'features.module.0.weight': [0.000012, '2D'] 'features.module.3.weight': [0.000012, '2D'] 'features.module.6.weight': [0.000012, '2D'] 'features.module.8.weight': [0.000012, '2D'] 'features.module.10.weight': [0.000012, '2D'] lr_schedulers: # Learning rate decay scheduler pruning_lr: class: ExponentialLR gamma: 0.9 policies: - pruner: instance_name : 'my_pruner' starting_epoch: 0 ending_epoch: 38 frequency: 2 - regularizer: instance_name: '2d_groups_regularizer' starting_epoch: 0 ending_epoch: 38 frequency: 1 - lr_scheduler: instance_name: pruning_lr starting_epoch: 24 ending_epoch: 200 frequency: 1 </code></pre> <h2 id="quantization-aware-training">Quantization-Aware Training</h2> <p>Similarly to pruners and regularizers, specifying a quantizer in the scheduler YAML follows the constructor arguments of the <code>Quantizer</code> class (see details <a href="design.html#quantization">here</a>). <strong>Note</strong> that only a single quantizer instance may be defined per YAML.</p> <p>Let's see an example:</p> <pre><code>quantizers: dorefa_quantizer: class: DorefaQuantizer bits_activations: 8 bits_weights: 4 overrides: conv1: bits_weights: null bits_activations: null relu1: bits_weights: null bits_activations: null final_relu: bits_weights: null bits_activations: null fc: bits_weights: null bits_activations: null </code></pre> <ul> <li>The specific quantization method we're instantiating here is <code>DorefaQuantizer</code>.</li> <li>Then we define the default bit-widths for activations and weights, in this case 8 and 4-bits, respectively. </li> <li>Then, we define the <code>overrides</code> mapping. In the example above, we choose not to quantize the first and last layer of the model. In the case of <code>DorefaQuantizer</code>, the weights are quantized as part of the convolution / FC layers, but the activations are quantized in separate layers, which replace the ReLU layers in the original model (remember - even though we replaced the ReLU modules with our own quantization modules, the name of the modules isn't changed). So, in all, we need to reference the first layer with parameters <code>conv1</code>, the first activation layer <code>relu1</code>, the last activation layer <code>final_relu</code> and the last layer with parameters <code>fc</code>.</li> <li>Specifying <code>null</code> means "do not quantize".</li> <li>Note that for quantizers, we reference names of modules, not names of parameters as we do for pruners and regularizers.</li> </ul> <h3 id="defining-overrides-for-groups-of-layers-using-regular-expressions">Defining overrides for <strong>groups of layers</strong> using regular expressions</h3> <p>Suppose we have a sub-module in our model named <code>block1</code>, which contains multiple convolution layers which we would like to quantize to, say, 2-bits. The convolution layers are named <code>conv1</code>, <code>conv2</code> and so on. In that case we would define the following:</p> <pre><code>overrides: 'block1\.conv*': bits_weights: 2 bits_activations: null </code></pre> <ul> <li><strong>RegEx Note</strong>: Remember that the dot (<code>.</code>) is a meta-character (i.e. a reserved character) in regular expressions. So, to match the actual dot characters which separate sub-modules in PyTorch module names, we need to escape it: <code>\.</code></li> </ul> <p><strong>Overlapping patterns</strong> are also possible, which allows to define some override for a groups of layers and also "single-out" specific layers for different overrides. For example, let's take the last example and configure a different override for <code>block1.conv1</code>:</p> <pre><code>overrides: 'block1\.conv1': bits_weights: 4 bits_activations: null 'block1\.conv*': bits_weights: 2 bits_activations: null </code></pre> <ul> <li><strong>Important Note</strong>: The patterns are evaluated eagerly - first match wins. So, to properly quantize a model using "broad" patterns and more "specific" patterns as just shown, make sure the specific pattern is listed <strong>before</strong> the broad one.</li> </ul> <p>The <code>QuantizationPolicy</code>, which controls the quantization procedure during training, is actually quite simplistic. All it does is call the <code>prepare_model()</code> function of the <code>Quantizer</code> when it's initialized, followed by the first call to <code>quantize_params()</code>. Then, at the end of each epoch, after the float copy of the weights has been updated, it calls the <code>quantize_params()</code> function again.</p> <pre><code>policies: - quantizer: instance_name: dorefa_quantizer starting_epoch: 0 ending_epoch: 200 frequency: 1 </code></pre> <p><strong>Important Note</strong>: As mentioned <a href="design.html#quantization-aware-training">here</a>, since the quantizer modifies the model's parameters (assuming training with quantization in the loop is used), the call to <code>prepare_model()</code> must be performed before an optimizer is called. Therefore, currently, the starting epoch for a quantization policy must be 0, otherwise the quantization process will not work as expected. If one wishes to do a "warm-startup" (or "boot-strapping"), training for a few epochs with full precision and only then starting to quantize, the only way to do this right now is to execute a separate run to generate the boot-strapped weights, and execute a second which will resume the checkpoint with the boot-strapped weights.</p> <h2 id="post-training-quantization">Post-Training Quantization</h2> <p>Post-training quantization differs from the other techniques described here. Since it is not executed during training, it does not require any Policies nor a Scheduler. Currently, the only method implemented for post-training quantization is <a href="algo_quantization.html#range-based-linear-quantization">range-based linear quantization</a>. Quantizing a model using this method, requires adding 2 lines of code:</p> <pre><code class="python">quantizer = distiller.quantization.PostTrainLinearQuantizer(model, <quantizer arguments>) quantizer.prepare_model() # Execute evaluation on model as usual </code></pre> <p>See the documentation for <code>PostTrainLinearQuantizer</code> in <a href="https://github.com/NervanaSystems/distiller/blob/master/distiller/quantization/range_linear.py">range_linear.py</a> for details on the available arguments.<br /> In addition to directly instantiating the quantizer with arguments, it can also be configured from a YAML file. The syntax for the YAML file is exactly the same as seen in the quantization-aware training section above. Not surprisingly, the <code>class</code> defined must be <code>PostTrainLinearQuantizer</code>, and any other components or policies defined in the YAML file are ignored. We'll see how to create the quantizer in this manner below.</p> <p>If more configurability is needed, a helper function can be used that will add a set of command-line arguments to configure the quantizer:</p> <pre><code class="python">parser = argparse.ArgumentParser() distiller.quantization.add_post_train_quant_args(parser) args = parser.parse_args() </code></pre> <p>These are the available command line arguments:</p> <pre><code>Arguments controlling quantization at evaluation time ("post-training quantization"): --quantize-eval, --qe Apply linear quantization to model before evaluation. Applicable only if --evaluate is also set --qe-calibration PORTION_OF_TEST_SET Run the model in evaluation mode on the specified portion of the test dataset and collect statistics. Ignores all other 'qe--*' arguments --qe-mode QE_MODE, --qem QE_MODE Linear quantization mode. Choices: sym | asym_s | asym_u --qe-bits-acts NUM_BITS, --qeba NUM_BITS Number of bits for quantization of activations --qe-bits-wts NUM_BITS, --qebw NUM_BITS Number of bits for quantization of weights --qe-bits-accum NUM_BITS Number of bits for quantization of the accumulator --qe-clip-acts QE_CLIP_ACTS, --qeca QE_CLIP_ACTS Activations clipping mode. Choices: none | avg | n_std --qe-clip-n-stds QE_CLIP_N_STDS When qe-clip-acts is set to 'n_std', this is the number of standard deviations to use --qe-no-clip-layers LAYER_NAME [LAYER_NAME ...], --qencl LAYER_NAME [LAYER_NAME ...] List of layer names for which not to clip activations. Applicable only if --qe-clip-acts is not 'none' --qe-per-channel, --qepc Enable per-channel quantization of weights (per output channel) --qe-scale-approx-bits NUM_BITS, --qesab NUM_BITS Enable scale factor approximation using integer multiply + bit shift, and uset his number of bits to use for the integer multiplier --qe-stats-file PATH Path to YAML file with calibration stats. If not given, dynamic quantization will be run (Note that not all layer types are supported for dynamic quantization) --qe-config-file PATH Path to YAML file containing configuration for PostTrainLinearQuantizer (if present, all other --qe* arguments are ignored) </code></pre> <p>(Note that <code>--quantize-eval</code> and <code>--qe-calibration</code> are mutually exclusive.)</p> <p>When using these command line arguments, the quantizer can be invoked as follows:</p> <pre><code class="python">if args.quantize_eval: quantizer = distiller.quantization.PostTrainLinearQuantizer.from_args(model, args) quantizer.prepare_model() # Execute evaluation on model as usual </code></pre> <p>Note that the command-line arguments don't expose the <code>overrides</code> parameter of the quantizer, which allows fine-grained control over how each layer is quantized. To utilize this functionality, configure with a YAML file.</p> <p>To see integration of these command line arguments in use, see the <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py">image classification example</a>. For examples invocations of post-training quantization see <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant">here</a>.</p> <h3 id="collecting-statistics-for-quantization">Collecting Statistics for Quantization</h3> <p>To collect generate statistics that can be used for static quantization of activations, do the following (shown here assuming the command line argument <code>--qe-calibration</code> shown above is used, which specifies the number of batches to use for statistics generation):</p> <pre><code class="python">if args.qe_calibration: distiller.utils.assign_layer_fq_names(model) msglogger.info("Generating quantization calibration stats based on {0} users".format(args.qe_calibration)) collector = distiller.data_loggers.QuantCalibrationStatsCollector(model) with collector_context(collector): # Here call your model evaluation function, making sure to execute only # the portion of the dataset specified by the qe_calibration argument yaml_path = 'some/dir/quantization_stats.yaml' collector.save(yaml_path) </code></pre> <p>The genreated YAML stats file can then be provided using the <code>`--qe-stats-file</code> argument. An example of a generated stats file can be found <a href="https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml">here</a>.</p> <h2 id="pruning-fine-control">Pruning Fine-Control</h2> <p>Sometimes the default pruning process doesn't satisfy our needs and we require finer control over the pruning process (e.g. over masking, gradient handling, and weight updates). Below we will explain the math and nuances of fine-control configuration.</p> <h3 id="setting-up-the-problem">Setting up the problem</h3> <p>We represent the weights of a DNN as the set <script type="math/tex; mode=display"> \theta=\left\{\theta_{l} : 0 \leq l \leq : L\right\} </script> where <script type="math/tex">\theta_{l}</script> represents the parameters tensor (weights and biases) of layer <script type="math/tex"> l </script> in a network having <script type="math/tex"> L </script> layers. Usually we do not prune biases because of their small size and relative importance. Therefore, we will consider only the network weights (also known as network connections): <script type="math/tex; mode=display"> W=\left\{W_{l} : 0 \leq l \leq : L\right\} </script> We wish to optimize some objective (e.g. minimize the energy required to execute a network in inference mode) under some performance constraint (e.g. accuracy), and we do this by maximizing the sparsity of the network weights (sometimes under some chosen sparsity-pattern constraint). </p> <p>We formalize pruning as a 3-step action:</p> <ol> <li> <p>Generating a mask - in which we define a sparsity-inducing function per layer, <script type="math/tex"> P_l </script>, such that <script type="math/tex; mode=display"> M_{l}=P_{l}\left(W_{l}\right) </script> <script type="math/tex"> M_{l} </script> is a binary matrix which is used to mask <script type="math/tex"> W_{l} </script>. <script type="math/tex"> P_l</script> is implemented by subclasses of <code>distiller.pruner</code>.</p> </li> <li> <p>Masking the weights using the Hadamard product: <script type="math/tex; mode=display"> \widehat{W}_{l}=M_{l} \circ W_{l} </script> </p> </li> <li> <p>Updating the weights (performed by the optimizer). By default, we compute the data-loss using the masked weights, and calculate the gradient of this loss with respect to the masked-weights. We update the weights by making a small adjustment to the <em>masked weights</em>: <script type="math/tex; mode=display"> W_{l} \leftarrow \widehat{W}_{l}-\alpha \frac{\partial Loss(\widehat{W}_{l})}{\partial \widehat{W}_{l}} </script> We show below how to change this default behavior. We also provide a more exact description of the weights update when using PyTorch's SGD optimizer.</p> </li> </ol> <p>The pruning regimen follows a pruning-rate schedule which, analogously to learning-rate annealing, changes the pruning rate according to a configurable strategy over time. The schedule allows us to configure new masks either once at the beginning of epochs (most common), or at the beginning of mini-batches (for finer control). In the former, the masks are calculated and assigned to <script type="math/tex">\{M_{l}\}</script> once, at the beginning of epochs (the specific epochs are determined by the schedule). The pseudo-code below shows the typical training-loop with <code>CompressionScheduler</code> callbacks in bold font, and the three pruning actions described above in burgendy.</p> <p><center><img alt="Masking" src="imgs/pruning_algorithm_pseudo_code.png" /></center><br> <center><strong>Figure 1: Pruning algorithm pseudo-code</strong></center></p> <p>We can perform masking by adding the masking operation to the network graph. We call this <em>in-graph masking</em>, as depicted in the bottom of Figure 2. In the forward-pass we apply element-wise multiplication of the weights <script type="math/tex"> W_{l} </script> and the mask <script type="math/tex"> M_{l} </script> to obtain the masked weights <script type="math/tex">widehat{W}_{l}</script> , which we apply to the Convolution operation. In the backward-pass we mask <script type="math/tex">\frac{\partial L}{\partial \widehat{W}}</script> to obtain <script type="math/tex">\frac{\partial L}{\partial W}</script> with which we update <script type="math/tex"> W_{l} </script>.</p> <p><center><img alt="Masking" src="imgs/pruning_masking.png" /></center><br> <center><strong>Figure 2: Forward and backward weight masking</strong></center></p> <p>In Distiller we perform <em>out-of-graph masking</em> in which we directly set the value of <script type="math/tex">\widehat{W}_{l}</script> by applying a mask on <script type="math/tex"> W_{l} </script> In the backward-pass we make sure that the weights are updated by the <em>proper</em> gradients. In the common pruning use-case we want the optimizer to update only the unmasked weights, but we can configure this behavior using the fine-control arguments, as explained below.</p> <h3 id="fine-control">Fine-Control</h3> <p>For finer control over the behavior of the pruning process, Distiller provides a set of <code>PruningPolicy</code> arguments in the <code>args</code> field, as in the sample below.</p> <pre><code class="YAML">pruners: random_filter_pruner: class: BernoulliFilterPruner desired_sparsity: 0.1 group_type: Filters weights: [module.conv1.weight] policies: - pruner: instance_name: random_filter_pruner args: mini_batch_pruning_frequency: 16 discard_masks_at_minibatch_end: True use_double_copies: True mask_on_forward_only: True mask_gradients: True starting_epoch: 15 ending_epoch: 180 frequency: 1 </code></pre> <h4 id="controls">Controls</h4> <ul> <li> <p><code>mini_batch_pruning_frequency</code> (default: 0): controls pruning scheduling at the mini-batch granularity. Every mini_batch_pruning_frequency training steps (i.e. mini_batches) we configure a new mask. In between mask updates, we mask mini-batches with the current mask.</p> </li> <li> <p><code>discard_masks_at_minibatch_end</code> (default: False): discards the pruning mask at the end of the mini-batch. In the example YAML above, a new mask is computed once every 16 mini-batches, applied in one forward-pass, and then discraded. In the next 15 mini-batches the mask is <code>Null</code> so we do not mask.</p> </li> <li> <p><code>mask_gradients</code> (default: False): mask the weights gradients after performing the backward-pass, and before invoking the optimizer.<br /> <br> One way to mask the gradients in PyTorch is to register to the backward callback of the weight tensors we want to mask, and alter the gradients there. We do this by setting <code>mask_gradients: True</code>, as in the sample YAML above. <br> This is sufficient if our weights optimization uses plain-vanilla SGD, because the update maintains the sparsity of the weights: <script type="math/tex">\widehat{W}_{l}</script> is sparse by definition, and the gradients are sparse because we mask them. <script type="math/tex; mode=display"> W_{l} \leftarrow \widehat{W}_{l}-\alpha \frac{\partial Loss(\widehat{W}_{l})}{\partial \widehat{W}_{l}} </script> <br> But this is not always the case. For example, <a href="https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py">PyTorch’s SGD optimizer</a> with weight-decay (<script type="math/tex">\lambda</script>) and momentum (<script type="math/tex">\alpha</script>) has the optimization logic listed below: <br>1. <script type="math/tex"> \Delta p=\frac{\partial Loss\left(\widehat{W}_{l}^{i}\right)}{\partial \widehat{W}_{l}^{i}}+\lambda \widehat{W}_{l}^{i} </script> <br>2. <script type="math/tex"> v_{i}=\left\lbrace \matrix{ {\Delta p: \; if \;i==0 }\; \cr {v_{i-1} \rho+ (1-dampening)\Delta p: \; if \; i>0} } \right\rbrace </script><br> <br>3. <script type="math/tex"> W_{l}^{i+1} = \widehat{W}_{l}^{i}-\alpha v_{i} </script> <br><br> Let’s look at the weight optimization update at some arbitrary step (i.e. mini-batch) <em>k</em>. <br> We want to show that masking the weights and gradients (<script type="math/tex">W_{l}^{i=k}</script> and <script type="math/tex"> \frac{\partial Loss\left(\widehat{W}_{l}^{i=k}\right)}{\partial \widehat{W}_{l}^{i=k}} </script>) is not sufficient to guarantee that <script type="math/tex">W_{l}^{i=k+1}</script> is sparse. This is easy do: if we allow for the general case where <script type="math/tex">v_i</script> is not necessarily sparse, then <script type="math/tex">W_{l}^{i+1}</script> is not necessarily sparse. <hr> <strong><em>Masking the weights in the forward-pass, and gradients in the backward-pass, is not sufficient to maintain the sparsity of the weights!</em></strong> <hr> This is an important insight, and it means that naïve in-graph masking is also not sufficient to guarantee sparsity of the updated weights. </p> </li> <li> <p><code>use_double_copies</code> (default: False): If you want to compute the gradients using the masked weights and also to update the unmasked weights (instead of updating the masked weights, per usual), set <code>use_double_copies = True</code>. This changes step (3) to: <br>3. <script type="math/tex"> W_{l}^{i+1} = W_{1}^{i}-\alpha \Delta p </script> <br></p> </li> <li> <p><code>mask_on_forward_only</code> (default: False): when set to <code>False</code> the weights will <em>also</em> be masked after the Optimizer is done updating the weights, to remove any updates of the masked gradients. <br> If we want to guarantee the sparsity of the updated weights, we must explicitly mask the weights after step (3) above: <br>4. <script type="math/tex"> {W}_{l}^{i+1} \leftarrow M_{l}^{i} \circ {W}_{l}^{i+1} </script> <br> This argument defaults to <code>False</code>, but you can skip step (4), by setting <code>mask_on_forward_only = True</code>. <br> Finally, note that <code>mask_gradients</code> and <code>not mask_on_forward_only</code> are mutually exclusive, or simply put: if you are masking in the backward-pass, you should choose to either do it via <code>mask_gradients</code> or <code>mask_on_forward_only=False</code>, but not both.</p> </li> </ul> <h2 id="knowledge-distillation">Knowledge Distillation</h2> <p>Knowledge distillation (see <a href="knowledge_distillation.html">here</a>) is also implemented as a <code>Policy</code>, which should be added to the scheduler. However, with the current implementation, it cannot be defined within the YAML file like the rest of the policies described above.</p> <p>To make the integration of this method into applications a bit easier, a helper function can be used that will add a set of command-line arguments related to knowledge distillation:</p> <pre><code>import argparse import distiller parser = argparse.ArgumentParser() distiller.knowledge_distillation.add_distillation_args(parser) </code></pre> <p>(The <code>add_distillation_args</code> function accepts some optional arguments, see its implementation at <code>distiller/knowledge_distillation.py</code> for details)</p> <p>These are the command line arguments exposed by this function:</p> <pre><code>Knowledge Distillation Training Arguments: --kd-teacher ARCH Model architecture for teacher model --kd-pretrained Use pre-trained model for teacher --kd-resume PATH Path to checkpoint from which to load teacher weights --kd-temperature TEMP, --kd-temp TEMP Knowledge distillation softmax temperature --kd-distill-wt WEIGHT, --kd-dw WEIGHT Weight for distillation loss (student vs. teacher soft targets) --kd-student-wt WEIGHT, --kd-sw WEIGHT Weight for student vs. labels loss --kd-teacher-wt WEIGHT, --kd-tw WEIGHT Weight for teacher vs. labels loss --kd-start-epoch EPOCH_NUM Epoch from which to enable distillation </code></pre> <p>Once arguments have been parsed, some initialization code is required, similar to the following:</p> <pre><code># Assuming: # "args" variable holds command line arguments # "model" variable holds the model we're going to train, that is - the student model # "compression_scheduler" variable holds a CompressionScheduler instance args.kd_policy = None if args.kd_teacher: # Create teacher model - replace this with your model creation code teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus) if args.kd_resume: teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume) # Create policy and add to scheduler dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt) args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw) compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs, frequency=1) </code></pre> <p>Finally, during the training loop, we need to perform forward propagation through the teacher model as well. The <code>KnowledgeDistillationPolicy</code> class keeps a reference to both the student and teacher models, and exposes a <code>forward</code> function that performs forward propagation on both of them. Since this is not one of the standard policy callbacks, we need to call this function manually from our training loop, as follows:</p> <pre><code>if args.kd_policy is None: # Revert to a "normal" forward-prop call if no knowledge distillation policy is present output = model(input_var) else: output = args.kd_policy.forward(input_var) </code></pre> <p>To see this integration in action, take a look at the image classification sample at <code>examples/classifier_compression/compress_classifier.py</code>.</p> </div> </div> <footer> <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> <a href="pruning.html" class="btn btn-neutral float-right" title="Pruning">Next <span class="icon icon-circle-arrow-right"></span></a> <a href="usage.html" class="btn btn-neutral" title="Usage"><span class="icon icon-circle-arrow-left"></span> Previous</a> </div> <hr/> <div role="contentinfo"> <!-- Copyright etc --> </div> Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>. </footer> </div> </div> </section> </div> <div class="rst-versions" role="note" style="cursor: pointer"> <span class="rst-current-version" data-toggle="rst-current-version"> <span><a href="usage.html" style="color: #fcfcfc;">« Previous</a></span> <span style="margin-left: 15px"><a href="pruning.html" style="color: #fcfcfc">Next »</a></span> </span> </div> <script>var base_url = '.';</script> <script src="js/theme.js" defer></script> <script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML" defer></script> <script src="search/main.js" defer></script> </body> </html>