From 1198a8890a6d858cb6b11124e3d6be4638f29408 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Thu, 8 Aug 2019 19:01:35 +0300
Subject: [PATCH] Added GNMT post-training quantization example (#252)

---
 .gitignore                                    |    5 +
 README.md                                     |    4 +-
 distiller/modules/__init__.py                 |   11 +-
 distiller/modules/aggregate.py                |   16 +
 distiller/modules/eltwise.py                  |   15 +-
 distiller/modules/matmul.py                   |   39 +
 distiller/quantization/range_linear.py        |  135 +-
 examples/GNMT/LICENSE                         |   22 +
 examples/GNMT/README.md                       |  121 +
 examples/GNMT/download_dataset.sh             |  172 +
 examples/GNMT/download_trained_model.sh       |    4 +
 examples/GNMT/model_stats.yaml                | 2868 +++++++++++++++++
 examples/GNMT/quantize_gnmt.ipynb             | 1069 ++++++
 examples/GNMT/requirements.txt                |    3 +
 examples/GNMT/scripts/filter_dataset.py       |   79 +
 examples/GNMT/seq2seq/__init__.py             |    0
 examples/GNMT/seq2seq/data/__init__.py        |    0
 examples/GNMT/seq2seq/data/config.py          |   21 +
 examples/GNMT/seq2seq/data/dataset.py         |  123 +
 examples/GNMT/seq2seq/data/sampler.py         |   85 +
 examples/GNMT/seq2seq/data/tokenizer.py       |   44 +
 examples/GNMT/seq2seq/inference/__init__.py   |    0
 .../GNMT/seq2seq/inference/beam_search.py     |  245 ++
 examples/GNMT/seq2seq/inference/inference.py  |   88 +
 examples/GNMT/seq2seq/models/__init__.py      |    4 +
 examples/GNMT/seq2seq/models/attention.py     |  164 +
 examples/GNMT/seq2seq/models/decoder.py       |  140 +
 examples/GNMT/seq2seq/models/encoder.py       |   62 +
 examples/GNMT/seq2seq/models/gnmt.py          |   36 +
 examples/GNMT/seq2seq/models/seq2seq_base.py  |   22 +
 examples/GNMT/seq2seq/utils.py                |  116 +
 examples/GNMT/translate.py                    |  336 ++
 examples/GNMT/verify_dataset.sh               |   73 +
 33 files changed, 6106 insertions(+), 16 deletions(-)
 create mode 100644 distiller/modules/aggregate.py
 create mode 100644 distiller/modules/matmul.py
 create mode 100644 examples/GNMT/LICENSE
 create mode 100644 examples/GNMT/README.md
 create mode 100644 examples/GNMT/download_dataset.sh
 create mode 100644 examples/GNMT/download_trained_model.sh
 create mode 100644 examples/GNMT/model_stats.yaml
 create mode 100644 examples/GNMT/quantize_gnmt.ipynb
 create mode 100644 examples/GNMT/requirements.txt
 create mode 100644 examples/GNMT/scripts/filter_dataset.py
 create mode 100644 examples/GNMT/seq2seq/__init__.py
 create mode 100644 examples/GNMT/seq2seq/data/__init__.py
 create mode 100644 examples/GNMT/seq2seq/data/config.py
 create mode 100644 examples/GNMT/seq2seq/data/dataset.py
 create mode 100644 examples/GNMT/seq2seq/data/sampler.py
 create mode 100644 examples/GNMT/seq2seq/data/tokenizer.py
 create mode 100644 examples/GNMT/seq2seq/inference/__init__.py
 create mode 100644 examples/GNMT/seq2seq/inference/beam_search.py
 create mode 100644 examples/GNMT/seq2seq/inference/inference.py
 create mode 100644 examples/GNMT/seq2seq/models/__init__.py
 create mode 100644 examples/GNMT/seq2seq/models/attention.py
 create mode 100644 examples/GNMT/seq2seq/models/decoder.py
 create mode 100644 examples/GNMT/seq2seq/models/encoder.py
 create mode 100644 examples/GNMT/seq2seq/models/gnmt.py
 create mode 100644 examples/GNMT/seq2seq/models/seq2seq_base.py
 create mode 100644 examples/GNMT/seq2seq/utils.py
 create mode 100644 examples/GNMT/translate.py
 create mode 100644 examples/GNMT/verify_dataset.sh

diff --git a/.gitignore b/.gitignore
index 913ea4c..a1670e9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,11 @@ __pycache__/
 .cache
 pytest_collaterals/
 
+# GNMT sample
+examples/GNMT/data
+examples/GNMT/model_best.pth
+examples/GNMT/output_file*
+
 # Virtual env
 env/
 .env/
diff --git a/README.md b/README.md
index 60f97f0..3dcede3 100755
--- a/README.md
+++ b/README.md
@@ -250,8 +250,8 @@ For more details, there are some other resources you can refer to:
 + [Preparing a model for quantization](https://nervanasystems.github.io/distiller/prepare_model_quant.html)
 + [Tutorial: Using Distiller to prune a PyTorch language model](https://nervanasystems.github.io/distiller/tutorial-lang_model.html)
 + [Tutorial: Pruning Filters & Channels](https://nervanasystems.github.io/distiller/tutorial-struct_pruning.html)
-+ [Tutorial: Post-Training Quantization of a Language Model
-](https://nervanasystems.github.io/distiller/tutorial-lang_model_quant.html)
++ [Tutorial: Post-Training Quantization of a Language Model](https://nervanasystems.github.io/distiller/tutorial-lang_model_quant.html)
++ [Tutorial: Post-Training Quantization of GNMT (translation model)](https://nervanasystems.github.io/distiller/tutorial-lang_model_quant.html)
 + [Post-training quantization command line examples](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/command_line.md)
 
 ### Example invocations of the sample application
diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index 5bd7d5c..03c3f57 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -14,10 +14,13 @@
 # limitations under the License.
 #
 
-from .eltwise import EltwiseAdd, EltwiseMult
+from .eltwise import *
 from .grouping import *
-from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm
+from .matmul import *
+from .rnn import *
+from .aggregate import Norm
 
-__all__ = ['EltwiseAdd', 'EltwiseMult',
+__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
            'Concat', 'Chunk', 'Split', 'Stack',
-           'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
+           'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
+           'Norm']
diff --git a/distiller/modules/aggregate.py b/distiller/modules/aggregate.py
new file mode 100644
index 0000000..a217627
--- /dev/null
+++ b/distiller/modules/aggregate.py
@@ -0,0 +1,16 @@
+import torch
+import torch.nn as nn
+
+
+class Norm(nn.Module):
+    """
+    A module wrapper for vector/matrix norm
+    """
+    def __init__(self, p='fro', dim=None, keepdim=False):
+        super(Norm, self).__init__()
+        self.p = p
+        self.dim = dim
+        self.keepdim = keepdim
+
+    def forward(self, x: torch.Tensor):
+        return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
diff --git a/distiller/modules/eltwise.py b/distiller/modules/eltwise.py
index b61a168..8435059 100644
--- a/distiller/modules/eltwise.py
+++ b/distiller/modules/eltwise.py
@@ -13,13 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import torch
 import torch.nn as nn
 
 
 class EltwiseAdd(nn.Module):
     def __init__(self, inplace=False):
         super(EltwiseAdd, self).__init__()
+
         self.inplace = inplace
 
     def forward(self, *input):
@@ -47,3 +48,15 @@ class EltwiseMult(nn.Module):
             for t in input[1:]:
                 res = res * t
         return res
+
+
+class EltwiseDiv(nn.Module):
+    def __init__(self, inplace=False):
+        super(EltwiseDiv, self).__init__()
+        self.inplace = inplace
+
+    def forward(self, x: torch.Tensor, y):
+        if self.inplace:
+            return x.div_(y)
+        return x.div(y)
+
diff --git a/distiller/modules/matmul.py b/distiller/modules/matmul.py
new file mode 100644
index 0000000..bdbeccd
--- /dev/null
+++ b/distiller/modules/matmul.py
@@ -0,0 +1,39 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import torch
+import torch.nn as nn
+
+
+class Matmul(nn.Module):
+    """
+    A wrapper module for matmul operation between 2 tensors.
+    """
+    def __init__(self):
+        super(Matmul, self).__init__()
+
+    def forward(self, a: torch.Tensor, b: torch.Tensor):
+        return a.matmul(b)
+
+
+class BatchMatmul(nn.Module):
+    """
+    A wrapper module for torch.bmm operation between 2 tensors.
+    """
+    def __init__(self):
+        super(BatchMatmul, self).__init__()
+
+    def forward(self, a: torch.Tensor, b:torch.Tensor):
+        return torch.bmm(a, b)
\ No newline at end of file
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index a213c60..7bbedc4 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -18,7 +18,7 @@ import torch.nn as nn
 import argparse
 from enum import Enum
 from collections import OrderedDict
-from functools import reduce, partial
+from functools import reduce, partial, update_wrapper
 import logging
 import os
 from copy import deepcopy
@@ -414,7 +414,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per
             output channel
         activation_stats (dict): See RangeLinearQuantWrapper
-        clip_n_stds (float): See RangeLinearQuantWrapper
+        clip_n_stds (int): See RangeLinearQuantWrapper
         scale_approx_mult_bits (int): See RangeLinearQuantWrapper
     """
     def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32,
@@ -468,6 +468,19 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
                 # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor
                 self.register_buffer('base_b_q', base_b_q)
 
+        # A flag indicating that the simulated quantized weights are pre-shifted. for faster performance.
+        # In the first forward pass - `w_zero_point` is added into the weights, to allow faster inference,
+        # and all subsequent calls are done with these shifted weights.
+        # Upon calling `self.state_dict()` - we restore the actual quantized weights.
+        self.is_simulated_quant_weight_shifted = False
+
+    def state_dict(self, destination=None, prefix='', keep_vars=False):
+        if self.is_simulated_quant_weight_shifted:
+            # We want to return the weights to their integer representation:
+            self.wrapped_module.weight.data -= self.w_zero_point
+            self.is_simulated_quant_weight_shifted = False
+        return super(RangeLinearQuantParamLayerWrapper, self).state_dict(destination, prefix, keep_vars)
+
     def get_inputs_quantization_params(self, input):
         if not self.preset_act_stats:
             self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor(
@@ -501,15 +514,16 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         # to the input and weights and pass those to the wrapped model. Functionally, since at this point we're
         # dealing solely with integer values, the results are the same either way.
 
-        if self.mode != LinearQuantMode.SYMMETRIC:
-            input_q += self.in_0_zero_point
+        if self.mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted:
+            # We "store" the w_zero_point inside our wrapped module's weights to
+            # improve performance on inference.
             self.wrapped_module.weight.data += self.w_zero_point
+            self.is_simulated_quant_weight_shifted = True
 
+        input_q += self.in_0_zero_point
         accum = self.wrapped_module.forward(input_q)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
 
-        if self.mode != LinearQuantMode.SYMMETRIC:
-            self.wrapped_module.weight.data -= self.w_zero_point
         return accum
 
     def get_output_quantization_params(self, accumulator):
@@ -553,6 +567,96 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         return tmpstr
 
 
+class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
+    """
+    Wrapper for quantizing the Matmul/BatchMatmul operation between 2 input tensors.
+    output = input1 @ input2
+    where:
+        input1.shape = (input_batch, input_size)
+        input2.shape = (input_size, output_size)
+    The mathematical calculation is:
+        y_f = i1_f * i2_f
+        iN_f = iN_q / scale_iN + zp_iN =>
+        y_q = scale_y * y_f + zp_y =  scale_y * (i1_f * i2_f) + zp_y =
+
+                    scale_y
+        y_q = ------------------- * ((i1_q + zp_i1) * (i2_q + zp_i2) + zp_y
+               scale_i1 * scale_i2
+    Args:
+        wrapped_module (distiller.modules.Matmul or distiller.modules.BatchMatmul): Module to be wrapped
+        num_bits_acts (int): Number of bits used for inputs and output quantization
+        num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results
+        mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned)
+        clip_acts (ClipNode): See RangeLinearQuantWrapper
+        activation_stats (dict): See RangeLinearQuantWrapper
+        clip_n_stds (int): See RangeLinearQuantWrapper
+        scale_approx_mult_bits (int): See RangeLinearQuantWrapper
+    """
+    def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32,
+                 mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,  activation_stats=None,
+                 clip_n_stds=None, scale_approx_mult_bits=None):
+        super(RangeLinearQuantMatmulWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode,
+                                                                clip_acts, activation_stats, clip_n_stds,
+                                                                scale_approx_mult_bits)
+
+        if not isinstance(wrapped_module, (distiller.modules.Matmul, distiller.modules.BatchMatmul)):
+            raise ValueError(self.__class__.__name__ + ' can wrap only Matmul modules')
+        if self.preset_act_stats:
+            self.register_buffer('accum_scale', self.in_0_scale * self.in_1_scale)
+        else:
+            self.accum_scale = 1
+
+    def get_inputs_quantization_params(self, input0, input1):
+        if not self.preset_act_stats:
+            self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor(
+                input0, self.num_bits_acts, self.mode, clip=self.clip_acts,
+                num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits)
+            self.in_1_scale, self.in_1_zero_point = _get_quant_params_from_tensor(
+                input0, self.num_bits_acts, self.mode, clip=self.clip_acts,
+                num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits)
+        return [self.in_0_scale, self.in_1_scale], [self.in_0_zero_point, self.in_1_zero_point]
+
+    def quantized_forward(self, input0_q, input1_q):
+        accum = self.wrapped_module.forward(input0_q + self.in_0_zero_point,
+                                            input1_q + self.in_1_zero_point)
+        clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
+        return accum
+
+    def get_output_quantization_params(self, accumulator):
+        if self.preset_act_stats:
+            return self.output_scale, self.output_zero_point
+
+        y_f = accumulator / self.accum_scale
+        return _get_quant_params_from_tensor(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts,
+                                             num_stds=self.clip_n_stds,
+                                             scale_approx_mult_bits=self.scale_approx_mult_bits)
+
+    def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
+        requant_scale = output_scale / self.accum_scale
+        if self.scale_approx_mult_bits is not None:
+            requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits)
+        return requant_scale, output_zero_point
+
+    def extra_repr(self):
+        tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1])
+        tmpstr += 'num_bits_acts={0},  num_bits_accum={1}, '.format(self.num_bits_acts, self.num_bits_accum)
+        tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts))
+        if self.clip_acts == ClipMode.N_STD:
+            tmpstr += 'num_stds={} '.format(self.clip_n_stds)
+        tmpstr += 'scale_approx_mult_bits={}'.format(self.scale_approx_mult_bits)
+        tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats)
+        if self.preset_act_stats:
+            tmpstr += '\nin_0_scale={0:.4f}, in_0_zero_point={1:.4f}'.format(self.in_0_scale.item(),
+                                                                         self.in_0_zero_point.item())
+
+            tmpstr += '\nin_1_scale={0:.4f}, in_1_zero_point={1:.4f}'.format(self.in_1_scale.item(),
+                                                                         self.in_1_zero_point.item())
+
+            tmpstr += '\nout_scale={0:.4f}, out_zero_point={1:.4f}'.format(self.output_scale.item(),
+                                                                           self.output_zero_point.item())
+        return tmpstr
+
+
 class NoStatsError(NotImplementedError):
     pass
 
@@ -877,12 +981,25 @@ class PostTrainLinearQuantizer(Quantizer):
         self.replacement_factory[nn.Conv3d] = replace_param_layer
         self.replacement_factory[nn.Linear] = replace_param_layer
 
-        self.replacement_factory[distiller.modules.Concat] = partial(
+        factory_concat = partial(
             replace_non_param_layer, RangeLinearQuantConcatWrapper)
-        self.replacement_factory[distiller.modules.EltwiseAdd] = partial(
+        factory_eltwiseadd = partial(
             replace_non_param_layer, RangeLinearQuantEltwiseAddWrapper)
-        self.replacement_factory[distiller.modules.EltwiseMult] = partial(
+        factory_eltwisemult = partial(
             replace_non_param_layer, RangeLinearQuantEltwiseMultWrapper)
+        factory_matmul = partial(
+            replace_non_param_layer, RangeLinearQuantMatmulWrapper)
+
+        update_wrapper(factory_concat, replace_non_param_layer)
+        update_wrapper(factory_eltwiseadd, replace_non_param_layer)
+        update_wrapper(factory_eltwisemult, replace_non_param_layer)
+        update_wrapper(factory_matmul, replace_non_param_layer)
+
+        self.replacement_factory[distiller.modules.Concat] = factory_concat
+        self.replacement_factory[distiller.modules.EltwiseAdd] = factory_eltwiseadd
+        self.replacement_factory[distiller.modules.EltwiseMult] = factory_eltwisemult
+        self.replacement_factory[distiller.modules.Matmul] = factory_matmul
+        self.replacement_factory[distiller.modules.BatchMatmul] = factory_matmul
         self.replacement_factory[nn.Embedding] = replace_embedding
 
         save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.'
diff --git a/examples/GNMT/LICENSE b/examples/GNMT/LICENSE
new file mode 100644
index 0000000..4343c76
--- /dev/null
+++ b/examples/GNMT/LICENSE
@@ -0,0 +1,22 @@
+MIT License
+
+Copyright (c) 2017 Elad Hoffer
+Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/examples/GNMT/README.md b/examples/GNMT/README.md
new file mode 100644
index 0000000..893ce1f
--- /dev/null
+++ b/examples/GNMT/README.md
@@ -0,0 +1,121 @@
+# Google's Neural Machine Translation
+In this example we quantize [MLPerf's implementation of GNMT](https://github.com/mlperf/training/tree/master/rnn_translator/pytorch)
+and show different configurations of quantization to achieve the highest accuracy using **post-training quantization**.
+
+Note that this folder contains only code required to run evaluation. All training code was removed. A link to a pre-trained model is provided below.
+
+For a summary on the quantization results see [below](#results).
+
+## Running the Example
+
+This example is implemented as Jupyter notebook.
+
+### Install Requirements
+
+    pip install -r requirements
+
+(This will install [sacrebleu](https://pypi.org/project/sacrebleu/))
+
+### Get the Dataset
+
+Download the data using the following command:
+
+    bash download_dataset.sh
+
+Verify data with:
+
+    bash verify_dataset.sh
+
+### Download the Pre-trained Model
+
+    wget https://zenodo.org/record/2581623/files/model_best.pth
+
+### Run the Example
+
+    jupyter notebook
+
+And start the `quantize_gnmt.ipynb` notebook.
+
+## Summary of Quantization Results
+
+### What is Quantized
+
+The following operations / modules are fully quantized:
+
+* Linear (fully-connected)
+* Embedding
+* Element-wise addition
+* Element-wise multiplication
+* MatMul / Batch MatMul
+* Concat
+
+The following operations do not have a quantized implementation. The operations run in FP32, with quantized + de-quantize applied at the op boundary (input and output):
+
+* Softmax
+* Tanh
+* Sigmoid
+* Division by norm in the attention block. That is, in pseudo code:
+  ```python
+  quant_dequant(x)
+  y = x / norm(x)
+  quant_dequant(y)
+  ```
+
+### Results
+
+| Precision | Mode       | Per-Channel | Clip Activations                                              | Bleu Score |
+|-----------|------------|-------------|---------------------------------------------------------------|------------|
+| FP32      | N/A        | N/A         | N/A                                                           | 22.16      |
+| INT8      | Symmetric  | No          | No                                                            | 18.05      |
+| INT8      | Asymmetric | No          | No                                                            | 18.52      |
+| INT8      | Asymmetric | Yes         | AVG in all layers                                             | 9.63       |
+| INT8      | Asymmetric | Yes         | AVG in all layers except attention block                      | 16.94      |
+| INT8      | Asymmetric | Yes         | AVG in all layers except attention block and final classifier | 21.49      |
+
+## Dataset / Environment
+
+### Publication / Attribution
+
+We use [WMT16 English-German](http://www.statmt.org/wmt16/translation-task.html) for training.
+
+### Data preprocessing
+
+Script uses [subword-nmt](https://github.com/rsennrich/subword-nmt) package to segment text into subword units (BPE), by default it builds shared vocabulary of 32,000 tokens. Preprocessing removes all pairs of sentences that can't be decoded by latin-1 encoder.
+
+## Model
+
+### Publication / Attribution
+
+Implemented model is similar to the one from [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144) paper.
+
+Most important difference is in the attention mechanism. This repository implements `gnmt_v2` attention: output from first LSTM layer of decoder goes into attention, then re-weighted context is concatenated with inputs to all subsequent LSTM layers in decoder at current timestep.
+
+The same attention mechanism is also implemented in default GNMT-like models from [tensorflow/nmt](https://github.com/tensorflow/nmt) and [NVIDIA/OpenSeq2Seq](https://github.com/NVIDIA/OpenSeq2Seq).
+
+### Structure
+
+* general:
+  * encoder and decoder are using shared embeddings
+  * data-parallel multi-gpu training
+  * dynamic loss scaling with backoff for Tensor Cores (mixed precision) training
+  * trained with label smoothing loss (smoothing factor 0.1)
+* encoder:
+  * 4-layer LSTM, hidden size 1024, first layer is bidirectional, the rest are
+    undirectional
+  * with residual connections starting from 3rd layer
+  * uses standard LSTM layer (accelerated by cudnn)
+* decoder:
+  * 4-layer unidirectional LSTM with hidden size 1024 and fully-connected
+    classifier
+  * with residual connections starting from 3rd layer
+  * uses standard LSTM layer (accelerated by cudnn)
+* attention:
+  * normalized Bahdanau attention
+  * model uses `gnmt_v2` attention mechanism
+  * output from first LSTM layer of decoder goes into attention,
+  then re-weighted context is concatenated with the input to all subsequent
+  LSTM layers in decoder at the current timestep
+* inference:
+  * beam search with default beam size 5
+  * with coverage penalty and length normalization
+  * BLEU computed by [sacrebleu](https://pypi.org/project/sacrebleu/)
diff --git a/examples/GNMT/download_dataset.sh b/examples/GNMT/download_dataset.sh
new file mode 100644
index 0000000..8dbe528
--- /dev/null
+++ b/examples/GNMT/download_dataset.sh
@@ -0,0 +1,172 @@
+#! /usr/bin/env bash
+
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+
+OUTPUT_DIR=${1:-"data"}
+echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable."
+
+OUTPUT_DIR_DATA="${OUTPUT_DIR}/data"
+
+mkdir -p $OUTPUT_DIR_DATA
+
+echo "Downloading Europarl v7. This may take a while..."
+wget -nc -nv -O ${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz \
+  http://www.statmt.org/europarl/v7/de-en.tgz
+
+echo "Downloading Common Crawl corpus. This may take a while..."
+wget -nc -nv -O ${OUTPUT_DIR_DATA}/common-crawl.tgz \
+  http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz
+
+echo "Downloading News Commentary v11. This may take a while..."
+wget -nc -nv -O ${OUTPUT_DIR_DATA}/nc-v11.tgz \
+  http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz
+
+echo "Downloading dev/test sets"
+wget -nc -nv -O  ${OUTPUT_DIR_DATA}/dev.tgz \
+  http://data.statmt.org/wmt16/translation-task/dev.tgz
+wget -nc -nv -O  ${OUTPUT_DIR_DATA}/test.tgz \
+  http://data.statmt.org/wmt16/translation-task/test.tgz
+
+# Extract everything
+echo "Extracting all files..."
+mkdir -p "${OUTPUT_DIR_DATA}/europarl-v7-de-en"
+tar -xvzf "${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz" -C "${OUTPUT_DIR_DATA}/europarl-v7-de-en"
+mkdir -p "${OUTPUT_DIR_DATA}/common-crawl"
+tar -xvzf "${OUTPUT_DIR_DATA}/common-crawl.tgz" -C "${OUTPUT_DIR_DATA}/common-crawl"
+mkdir -p "${OUTPUT_DIR_DATA}/nc-v11"
+tar -xvzf "${OUTPUT_DIR_DATA}/nc-v11.tgz" -C "${OUTPUT_DIR_DATA}/nc-v11"
+mkdir -p "${OUTPUT_DIR_DATA}/dev"
+tar -xvzf "${OUTPUT_DIR_DATA}/dev.tgz" -C "${OUTPUT_DIR_DATA}/dev"
+mkdir -p "${OUTPUT_DIR_DATA}/test"
+tar -xvzf "${OUTPUT_DIR_DATA}/test.tgz" -C "${OUTPUT_DIR_DATA}/test"
+
+# Concatenate Training data
+cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.en" \
+  "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.en" \
+  "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.en" \
+  > "${OUTPUT_DIR}/train.en"
+wc -l "${OUTPUT_DIR}/train.en"
+
+cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.de" \
+  "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.de" \
+  "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.de" \
+  > "${OUTPUT_DIR}/train.de"
+wc -l "${OUTPUT_DIR}/train.de"
+
+# Clone Moses
+if [ ! -d "${OUTPUT_DIR}/mosesdecoder" ]; then
+  echo "Cloning moses for data processing"
+  git clone https://github.com/moses-smt/mosesdecoder.git "${OUTPUT_DIR}/mosesdecoder"
+  cd ${OUTPUT_DIR}/mosesdecoder
+  git reset --hard 8c5eaa1a122236bbf927bde4ec610906fea599e6
+  cd -
+fi
+
+# Convert SGM files
+# Convert newstest2014 data into raw text format
+${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+  < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-src.de.sgm \
+  > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.de
+${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+  < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-ref.en.sgm \
+  > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.en
+
+# Convert newstest2015 data into raw text format
+${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+  < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-src.de.sgm \
+  > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.de
+${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+  < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-ref.en.sgm \
+  > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.en
+
+# Convert newstest2016 data into raw text format
+${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+  < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-src.de.sgm \
+  > ${OUTPUT_DIR_DATA}/test/test/newstest2016.de
+${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+  < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-ref.en.sgm \
+  > ${OUTPUT_DIR_DATA}/test/test/newstest2016.en
+
+# Copy dev/test data to output dir
+cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.de ${OUTPUT_DIR}
+cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.en ${OUTPUT_DIR}
+cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.de ${OUTPUT_DIR}
+cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.en ${OUTPUT_DIR}
+
+# Tokenize data
+for f in ${OUTPUT_DIR}/*.de; do
+  echo "Tokenizing $f..."
+  ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l de -threads 8 < $f > ${f%.*}.tok.de
+done
+
+for f in ${OUTPUT_DIR}/*.en; do
+  echo "Tokenizing $f..."
+  ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l en -threads 8 < $f > ${f%.*}.tok.en
+done
+
+# Clean all corpora
+for f in ${OUTPUT_DIR}/*.en; do
+  fbase=${f%.*}
+  echo "Cleaning ${fbase}..."
+  ${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $fbase de en "${fbase}.clean" 1 80
+done
+
+# Create dev dataset
+cat "${OUTPUT_DIR}/newstest2015.tok.clean.en" \
+   "${OUTPUT_DIR}/newstest2016.tok.clean.en" \
+   > "${OUTPUT_DIR}/newstest_dev.tok.clean.en"
+
+cat "${OUTPUT_DIR}/newstest2015.tok.clean.de" \
+   "${OUTPUT_DIR}/newstest2016.tok.clean.de" \
+   > "${OUTPUT_DIR}/newstest_dev.tok.clean.de"
+
+# Filter datasets
+python3 scripts/filter_dataset.py -f1 ${OUTPUT_DIR}/train.tok.clean.en -f2 ${OUTPUT_DIR}/train.tok.clean.de
+python3 scripts/filter_dataset.py -f1 ${OUTPUT_DIR}/newstest_dev.tok.clean.en -f2 ${OUTPUT_DIR}/newstest_dev.tok.clean.de
+
+# Generate Subword Units (BPE)
+# Clone Subword NMT
+if [ ! -d "${OUTPUT_DIR}/subword-nmt" ]; then
+  git clone https://github.com/rsennrich/subword-nmt.git "${OUTPUT_DIR}/subword-nmt"
+  cd ${OUTPUT_DIR}/subword-nmt
+  git reset --hard 48ba99e657591c329e0003f0c6e32e493fa959ef
+  cd -
+fi
+
+# Learn Shared BPE
+for merge_ops in 32000; do
+  echo "Learning BPE with merge_ops=${merge_ops}. This may take a while..."
+  cat "${OUTPUT_DIR}/train.tok.clean.de" "${OUTPUT_DIR}/train.tok.clean.en" | \
+    ${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $merge_ops > "${OUTPUT_DIR}/bpe.${merge_ops}"
+
+  echo "Apply BPE with merge_ops=${merge_ops} to tokenized files..."
+  for lang in en de; do
+    for f in ${OUTPUT_DIR}/*.tok.${lang} ${OUTPUT_DIR}/*.tok.clean.${lang}; do
+      outfile="${f%.*}.bpe.${merge_ops}.${lang}"
+      ${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c "${OUTPUT_DIR}/bpe.${merge_ops}" < $f > "${outfile}"
+      echo ${outfile}
+    done
+  done
+
+  # Create vocabulary file for BPE
+  cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \
+    ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}"
+
+done
+
+echo "All done."
diff --git a/examples/GNMT/download_trained_model.sh b/examples/GNMT/download_trained_model.sh
new file mode 100644
index 0000000..77a4a23
--- /dev/null
+++ b/examples/GNMT/download_trained_model.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+
+wget https://zenodo.org/record/2581623/files/model_best.pth
+
diff --git a/examples/GNMT/model_stats.yaml b/examples/GNMT/model_stats.yaml
new file mode 100644
index 0000000..32de9d7
--- /dev/null
+++ b/examples/GNMT/model_stats.yaml
@@ -0,0 +1,2868 @@
+encoder.rnn_layers.0.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -5.349642276763916
+      max: 5.236245632171631
+      avg_min: -2.8281705602780907
+      avg_max: 2.8138245348501125
+      mean: -0.0016106524300209288
+      std: 0.8805053743938167
+      shape: (1, 1024)
+  output:
+    min: -33.63774490356445
+    max: 32.98709487915039
+    avg_min: -19.3743170070767
+    avg_max: 18.377735952944196
+    mean: -1.2778008344058638
+    std: 5.574481019638121
+    shape: (1, 4096)
+encoder.rnn_layers.0.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -1.0
+      max: 0.9997552633285522
+      avg_min: -0.9242292824059608
+      avg_max: 0.8228753872990404
+      mean: -0.002536592865063468
+      std: 0.23544047089016987
+      shape: (1, 1024)
+  output:
+    min: -22.65224838256836
+    max: 17.655214309692383
+    avg_min: -11.509061529416135
+    avg_max: 6.881520398276833
+    mean: -2.373206253470364
+    std: 2.8913670592601197
+    shape: (1, 4096)
+encoder.rnn_layers.0.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -33.63774490356445
+      max: 32.98709487915039
+      avg_min: -19.3743170070767
+      avg_max: 18.377735952944196
+      mean: -1.2778008344058638
+      std: 5.574481019638121
+      shape: (1, 4096)
+    1:
+      min: -22.65224838256836
+      max: 17.655214309692383
+      avg_min: -11.509061529416135
+      avg_max: 6.881520398276833
+      mean: -2.373206253470364
+      std: 2.8913670592601197
+      shape: (1, 4096)
+  output:
+    min: -40.20295333862305
+    max: 34.10171890258789
+    avg_min: -23.746535006890156
+    avg_max: 18.852929802167523
+    mean: -3.6510070722635595
+    std: 6.424404119623572
+    shape: (1, 4096)
+encoder.rnn_layers.0.cells.0.act_f:
+  inputs:
+    0:
+      min: -40.20295333862305
+      max: 30.254833221435547
+      avg_min: -23.296182616223867
+      avg_max: 11.772601394017986
+      mean: -7.432837164007311
+      std: 5.768370101504384
+      shape: (1, 1024)
+  output:
+    min: 3.4680012470380246e-18
+    max: 1.0
+    avg_min: 2.613044365988191e-07
+    avg_max: 0.9997981336035177
+    mean: 0.11436928112791814
+    std: 0.2739924793333765
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.act_i:
+  inputs:
+    0:
+      min: -33.06312561035156
+      max: 34.10171890258789
+      avg_min: -20.231574570140438
+      avg_max: 15.238747643933712
+      mean: -3.533190907395537
+      std: 5.892242776130438
+      shape: (1, 1024)
+  output:
+    min: 4.37388120446097e-15
+    max: 1.0
+    avg_min: 1.3271535511620103e-07
+    avg_max: 0.9999727739693318
+    mean: 0.27593418032485884
+    std: 0.3954827759906979
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.act_o:
+  inputs:
+    0:
+      min: -34.778221130371094
+      max: 30.507577896118164
+      avg_min: -20.07659851906139
+      avg_max: 14.524603366545767
+      mean: -3.567072719363816
+      std: 5.69751302438932
+      shape: (1, 1024)
+  output:
+    min: 7.870648145337088e-16
+    max: 1.0
+    avg_min: 2.206335442308734e-07
+    avg_max: 0.9999628556888489
+    mean: 0.2618299175561274
+    std: 0.3833761347669008
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.act_g:
+  inputs:
+    0:
+      min: -33.982521057128906
+      max: 33.54863739013672
+      avg_min: -19.17722672297466
+      avg_max: 18.48774343457505
+      mean: -0.0709275224609481
+      std: 6.116932441002841
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -1.0
+    avg_max: 0.9999999999947166
+    mean: -0.007298258008836419
+    std: 0.9385271485476292
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 3.4680012470380246e-18
+      max: 1.0
+      avg_min: 2.613044365988191e-07
+      avg_max: 0.9997981336035177
+      mean: 0.11436928112791814
+      std: 0.2739924793333765
+      shape: (1, 1024)
+    1:
+      min: -38.982444763183594
+      max: 5.8226799964904785
+      avg_min: -5.096891815574018
+      avg_max: 1.6211083284579975
+      mean: -0.00908752453454041
+      std: 0.5260199992435526
+      shape: (1, 1024)
+  output:
+    min: -38.97710418701172
+    max: 4.823007583618164
+    avg_min: -4.562025848362994
+    avg_max: 1.0867458244325479
+    mean: -0.0052387909689317535
+    std: 0.24397997338840743
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 4.37388120446097e-15
+      max: 1.0
+      avg_min: 1.3271535511620103e-07
+      avg_max: 0.9999727739693318
+      mean: 0.27593418032485884
+      std: 0.3954827759906979
+      shape: (1, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -1.0
+      avg_max: 0.9999999999947166
+      mean: -0.007298258008836419
+      std: 0.9385271485476292
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999414832227521
+    avg_max: 0.9998500162423687
+    mean: -0.004381999838699326
+    std: 0.4680802328708717
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -38.97710418701172
+      max: 4.823007583618164
+      avg_min: -4.562025848362994
+      avg_max: 1.0867458244325479
+      mean: -0.0052387909689317535
+      std: 0.24397997338840743
+      shape: (1, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999414832227521
+      avg_max: 0.9998500162423687
+      mean: -0.004381999838699326
+      std: 0.4680802328708717
+      shape: (1, 1024)
+  output:
+    min: -39.97100830078125
+    max: 5.8226799964904785
+    avg_min: -5.376022642979372
+    avg_max: 1.6904500271956118
+    mean: -0.009620790795830529
+    std: 0.5374170427844072
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.act_h:
+  inputs:
+    0:
+      min: -39.97100830078125
+      max: 5.8226799964904785
+      avg_min: -5.376022642979372
+      avg_max: 1.6904500271956118
+      mean: -0.009620790795830529
+      std: 0.5374170427844072
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 0.999982476234436
+    avg_min: -0.9784979219253714
+    avg_max: 0.9111029639984748
+    mean: -0.0042562744052456435
+    std: 0.3862437789687602
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 7.870648145337088e-16
+      max: 1.0
+      avg_min: 2.206335442308734e-07
+      avg_max: 0.9999628556888489
+      mean: 0.2618299175561274
+      std: 0.3833761347669008
+      shape: (1, 1024)
+    1:
+      min: -1.0
+      max: 0.999982476234436
+      avg_min: -0.9784979219253714
+      avg_max: 0.9111029639984748
+      mean: -0.0042562744052456435
+      std: 0.3862437789687602
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 0.9997552633285522
+    avg_min: -0.957775045493643
+    avg_max: 0.8561479863309971
+    mean: -0.0026583670083781055
+    std: 0.24113962918670837
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.fc_gate_x:
+  inputs:
+    0:
+      min: -5.349642276763916
+      max: 5.236245632171631
+      avg_min: -2.828170560278106
+      avg_max: 2.8138245348501405
+      mean: -0.0016106524300209281
+      std: 0.8805050697722259
+      shape: (1, 1024)
+  output:
+    min: -33.94660949707031
+    max: 31.446853637695312
+    avg_min: -19.276340029144965
+    avg_max: 18.700496405858367
+    mean: -0.4969901222221274
+    std: 5.666540069661909
+    shape: (1, 4096)
+encoder.rnn_layers.0.cells_reverse.0.fc_gate_h:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9213716166421844
+      avg_max: 0.9217871793855027
+      mean: -0.0008159191441113994
+      std: 0.2933630896077876
+      shape: (1, 1024)
+  output:
+    min: -26.57025146484375
+    max: 27.723560333251953
+    avg_min: -13.664213478780777
+    avg_max: 13.256404327800057
+    mean: -1.6347749916045144
+    std: 3.3098913033147026
+    shape: (1, 4096)
+encoder.rnn_layers.0.cells_reverse.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -33.94660949707031
+      max: 31.446853637695312
+      avg_min: -19.276340029144965
+      avg_max: 18.700496405858367
+      mean: -0.4969901222221274
+      std: 5.666540069661909
+      shape: (1, 4096)
+    1:
+      min: -26.57025146484375
+      max: 27.723560333251953
+      avg_min: -13.664213478780777
+      avg_max: 13.256404327800057
+      mean: -1.6347749916045144
+      std: 3.3098913033147026
+      shape: (1, 4096)
+  output:
+    min: -38.69415283203125
+    max: 36.723724365234375
+    avg_min: -23.688260936494675
+    avg_max: 21.477238427920927
+    mean: -2.131765110762214
+    std: 6.632126340207978
+    shape: (1, 4096)
+encoder.rnn_layers.0.cells_reverse.0.act_f:
+  inputs:
+    0:
+      min: -38.21221160888672
+      max: 35.994014739990234
+      avg_min: -21.775707706900725
+      avg_max: 17.50440279642364
+      mean: -3.8737916625820072
+      std: 6.84193347529921
+      shape: (1, 1024)
+  output:
+    min: 2.5389102872982082e-17
+    max: 1.0
+    avg_min: 1.0996477415504891e-07
+    avg_max: 0.999999346159746
+    mean: 0.2890327369231582
+    std: 0.4023833394248511
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.act_i:
+  inputs:
+    0:
+      min: -38.2786750793457
+      max: 32.760406494140625
+      avg_min: -21.206531206573878
+      avg_max: 16.691689617168617
+      mean: -3.073931181881033
+      std: 6.2677184148953256
+      shape: (1, 1024)
+  output:
+    min: 2.3756509732875972e-17
+    max: 1.0
+    avg_min: 1.7373452237735842e-08
+    avg_max: 0.9999988585396135
+    mean: 0.3108025884336769
+    std: 0.4110379687299371
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.act_o:
+  inputs:
+    0:
+      min: -38.69415283203125
+      max: 33.763919830322266
+      avg_min: -19.567437244926133
+      avg_max: 17.492547388038638
+      mean: -1.5878526072804218
+      std: 6.0648371330715545
+      shape: (1, 1024)
+  output:
+    min: 1.5679886799044933e-17
+    max: 1.0
+    avg_min: 7.394662946045065e-08
+    avg_max: 0.9999994433218858
+    mean: 0.39423968600374154
+    std: 0.4279581642861166
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.act_g:
+  inputs:
+    0:
+      min: -37.34138870239258
+      max: 36.723724365234375
+      avg_min: -20.704529739459048
+      avg_max: 20.84152177737822
+      mean: 0.008514990976159644
+      std: 6.647000615879316
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -1.0
+    avg_max: 1.0
+    mean: 0.0014180161330502351
+    std: 0.9442339332189296
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 2.5389102872982082e-17
+      max: 1.0
+      avg_min: 1.0996477415504891e-07
+      avg_max: 0.999999346159746
+      mean: 0.2890327369231582
+      std: 0.4023833394248511
+      shape: (1, 1024)
+    1:
+      min: -29.373083114624023
+      max: 35.754878997802734
+      avg_min: -3.36926156270743
+      avg_max: 3.9063054220693965
+      mean: -0.0008279334693150963
+      std: 0.6366862423327254
+      shape: (1, 1024)
+  output:
+    min: -29.373075485229492
+    max: 34.76329040527344
+    avg_min: -3.194934033077149
+    avg_max: 3.729673293131146
+    mean: 0.00015181826198886874
+    std: 0.3949363630374722
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 2.3756509732875972e-17
+      max: 1.0
+      avg_min: 1.7373452237735842e-08
+      avg_max: 0.9999988585396135
+      mean: 0.3108025884336769
+      std: 0.4110379687299371
+      shape: (1, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -1.0
+      avg_max: 1.0
+      mean: 0.0014180161330502351
+      std: 0.9442339332189296
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.999991173282417
+    avg_max: 0.9999928301513564
+    mean: -0.0005279613764160218
+    std: 0.5003141965724145
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -29.373075485229492
+      max: 34.76329040527344
+      avg_min: -3.194934033077149
+      avg_max: 3.729673293131146
+      mean: 0.00015181826198886874
+      std: 0.3949363630374722
+      shape: (1, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.999991173282417
+      avg_max: 0.9999928301513564
+      mean: -0.0005279613764160218
+      std: 0.5003141965724145
+      shape: (1, 1024)
+  output:
+    min: -29.373083114624023
+    max: 35.754878997802734
+    avg_min: -3.50784800743793
+    avg_max: 4.100828251393332
+    mean: -0.00037614313566861415
+    std: 0.6483412291978414
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.act_h:
+  inputs:
+    0:
+      min: -29.373083114624023
+      max: 35.754878997802734
+      avg_min: -3.50784800743793
+      avg_max: 4.100828251393332
+      mean: -0.00037614313566861415
+      std: 0.6483412291978414
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9818865314740792
+    avg_max: 0.9818776522137024
+    mean: -0.0013970042477792044
+    std: 0.45158473029410123
+    shape: (1, 1024)
+encoder.rnn_layers.0.cells_reverse.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 1.5679886799044933e-17
+      max: 1.0
+      avg_min: 7.394662946045065e-08
+      avg_max: 0.9999994433218858
+      mean: 0.39423968600374154
+      std: 0.4279581642861166
+      shape: (1, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9818865314740792
+      avg_max: 0.9818776522137024
+      mean: -0.0013970042477792044
+      std: 0.45158473029410123
+      shape: (1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9556271536847694
+    avg_max: 0.9564541215110065
+    mean: -0.0005957838928230832
+    std: 0.2985633225360773
+    shape: (1, 1024)
+encoder.rnn_layers.0.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+encoder.rnn_layers.1.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.3797274232525334
+      avg_max: 0.3723408187603086
+      mean: -0.0006261473494000447
+      std: 0.1684595854192406
+      shape: (128, 2048)
+  output:
+    min: -28.018692016601562
+    max: 26.689453125
+    avg_min: -5.231459095807397
+    avg_max: 4.8616450890998655
+    mean: -0.5998025477032523
+    std: 2.1475245530098204
+    shape: (128, 4096)
+encoder.rnn_layers.1.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -0.9999995231628418
+      max: 0.9999927878379822
+      avg_min: -0.9296687000683234
+      avg_max: 0.9067187918312595
+      mean: -0.00016366511202441184
+      std: 0.17426053388600266
+      shape: (128, 1024)
+  output:
+    min: -22.963897705078125
+    max: 23.043453216552734
+    avg_min: -7.452188115125264
+    avg_max: 8.383283630046785
+    mean: -0.9507504072732051
+    std: 1.9114254594378783
+    shape: (128, 4096)
+encoder.rnn_layers.1.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -28.018692016601562
+      max: 26.689453125
+      avg_min: -5.231459095807397
+      avg_max: 4.8616450890998655
+      mean: -0.5998025477032523
+      std: 2.1475245530098204
+      shape: (128, 4096)
+    1:
+      min: -22.963897705078125
+      max: 23.043453216552734
+      avg_min: -7.452188115125264
+      avg_max: 8.383283630046785
+      mean: -0.9507504072732051
+      std: 1.9114254594378783
+      shape: (128, 4096)
+  output:
+    min: -31.574262619018555
+    max: 32.37839126586914
+    avg_min: -10.47135662152461
+    avg_max: 10.518839543297785
+    mean: -1.5505529552098096
+    std: 3.1590192373764134
+    shape: (128, 4096)
+encoder.rnn_layers.1.cells.0.act_f:
+  inputs:
+    0:
+      min: -31.574262619018555
+      max: 25.064706802368164
+      avg_min: -10.176423978218846
+      avg_max: 8.090042744440224
+      mean: -3.3116955970157718
+      std: 3.4230677646590277
+      shape: (128, 1024)
+  output:
+    min: 1.9385273957893065e-14
+    max: 1.0
+    avg_min: 0.0009497451324273016
+    avg_max: 0.9984393432169701
+    mean: 0.16464387024575694
+    std: 0.2689671071958494
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.act_i:
+  inputs:
+    0:
+      min: -29.14598274230957
+      max: 28.608158111572266
+      avg_min: -8.99660675317649
+      avg_max: 6.228059029819189
+      mean: -1.7129896603174661
+      std: 2.7638664904974917
+      shape: (128, 1024)
+  output:
+    min: 2.1981662367068222e-13
+    max: 1.0
+    avg_min: 0.0017008052480266586
+    avg_max: 0.9786340831643511
+    mean: 0.2868795179297321
+    std: 0.27602027714229127
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.act_o:
+  inputs:
+    0:
+      min: -25.69948959350586
+      max: 27.677391052246094
+      avg_min: -8.206474038311828
+      avg_max: 9.042313837365008
+      mean: -1.2052700564898609
+      std: 2.6057113237143055
+      shape: (128, 1024)
+  output:
+    min: 6.9000694914722605e-12
+    max: 1.0
+    avg_min: 0.0033538953499784242
+    avg_max: 0.9974105607476547
+    mean: 0.33604429936415675
+    std: 0.2886479601419748
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.act_g:
+  inputs:
+    0:
+      min: -28.827739715576172
+      max: 32.37839126586914
+      avg_min: -8.882986145798268
+      avg_max: 8.65820987619275
+      mean: 0.027743494574744162
+      std: 2.8342722477735816
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999114334316583
+    avg_max: 0.9996711071705672
+    mean: 0.015988719393334825
+    std: 0.7484585498246807
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 1.9385273957893065e-14
+      max: 1.0
+      avg_min: 0.0009497451324273016
+      avg_max: 0.9984393432169701
+      mean: 0.16464387024575694
+      std: 0.2689671071958494
+      shape: (128, 1024)
+    1:
+      min: -86.69337463378906
+      max: 80.73381805419922
+      avg_min: -12.103038570158164
+      avg_max: 12.752722294005268
+      mean: -0.023152882641702514
+      std: 1.2572998992641964
+      shape: (128, 1024)
+  output:
+    min: -86.67153930664062
+    max: 80.55928802490234
+    avg_min: -11.985433492784532
+    avg_max: 12.574206966781782
+    mean: -0.030583657865831754
+    std: 1.2066110910522718
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 2.1981662367068222e-13
+      max: 1.0
+      avg_min: 0.0017008052480266586
+      avg_max: 0.9786340831643511
+      mean: 0.2868795179297321
+      std: 0.27602027714229127
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999114334316583
+      avg_max: 0.9996711071705672
+      mean: 0.015988719393334825
+      std: 0.7484585498246807
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9046535779632451
+    avg_max: 0.9683758158411752
+    mean: 0.006315402761504094
+    std: 0.3088953017599018
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -86.67153930664062
+      max: 80.55928802490234
+      avg_min: -11.985433492784532
+      avg_max: 12.574206966781782
+      mean: -0.030583657865831754
+      std: 1.2066110910522718
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9046535779632451
+      avg_max: 0.9683758158411752
+      mean: 0.006315402761504094
+      std: 0.3088953017599018
+      shape: (128, 1024)
+  output:
+    min: -87.62094116210938
+    max: 81.55608367919922
+    avg_min: -12.543978190728746
+    avg_max: 13.23383954467387
+    mean: -0.024268255073244073
+    std: 1.2882582959130635
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.act_h:
+  inputs:
+    0:
+      min: -87.62094116210938
+      max: 81.55608367919922
+      avg_min: -12.543978190728746
+      avg_max: 13.23383954467387
+      mean: -0.024268255073244073
+      std: 1.2882582959130635
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9870239480229829
+    avg_max: 0.982027827093264
+    mean: 0.005185764388532731
+    std: 0.3119488266195677
+    shape: (128, 1024)
+encoder.rnn_layers.1.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 6.9000694914722605e-12
+      max: 1.0
+      avg_min: 0.0033538953499784242
+      avg_max: 0.9974105607476547
+      mean: 0.33604429936415675
+      std: 0.2886479601419748
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9870239480229829
+      avg_max: 0.982027827093264
+      mean: 0.005185764388532731
+      std: 0.3119488266195677
+      shape: (128, 1024)
+  output:
+    min: -0.9999995231628418
+    max: 0.9999927878379822
+    avg_min: -0.9430685600378369
+    avg_max: 0.9199835463058242
+    mean: -0.00017594442323635664
+    std: 0.17510466702537447
+    shape: (128, 1024)
+encoder.rnn_layers.1.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+encoder.rnn_layers.2.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -0.9999995231628418
+      max: 0.9999927878379822
+      avg_min: -0.9430685600378369
+      avg_max: 0.9199835463058242
+      mean: -0.00017594442323635664
+      std: 0.17510466702537447
+      shape: (128, 1024)
+  output:
+    min: -17.669666290283203
+    max: 18.855335235595703
+    avg_min: -5.826010705240624
+    avg_max: 5.579422310161379
+    mean: -0.7621265583790381
+    std: 1.6214516118322957
+    shape: (128, 4096)
+encoder.rnn_layers.2.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -0.9999260902404785
+      max: 0.999937891960144
+      avg_min: -0.9288329214057661
+      avg_max: 0.8692965302824718
+      mean: -0.002193973255493973
+      std: 0.13664756692926694
+      shape: (128, 1024)
+  output:
+    min: -12.899184226989746
+    max: 12.18427848815918
+    avg_min: -5.383265833059953
+    avg_max: 6.135640713992537
+    mean: -0.7798955419852971
+    std: 1.2988276201978577
+    shape: (128, 4096)
+encoder.rnn_layers.2.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -17.669666290283203
+      max: 18.855335235595703
+      avg_min: -5.826010705240624
+      avg_max: 5.579422310161379
+      mean: -0.7621265583790381
+      std: 1.6214516118322957
+      shape: (128, 4096)
+    1:
+      min: -12.899184226989746
+      max: 12.18427848815918
+      avg_min: -5.383265833059953
+      avg_max: 6.135640713992537
+      mean: -0.7798955419852971
+      std: 1.2988276201978577
+      shape: (128, 4096)
+  output:
+    min: -22.86528968811035
+    max: 22.331012725830078
+    avg_min: -9.460085806430587
+    avg_max: 10.091681053441096
+    mean: -1.5420221014644224
+    std: 2.4497739415101045
+    shape: (128, 4096)
+encoder.rnn_layers.2.cells.0.act_f:
+  inputs:
+    0:
+      min: -22.86528968811035
+      max: 20.687637329101562
+      avg_min: -8.921805253498242
+      avg_max: 9.76669416798307
+      mean: -1.7373663354206654
+      std: 2.701828662299363
+      shape: (128, 1024)
+  output:
+    min: 1.1741696503975163e-10
+    max: 1.0
+    avg_min: 0.0011454785276338222
+    avg_max: 0.9971977971010819
+    mean: 0.2779621332383798
+    std: 0.3120235028974232
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.act_i:
+  inputs:
+    0:
+      min: -19.433446884155273
+      max: 22.331012725830078
+      avg_min: -9.11351156621439
+      avg_max: 5.007397494993483
+      mean: -2.482964531553934
+      std: 2.096762127403773
+      shape: (128, 1024)
+  output:
+    min: 3.632128597885753e-09
+    max: 1.0
+    avg_min: 0.0005238509440227943
+    avg_max: 0.9815354767928449
+    mean: 0.17026146355457997
+    std: 0.22359512877822252
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.act_o:
+  inputs:
+    0:
+      min: -20.721616744995117
+      max: 21.40572166442871
+      avg_min: -7.591665619305041
+      avg_max: 6.900344703221481
+      mean: -1.9338501979173957
+      std: 2.188371648150393
+      shape: (128, 1024)
+  output:
+    min: 1.0016504292664763e-09
+    max: 1.0
+    avg_min: 0.0014563385881592266
+    avg_max: 0.996341593846912
+    mean: 0.2302334359294881
+    std: 0.2644176532615949
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.act_g:
+  inputs:
+    0:
+      min: -20.245834350585938
+      max: 17.422372817993164
+      avg_min: -7.094027656303417
+      avg_max: 6.198857292499591
+      mean: -0.013907334995555003
+      std: 2.0268158859398273
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999658937755573
+    avg_max: 0.999913128430412
+    mean: 0.006783831941798703
+    std: 0.7667732969711837
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 1.1741696503975163e-10
+      max: 1.0
+      avg_min: 0.0011454785276338222
+      avg_max: 0.9971977971010819
+      mean: 0.2779621332383798
+      std: 0.3120235028974232
+      shape: (128, 1024)
+    1:
+      min: -70.65399169921875
+      max: 79.07190704345703
+      avg_min: -9.986505569694323
+      avg_max: 11.433578196754656
+      mean: -0.007533524133234033
+      std: 0.9619221798317081
+      shape: (128, 1024)
+  output:
+    min: -70.47572326660156
+    max: 79.0152816772461
+    avg_min: -9.786985937847655
+    avg_max: 11.29573609329911
+    mean: -0.0077121059270617645
+    std: 0.9119806954490677
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 3.632128597885753e-09
+      max: 1.0
+      avg_min: 0.0005238509440227943
+      avg_max: 0.9815354767928449
+      mean: 0.17026146355457997
+      std: 0.22359512877822252
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999658937755573
+      avg_max: 0.999913128430412
+      mean: 0.006783831941798703
+      std: 0.7667732969711837
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9706906902643418
+    avg_max: 0.9338277108983972
+    mean: 0.00031330419493738533
+    std: 0.23285220882823307
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -70.47572326660156
+      max: 79.0152816772461
+      avg_min: -9.786985937847655
+      avg_max: 11.29573609329911
+      mean: -0.0077121059270617645
+      std: 0.9119806954490677
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9706906902643418
+      avg_max: 0.9338277108983972
+      mean: 0.00031330419493738533
+      std: 0.23285220882823307
+      shape: (128, 1024)
+  output:
+    min: -71.34825897216797
+    max: 79.90394592285156
+    avg_min: -10.283224373522492
+    avg_max: 11.882522354690009
+    mean: -0.007398801722887531
+    std: 0.9852414934516581
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.act_h:
+  inputs:
+    0:
+      min: -71.34825897216797
+      max: 79.90394592285156
+      avg_min: -10.283224373522492
+      avg_max: 11.882522354690009
+      mean: -0.007398801722887531
+      std: 0.9852414934516581
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9912044704627142
+    avg_max: 0.9643497033340547
+    mean: -0.0026311075007701672
+    std: 0.2839022589840224
+    shape: (128, 1024)
+encoder.rnn_layers.2.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 1.0016504292664763e-09
+      max: 1.0
+      avg_min: 0.0014563385881592266
+      avg_max: 0.996341593846912
+      mean: 0.2302334359294881
+      std: 0.2644176532615949
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9912044704627142
+      avg_max: 0.9643497033340547
+      mean: -0.0026311075007701672
+      std: 0.2839022589840224
+      shape: (128, 1024)
+  output:
+    min: -0.9999260902404785
+    max: 0.999937891960144
+    avg_min: -0.9419258169546493
+    avg_max: 0.8823925270396857
+    mean: -0.002203096514809122
+    std: 0.1377186248199642
+    shape: (128, 1024)
+encoder.rnn_layers.2.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+encoder.rnn_layers.3.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -1.999041199684143
+      max: 1.996099829673767
+      avg_min: -1.4410850684714778
+      avg_max: 1.4346018574915211
+      mean: -0.0023790409310782876
+      std: 0.23082689970193895
+      shape: (128, 1024)
+  output:
+    min: -18.958541870117188
+    max: 20.005878448486328
+    avg_min: -6.6773925654306705
+    avg_max: 7.386176010906286
+    mean: -0.7425925570313029
+    std: 1.9741718173439569
+    shape: (128, 4096)
+encoder.rnn_layers.3.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -0.9999996423721313
+      max: 0.9999986290931702
+      avg_min: -0.9598443289464481
+      avg_max: 0.9680624933584159
+      mean: -5.464283497081485e-05
+      std: 0.1605789723297388
+      shape: (128, 1024)
+  output:
+    min: -17.731414794921875
+    max: 16.26653289794922
+    avg_min: -7.986687395696665
+    avg_max: 8.483382081558762
+    mean: -1.0937484681389607
+    std: 1.8863347454511785
+    shape: (128, 4096)
+encoder.rnn_layers.3.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -18.958541870117188
+      max: 20.005878448486328
+      avg_min: -6.6773925654306705
+      avg_max: 7.386176010906286
+      mean: -0.7425925570313029
+      std: 1.9741718173439569
+      shape: (128, 4096)
+    1:
+      min: -17.731414794921875
+      max: 16.26653289794922
+      avg_min: -7.986687395696665
+      avg_max: 8.483382081558762
+      mean: -1.0937484681389607
+      std: 1.8863347454511785
+      shape: (128, 4096)
+  output:
+    min: -25.56599998474121
+    max: 24.13156509399414
+    avg_min: -12.193620794438145
+    avg_max: 13.726210550973875
+    mean: -1.8363410236118098
+    std: 3.220885522048646
+    shape: (128, 4096)
+encoder.rnn_layers.3.cells.0.act_f:
+  inputs:
+    0:
+      min: -25.56599998474121
+      max: 24.13156509399414
+      avg_min: -10.901736432407242
+      avg_max: 11.459924239973608
+      mean: -1.7886736150289273
+      std: 3.611997766132868
+      shape: (128, 1024)
+  output:
+    min: 7.885464850532209e-12
+    max: 1.0
+    avg_min: 0.001426324373246965
+    avg_max: 0.9981180747150038
+    mean: 0.3085659271068626
+    std: 0.362517692693201
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.act_i:
+  inputs:
+    0:
+      min: -21.904094696044922
+      max: 19.92990493774414
+      avg_min: -10.301701992016742
+      avg_max: 5.387517666390012
+      mean: -3.4113261673007766
+      std: 2.452103517547047
+      shape: (128, 1024)
+  output:
+    min: 3.070241560987341e-10
+    max: 1.0
+    avg_min: 0.0005902890219026971
+    avg_max: 0.9767627610722911
+    mean: 0.1199813902886455
+    std: 0.20597901242826788
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.act_o:
+  inputs:
+    0:
+      min: -21.795564651489258
+      max: 20.31422233581543
+      avg_min: -9.119528548429464
+      avg_max: 12.915865638912122
+      mean: -2.1100439747307944
+      std: 2.7725332585946427
+      shape: (128, 1024)
+  output:
+    min: 3.422208905146107e-10
+    max: 1.0
+    avg_min: 0.0007607025358741477
+    avg_max: 0.9998594416467933
+    mean: 0.23564658800694233
+    std: 0.2900290182410238
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.act_g:
+  inputs:
+    0:
+      min: -22.5491943359375
+      max: 21.174640655517578
+      avg_min: -11.443979141296161
+      avg_max: 10.109944212490001
+      mean: -0.035320331854928055
+      std: 2.991404617719586
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999641429391214
+    avg_max: 0.9999893591374617
+    mean: -0.017447813011902232
+    std: 0.8383612744760685
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 7.885464850532209e-12
+      max: 1.0
+      avg_min: 0.001426324373246965
+      avg_max: 0.9981180747150038
+      mean: 0.3085659271068626
+      std: 0.362517692693201
+      shape: (128, 1024)
+    1:
+      min: -72.4426498413086
+      max: 66.09469604492188
+      avg_min: -13.666880435458225
+      avg_max: 12.621639567200226
+      mean: 0.001298973259466056
+      std: 1.3620516853160491
+      shape: (128, 1024)
+  output:
+    min: -72.35834503173828
+    max: 66.04717254638672
+    avg_min: -13.53175763829178
+    avg_max: 12.483070900776225
+    mean: -0.0005504995357517113
+    std: 1.3219009414904979
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 3.070241560987341e-10
+      max: 1.0
+      avg_min: 0.0005902890219026971
+      avg_max: 0.9767627610722911
+      mean: 0.1199813902886455
+      std: 0.20597901242826788
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999641429391214
+      avg_max: 0.9999893591374617
+      mean: -0.017447813011902232
+      std: 0.8383612744760685
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9390226983917377
+    avg_max: 0.9557934822118774
+    mean: 0.0013386550747985504
+    std: 0.21654626606567212
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -72.35834503173828
+      max: 66.04717254638672
+      avg_min: -13.53175763829178
+      avg_max: 12.483070900776225
+      mean: -0.0005504995357517113
+      std: 1.3219009414904979
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9390226983917377
+      avg_max: 0.9557934822118774
+      mean: 0.0013386550747985504
+      std: 0.21654626606567212
+      shape: (128, 1024)
+  output:
+    min: -73.15592956542969
+    max: 66.85343170166016
+    avg_min: -14.088761083898394
+    avg_max: 12.986537455179008
+    mean: 0.0007881555251539649
+    std: 1.3934345367324847
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.act_h:
+  inputs:
+    0:
+      min: -73.15592956542969
+      max: 66.85343170166016
+      avg_min: -14.088761083898394
+      avg_max: 12.986537455179008
+      mean: 0.0007881555251539649
+      std: 1.3934345367324847
+      shape: (128, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9939409935101039
+    avg_max: 0.9954337507019633
+    mean: 0.005215724385428704
+    std: 0.3152480532527837
+    shape: (128, 1024)
+encoder.rnn_layers.3.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 3.422208905146107e-10
+      max: 1.0
+      avg_min: 0.0007607025358741477
+      avg_max: 0.9998594416467933
+      mean: 0.23564658800694233
+      std: 0.2900290182410238
+      shape: (128, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9939409935101039
+      avg_max: 0.9954337507019633
+      mean: 0.005215724385428704
+      std: 0.3152480532527837
+      shape: (128, 1024)
+  output:
+    min: -0.9999996423721313
+    max: 0.9999986290931702
+    avg_min: -0.973251596759896
+    avg_max: 0.981412514344157
+    mean: -8.919282888996098e-05
+    std: 0.16194114228528503
+    shape: (128, 1024)
+encoder.rnn_layers.3.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+encoder.dropout:
+  inputs:
+    0:
+      min: -1.999041199684143
+      max: 1.996099829673767
+      avg_min: -1.2932450498143837
+      avg_max: 1.2545068023933308
+      mean: -0.0010675477744067996
+      std: 0.19403692465489641
+      shape: (128, 67, 2048)
+  output:
+    min: -1.999041199684143
+    max: 1.996099829673767
+    avg_min: -1.2932450498143837
+    avg_max: 1.2545068023933308
+    mean: -0.0010675477744067996
+    std: 0.19403692465489641
+    shape: (128, 67, 2048)
+encoder.embedder:
+  inputs:
+    0:
+      min: 0
+      max: 32103
+      avg_min: 3.2019918366436153
+      avg_max: 24950.48025676814
+      mean: 2937.1528870343564
+      std: 6014.0608835026405
+      shape: (128, 67)
+  output:
+    min: -5.349642276763916
+    max: 5.236245632171631
+    avg_min: -2.757863937090343
+    avg_max: 2.8868405888984556
+    mean: 0.0012734194185268903
+    std: 0.8584338496588495
+    shape: (128, 67, 1024)
+encoder.eltwiseadd_residuals.0:
+  inputs:
+    0:
+      min: -0.9999260902404785
+      max: 0.999937891960144
+      avg_min: -0.9973272457718849
+      avg_max: 0.9888223583499591
+      mean: -0.0022379238847255083
+      std: 0.13743977474031957
+      shape: (128, 67, 1024)
+    1:
+      min: -0.9999995231628418
+      max: 0.9999927878379822
+      avg_min: -0.9997136642535528
+      avg_max: 0.995096003015836
+      mean: -0.0001647915998243358
+      std: 0.17545655249766362
+      shape: (128, 67, 1024)
+  output:
+    min: -1.999041199684143
+    max: 1.996099829673767
+    avg_min: -1.8801900794108708
+    avg_max: 1.7707290103038151
+    mean: -0.0024027154916742197
+    std: 0.2309446271021139
+    shape: (128, 67, 1024)
+encoder.eltwiseadd_residuals.1:
+  inputs:
+    0:
+      min: -0.9999996423721313
+      max: 0.9999986290931702
+      avg_min: -0.9991549079616864
+      avg_max: 0.9982590352495511
+      mean: -4.938640427099017e-05
+      std: 0.16142380765711173
+      shape: (128, 67, 1024)
+    1:
+      min: -1.999041199684143
+      max: 1.996099829673767
+      avg_min: -1.8801900794108708
+      avg_max: 1.7707290103038151
+      mean: -0.0024027154916742197
+      std: 0.2309446271021139
+      shape: (128, 67, 1024)
+  output:
+    min: -2.9931235313415527
+    max: 2.9921934604644775
+    avg_min: -2.454954892396927
+    avg_max: 2.5017193754514055
+    mean: -0.0024521019076928496
+    std: 0.2936707415080642
+    shape: (128, 67, 1024)
+decoder.att_rnn.rnn.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -5.343320369720459
+      max: 5.05069637298584
+      avg_min: -2.7414190683472004
+      avg_max: 2.8718470222494563
+      mean: 0.0012989030032717187
+      std: 0.8618320366411912
+      shape: (1280, 1024)
+  output:
+    min: -32.593318939208984
+    max: 33.264034271240234
+    avg_min: -19.51530230447147
+    avg_max: 18.39185942424816
+    mean: -0.319708395848218
+    std: 5.248436831320663
+    shape: (1280, 4096)
+decoder.att_rnn.rnn.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9713709625970117
+      avg_max: 0.9817781538775797
+      mean: 0.007541471822687429
+      std: 0.3220154108800773
+      shape: (1280, 1024)
+  output:
+    min: -35.26850891113281
+    max: 38.68207550048828
+    avg_min: -19.4308424914486
+    avg_max: 22.16190616622404
+    mean: -1.9407802034043842
+    std: 4.835737332097758
+    shape: (1280, 4096)
+decoder.att_rnn.rnn.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -32.593318939208984
+      max: 33.264034271240234
+      avg_min: -19.51530230447147
+      avg_max: 18.39185942424816
+      mean: -0.319708395848218
+      std: 5.248436831320663
+      shape: (1280, 4096)
+    1:
+      min: -35.26850891113281
+      max: 38.68207550048828
+      avg_min: -19.4308424914486
+      avg_max: 22.16190616622404
+      mean: -1.9407802034043842
+      std: 4.835737332097758
+      shape: (1280, 4096)
+  output:
+    min: -45.18857192993164
+    max: 48.00703430175781
+    avg_min: -26.367431187897573
+    avg_max: 27.033264856660004
+    mean: -2.260488603289206
+    std: 7.086690069945273
+    shape: (1280, 4096)
+decoder.att_rnn.rnn.cells.0.act_f:
+  inputs:
+    0:
+      min: -42.005043029785156
+      max: 48.00703430175781
+      avg_min: -24.572684604666218
+      avg_max: 24.47041631977212
+      mean: -5.428667449415389
+      std: 7.5857557123719195
+      shape: (1280, 1024)
+  output:
+    min: 5.720600170657743e-19
+    max: 1.0
+    avg_min: 1.3068073496415576e-06
+    avg_max: 0.9999987296174087
+    mean: 0.22353376293282823
+    std: 0.37551881742884263
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.act_i:
+  inputs:
+    0:
+      min: -42.485984802246094
+      max: 37.89415740966797
+      avg_min: -20.737807757666975
+      avg_max: 17.757313247745003
+      mean: -2.6210182911726867
+      std: 6.095373006020184
+      shape: (1280, 1024)
+  output:
+    min: 3.5364804280239927e-19
+    max: 1.0
+    avg_min: 2.9600880243534555e-08
+    avg_max: 0.9999995308980515
+    mean: 0.330330826534649
+    std: 0.4106618580468271
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.act_o:
+  inputs:
+    0:
+      min: -38.675228118896484
+      max: 45.59527587890625
+      avg_min: -19.614914631039895
+      avg_max: 22.38339533752272
+      mean: -1.0162639051926945
+      std: 6.201834056672583
+      shape: (1280, 1024)
+  output:
+    min: 1.5979450257881012e-17
+    max: 1.0
+    avg_min: 7.854096923179388e-08
+    avg_max: 0.9999999956133669
+    mean: 0.4198552611838568
+    std: 0.4295806405099903
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.act_g:
+  inputs:
+    0:
+      min: -45.18857192993164
+      max: 45.650177001953125
+      avg_min: -24.537228854318712
+      avg_max: 25.075750009129578
+      mean: 0.023995243424084333
+      std: 7.12668924216872
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.999999995680337
+    avg_max: 0.999999995680337
+    mean: -0.002128829708629359
+    std: 0.9478563806555594
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 5.720600170657743e-19
+      max: 1.0
+      avg_min: 1.3068073496415576e-06
+      avg_max: 0.9999987296174087
+      mean: 0.22353376293282823
+      std: 0.37551881742884263
+      shape: (1280, 1024)
+    1:
+      min: -45.964332580566406
+      max: 74.85725402832031
+      avg_min: -8.37450361519716
+      avg_max: 30.634137733129954
+      mean: 0.06729932622641525
+      std: 1.6449935154656605
+      shape: (1280, 1024)
+  output:
+    min: -45.9624137878418
+    max: 74.85725402832031
+    avg_min: -8.23528243631459
+    avg_max: 30.575408138686335
+    mean: 0.06604976536836245
+    std: 1.5455946584670446
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 3.5364804280239927e-19
+      max: 1.0
+      avg_min: 2.9600880243534555e-08
+      avg_max: 0.9999995308980515
+      mean: 0.330330826534649
+      std: 0.4106618580468271
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.999999995680337
+      avg_max: 0.999999995680337
+      mean: -0.002128829708629359
+      std: 0.9478563806555594
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999764687894426
+    avg_max: 0.9999969378950895
+    mean: 0.0027081930132716804
+    std: 0.510135969194079
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -45.9624137878418
+      max: 74.85725402832031
+      avg_min: -8.23528243631459
+      avg_max: 30.575408138686335
+      mean: 0.06604976536836245
+      std: 1.5455946584670446
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999764687894426
+      avg_max: 0.9999969378950895
+      mean: 0.0027081930132716804
+      std: 0.510135969194079
+      shape: (1280, 1024)
+  output:
+    min: -46.92375183105469
+    max: 75.85040283203125
+    avg_min: -8.574188237645677
+    avg_max: 31.416855715031076
+    mean: 0.06875795865061662
+    std: 1.676732644640398
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.act_h:
+  inputs:
+    0:
+      min: -46.92375183105469
+      max: 75.85040283203125
+      avg_min: -8.574188237645677
+      avg_max: 31.416855715031076
+      mean: 0.06875795865061662
+      std: 1.676732644640398
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9937744413533918
+    avg_max: 0.9959949824247473
+    mean: 0.0061799199947633445
+    std: 0.47477836896857967
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 1.5979450257881012e-17
+      max: 1.0
+      avg_min: 7.854096923179388e-08
+      avg_max: 0.9999999956133669
+      mean: 0.4198552611838568
+      std: 0.4295806405099903
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9937744413533918
+      avg_max: 0.9959949824247473
+      mean: 0.0061799199947633445
+      std: 0.47477836896857967
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9846230725224109
+    avg_max: 0.9952872082423628
+    mean: 0.007966578123805529
+    std: 0.32187307320175484
+    shape: (1280, 1024)
+decoder.att_rnn.rnn.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+decoder.att_rnn.attn.linear_q:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9846230725224109
+      avg_max: 0.9952872082423628
+      mean: 0.007966578123805529
+      std: 0.32187307320175484
+      shape: (1280, 1, 1024)
+  output:
+    min: -28.324504852294922
+    max: 30.094812393188477
+    avg_min: -16.16285647060098
+    avg_max: 16.370824370223502
+    mean: 0.07139282220419639
+    std: 5.614923564767698
+    shape: (1280, 1, 1024)
+decoder.att_rnn.attn.linear_k:
+  inputs:
+    0:
+      min: -2.9931235313415527
+      max: 2.9921934604644775
+      avg_min: -2.5274831603082375
+      avg_max: 2.571398892563376
+      mean: -0.002407128676685295
+      std: 0.3183030856867703
+      shape: (1280, 67, 1024)
+  output:
+    min: -27.933134078979492
+    max: 27.290098190307617
+    avg_min: -21.23075208556785
+    avg_max: 21.645756855439604
+    mean: 0.11102438071386868
+    std: 4.714025089771699
+    shape: (1280, 67, 1024)
+decoder.att_rnn.attn.dropout:
+  inputs:
+    0:
+      min: 0.0
+      max: 0.9910882711410522
+      avg_min: 1.299117331064688e-05
+      avg_max: 0.39005828052340624
+      mean: 0.013619308356258466
+      std: 0.05484342060778131
+      shape: (1280, 1, 67)
+  output:
+    min: 0.0
+    max: 0.9910882711410522
+    avg_min: 1.299117331064688e-05
+    avg_max: 0.39005828052340624
+    mean: 0.013619308356258466
+    std: 0.05484342060778131
+    shape: (1280, 1, 67)
+decoder.att_rnn.attn.eltwiseadd_qk:
+  inputs:
+    0:
+      min: -28.324504852294922
+      max: 30.094812393188477
+      avg_min: -16.16285647060098
+      avg_max: 16.370824370223502
+      mean: 0.07139282138405703
+      std: 5.614948532126775
+      shape: (1280, 1, 67, 1024)
+    1:
+      min: -27.933134078979492
+      max: 27.290098190307617
+      avg_min: -21.23075208556785
+      avg_max: 21.645756855439604
+      mean: 0.11102438071386868
+      std: 4.714025089771699
+      shape: (1280, 1, 67, 1024)
+  output:
+    min: -42.6893310546875
+    max: 41.46746063232422
+    avg_min: -26.125514688384673
+    avg_max: 26.724065986376143
+    mean: 0.18241720259650032
+    std: 6.144074303575077
+    shape: (1280, 1, 67, 1024)
+decoder.att_rnn.attn.eltwiseadd_norm_bias:
+  inputs:
+    0:
+      min: -42.6893310546875
+      max: 41.46746063232422
+      avg_min: -26.125514688384673
+      avg_max: 26.724065986376143
+      mean: 0.18241720259650032
+      std: 6.144074303575077
+      shape: (1280, 1, 67, 1024)
+    1:
+      min: -1.159269094467163
+      max: 1.030885934829712
+      avg_min: -1.159269094467163
+      avg_max: 1.030885934829712
+      mean: 0.01333538442850113
+      std: 0.3476333932113675
+      shape: (1024)
+  output:
+    min: -42.8883171081543
+    max: 41.60582733154297
+    avg_min: -26.312186408310776
+    avg_max: 26.855665168333605
+    mean: 0.19575258549512087
+    std: 6.287847745163477
+    shape: (1280, 1, 67, 1024)
+decoder.att_rnn.attn.eltwisemul_norm_scaler:
+  inputs:
+    0:
+      min: -0.164537250995636
+      max: 0.09800389409065247
+      avg_min: -0.164537250995636
+      avg_max: 0.09800389409065247
+      mean: 0.00032652184017933905
+      std: 0.031248209015367
+      shape: (1024)
+    1:
+      min: 1.2628638744354248
+      max: 1.2628638744354248
+      avg_min: 1.2628638744354248
+      avg_max: 1.2628638744354248
+      mean: 1.2628638744354248
+      std: .nan
+      shape: (1)
+  output:
+    min: -0.2077881544828415
+    max: 0.12376558035612106
+    avg_min: -0.2077881544828415
+    avg_max: 0.12376558035612106
+    mean: 0.0004123526159673929
+    std: 0.03946245597745369
+    shape: (1024)
+decoder.att_rnn.attn.tanh:
+  inputs:
+    0:
+      min: -42.8883171081543
+      max: 41.60582733154297
+      avg_min: -26.312186408310776
+      avg_max: 26.855665168333605
+      mean: 0.19575258549512087
+      std: 6.287847745163477
+      shape: (1280, 1, 67, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.999999995680337
+    avg_max: 0.999999995680337
+    mean: 0.03299323406470721
+    std: 0.9384829454472874
+    shape: (1280, 1, 67, 1024)
+decoder.att_rnn.attn.matmul_score:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.999999995680337
+      avg_max: 0.999999995680337
+      mean: 0.03299323406470721
+      std: 0.9384829454472874
+      shape: (1280, 1, 67, 1024)
+    1:
+      min: -0.2077881544828415
+      max: 0.12376558035612106
+      avg_min: -0.2077881544828415
+      avg_max: 0.12376558035612106
+      mean: 0.0004123526159673929
+      std: 0.03946245597745369
+      shape: (1024)
+  output:
+    min: -5.075064182281494
+    max: 20.071443557739258
+    avg_min: 4.662021810703742
+    avg_max: 12.197041146406967
+    mean: 7.974570217477471
+    std: 2.9091675582996563
+    shape: (1280, 1, 67)
+decoder.att_rnn.attn.softmax_att:
+  inputs:
+    0:
+      min: -65504.0
+      max: 20.071443557739258
+      avg_min: -58239.877876515005
+      avg_max: 12.163362407550382
+      mean: -24693.233116322437
+      std: 31749.93291331495
+      shape: (1280, 1, 67)
+  output:
+    min: 0.0
+    max: 0.9910882711410522
+    avg_min: 1.299117331064688e-05
+    avg_max: 0.39005828052340624
+    mean: 0.013619308356258466
+    std: 0.05484342060778131
+    shape: (1280, 1, 67)
+decoder.att_rnn.attn.context_matmul:
+  inputs:
+    0:
+      min: 0.0
+      max: 0.9910882711410522
+      avg_min: 1.299117331064688e-05
+      avg_max: 0.39005828052340624
+      mean: 0.013619308356258466
+      std: 0.05484342060778131
+      shape: (1280, 1, 67)
+    1:
+      min: -2.9931235313415527
+      max: 2.9921934604644775
+      avg_min: -2.5274831603082375
+      avg_max: 2.571398892563376
+      mean: -0.002407128676685295
+      std: 0.3183030856867703
+      shape: (1280, 67, 1024)
+  output:
+    min: -2.748584032058716
+    max: 2.7187154293060303
+    avg_min: -1.0244215909684593
+    avg_max: 0.9287193868937127
+    mean: 0.0004324376445557447
+    std: 0.16600608798290178
+    shape: (1280, 1, 1024)
+decoder.att_rnn.dropout:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9846230725224109
+      avg_max: 0.9952872082423628
+      mean: 0.007966578123805529
+      std: 0.32187307320175484
+      shape: (1280, 1, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9846230725224109
+    avg_max: 0.9952872082423628
+    mean: 0.007966578123805529
+    std: 0.32187307320175484
+    shape: (1280, 1, 1024)
+decoder.rnn_layers.0.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -2.748584032058716
+      max: 2.7187154293060303
+      avg_min: -1.200279578834437
+      avg_max: 1.1144092399752523
+      mean: 0.004199507878490362
+      std: 0.2561140031003674
+      shape: (1280, 2048)
+  output:
+    min: -31.11646270751953
+    max: 31.012344360351562
+    avg_min: -11.667031159561684
+    avg_max: 10.85241473739067
+    mean: -1.1740948434960983
+    std: 3.0091560830798434
+    shape: (1280, 4096)
+decoder.rnn_layers.0.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -0.9999914765357971
+      max: 0.9999936819076538
+      avg_min: -0.7896906179993346
+      avg_max: 0.7587759274110363
+      mean: -0.0010008882399605407
+      std: 0.1321120655297924
+      shape: (1280, 1024)
+  output:
+    min: -16.805160522460938
+    max: 22.041872024536133
+    avg_min: -5.688125228613951
+    avg_max: 5.603913084911495
+    mean: -0.49282410457060527
+    std: 1.2749332466602463
+    shape: (1280, 4096)
+decoder.rnn_layers.0.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -31.11646270751953
+      max: 31.012344360351562
+      avg_min: -11.667031159561684
+      avg_max: 10.85241473739067
+      mean: -1.1740948434960983
+      std: 3.0091560830798434
+      shape: (1280, 4096)
+    1:
+      min: -16.805160522460938
+      max: 22.041872024536133
+      avg_min: -5.688125228613951
+      avg_max: 5.603913084911495
+      mean: -0.49282410457060527
+      std: 1.2749332466602463
+      shape: (1280, 4096)
+  output:
+    min: -33.57273864746094
+    max: 30.18963623046875
+    avg_min: -13.689522092797787
+    avg_max: 12.499767167380655
+    mean: -1.6669189450446138
+    std: 3.6505088399213936
+    shape: (1280, 4096)
+decoder.rnn_layers.0.cells.0.act_f:
+  inputs:
+    0:
+      min: -30.98964500427246
+      max: 28.624046325683594
+      avg_min: -13.069127936845424
+      avg_max: 11.013900484127936
+      mean: -2.6897614965278116
+      std: 4.286025676313951
+      shape: (1280, 1024)
+  output:
+    min: 3.4783092882741465e-14
+    max: 1.0
+    avg_min: 1.753829452079096e-05
+    avg_max: 0.9996802927737837
+    mean: 0.2552954440455087
+    std: 0.35829928323219956
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.act_i:
+  inputs:
+    0:
+      min: -29.928306579589844
+      max: 27.947511672973633
+      avg_min: -10.685935257525923
+      avg_max: 9.380031768927397
+      mean: -2.4518663697698173
+      std: 3.0217416033646693
+      shape: (1280, 1024)
+  output:
+    min: 1.005313792824293e-13
+    max: 1.0
+    avg_min: 0.00017471713297229053
+    avg_max: 0.9996398801214232
+    mean: 0.22650649123144959
+    std: 0.3026139789025641
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.act_o:
+  inputs:
+    0:
+      min: -27.317293167114258
+      max: 23.132068634033203
+      avg_min: -11.532117180877865
+      avg_max: 8.524241487095862
+      mean: -1.5604653776026836
+      std: 3.2415928306767676
+      shape: (1280, 1024)
+  output:
+    min: 1.3685174955063717e-12
+    max: 1.0
+    avg_min: 9.716954026726062e-05
+    avg_max: 0.9978680859455905
+    mean: 0.3295869104145614
+    std: 0.34560961994352346
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.act_g:
+  inputs:
+    0:
+      min: -33.57273864746094
+      max: 30.18963623046875
+      avg_min: -11.964716329735301
+      avg_max: 11.46917405824984
+      mean: 0.03441744941392035
+      std: 3.2752884250043954
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999999861703825
+    avg_max: 0.9999999514456561
+    mean: 0.009212502904665593
+    std: 0.8613827573443271
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 3.4783092882741465e-14
+      max: 1.0
+      avg_min: 1.753829452079096e-05
+      avg_max: 0.9996802927737837
+      mean: 0.2552954440455087
+      std: 0.35829928323219956
+      shape: (1280, 1024)
+    1:
+      min: -50.77605438232422
+      max: 51.83877944946289
+      avg_min: -11.694220676448898
+      avg_max: 11.783769438574792
+      mean: -0.013233005671694572
+      std: 0.9262749365617192
+      shape: (1280, 1024)
+  output:
+    min: -50.32355880737305
+    max: 51.78708267211914
+    avg_min: -11.160377957814195
+    avg_max: 11.075782726721808
+    mean: -0.010043481503626429
+    std: 0.8084092966416248
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 1.005313792824293e-13
+      max: 1.0
+      avg_min: 0.00017471713297229053
+      avg_max: 0.9996398801214232
+      mean: 0.22650649123144959
+      std: 0.3026139789025641
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999999861703825
+      avg_max: 0.9999999514456561
+      mean: 0.009212502904665593
+      std: 0.8613827573443271
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9977656979239379
+    avg_max: 0.9981463519374972
+    mean: -0.0036360137033707263
+    std: 0.3448696757960219
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -50.32355880737305
+      max: 51.78708267211914
+      avg_min: -11.160377957814195
+      avg_max: 11.075782726721808
+      mean: -0.010043481503626429
+      std: 0.8084092966416248
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9977656979239379
+      avg_max: 0.9981463519374972
+      mean: -0.0036360137033707263
+      std: 0.3448696757960219
+      shape: (1280, 1024)
+  output:
+    min: -51.32215881347656
+    max: 52.70581817626953
+    avg_min: -11.736164883080493
+    avg_max: 11.741759218124862
+    mean: -0.013679495197402003
+    std: 0.9252576031128981
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.act_h:
+  inputs:
+    0:
+      min: -51.32215881347656
+      max: 52.70581817626953
+      avg_min: -11.736164883080493
+      avg_max: 11.741759218124862
+      mean: -0.013679495197402003
+      std: 0.9252576031128981
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9954183305247454
+    avg_max: 0.9928096467189585
+    mean: -0.005522266373139992
+    std: 0.3470194696164207
+    shape: (1280, 1024)
+decoder.rnn_layers.0.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 1.3685174955063717e-12
+      max: 1.0
+      avg_min: 9.716954026726062e-05
+      avg_max: 0.9978680859455905
+      mean: 0.3295869104145614
+      std: 0.34560961994352346
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9954183305247454
+      avg_max: 0.9928096467189585
+      mean: -0.005522266373139992
+      std: 0.3470194696164207
+      shape: (1280, 1024)
+  output:
+    min: -0.9999914765357971
+    max: 0.9999985694885254
+    avg_min: -0.7889921560716086
+    avg_max: 0.7594633606377601
+    mean: -0.0010860526439966734
+    std: 0.13179261237432402
+    shape: (1280, 1024)
+decoder.rnn_layers.0.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+decoder.rnn_layers.1.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -2.748584032058716
+      max: 2.7187154293060303
+      avg_min: -1.1142985966098458
+      avg_max: 0.996185907755004
+      mean: -0.00032680749604648707
+      std: 0.14988080842341334
+      shape: (1280, 2048)
+  output:
+    min: -22.17432403564453
+    max: 24.88280487060547
+    avg_min: -7.668186707710953
+    avg_max: 7.5104381103194156
+    mean: -0.8090552802835963
+    std: 2.096109569488443
+    shape: (1280, 4096)
+decoder.rnn_layers.1.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -0.9998709559440613
+      max: 0.9998664259910583
+      avg_min: -0.6980919281250983
+      avg_max: 0.7355864281399854
+      mean: 0.00030723568868500604
+      std: 0.13119853605881615
+      shape: (1280, 1024)
+  output:
+    min: -10.816532135009766
+    max: 10.679582595825195
+    avg_min: -4.561453889193147
+    avg_max: 4.048953114064886
+    mean: -0.47011378199029547
+    std: 0.9846700195969109
+    shape: (1280, 4096)
+decoder.rnn_layers.1.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -22.17432403564453
+      max: 24.88280487060547
+      avg_min: -7.668186707710953
+      avg_max: 7.5104381103194156
+      mean: -0.8090552802835963
+      std: 2.096109569488443
+      shape: (1280, 4096)
+    1:
+      min: -10.816532135009766
+      max: 10.679582595825195
+      avg_min: -4.561453889193147
+      avg_max: 4.048953114064886
+      mean: -0.47011378199029547
+      std: 0.9846700195969109
+      shape: (1280, 4096)
+  output:
+    min: -24.22309684753418
+    max: 24.288864135742188
+    avg_min: -9.790749129284656
+    avg_max: 8.619773781969311
+    mean: -1.2791690607754058
+    std: 2.5821871111713826
+    shape: (1280, 4096)
+decoder.rnn_layers.1.cells.0.act_f:
+  inputs:
+    0:
+      min: -24.22309684753418
+      max: 23.923398971557617
+      avg_min: -9.621848200680162
+      avg_max: 6.821592486440458
+      mean: -2.6714931534247466
+      std: 2.5443121998813525
+      shape: (1280, 1024)
+  output:
+    min: 3.020248634522105e-11
+    max: 1.0
+    avg_min: 0.0003451858472697981
+    avg_max: 0.9935724002256816
+    mean: 0.1752624285941042
+    std: 0.25292423787376883
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.act_i:
+  inputs:
+    0:
+      min: -20.21559715270996
+      max: 22.597978591918945
+      avg_min: -7.931391174873612
+      avg_max: 6.599652247616419
+      mean: -1.3425940824358642
+      std: 2.214069210841872
+      shape: (1280, 1024)
+  output:
+    min: 1.6614134512593637e-09
+    max: 1.0
+    avg_min: 0.0012407131223308894
+    avg_max: 0.994309761283102
+    mean: 0.3069548572800805
+    std: 0.29030061738546137
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.act_o:
+  inputs:
+    0:
+      min: -20.679615020751953
+      max: 24.288864135742188
+      avg_min: -8.320266337876905
+      avg_max: 7.602936600299369
+      mean: -1.092906504534604
+      std: 2.4894375831577364
+      shape: (1280, 1024)
+  output:
+    min: 1.0446175036094019e-09
+    max: 1.0
+    avg_min: 0.0006933753792029073
+    avg_max: 0.9977307225211285
+    mean: 0.3514365402500282
+    std: 0.32005181986355885
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.act_g:
+  inputs:
+    0:
+      min: -21.044048309326172
+      max: 21.751747131347656
+      avg_min: -7.526448396886325
+      avg_max: 7.799964112378239
+      mean: -0.009682508857590904
+      std: 2.3471144877832884
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999812895662334
+    avg_max: 0.9999866421637916
+    mean: -0.007454321779905927
+    std: 0.8034238800947899
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 3.020248634522105e-11
+      max: 1.0
+      avg_min: 0.0003451858472697981
+      avg_max: 0.9935724002256816
+      mean: 0.1752624285941042
+      std: 0.25292423787376883
+      shape: (1280, 1024)
+    1:
+      min: -50.93753433227539
+      max: 30.134031295776367
+      avg_min: -2.980211339807237
+      avg_max: 2.858599665124761
+      mean: -0.008311334580423152
+      std: 0.4794455076286656
+      shape: (1280, 1024)
+  output:
+    min: -50.889991760253906
+    max: 30.027212142944336
+    avg_min: -2.4969031673468915
+    avg_max: 2.3347478714216994
+    mean: -0.0015736418809647336
+    std: 0.25571878953430177
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 1.6614134512593637e-09
+      max: 1.0
+      avg_min: 0.0012407131223308894
+      avg_max: 0.994309761283102
+      mean: 0.3069548572800805
+      std: 0.29030061738546137
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999812895662334
+      avg_max: 0.9999866421637916
+      mean: -0.007454321779905927
+      std: 0.8034238800947899
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9808258289701476
+    avg_max: 0.9891012119108362
+    mean: -0.006563609112533941
+    std: 0.3551230793477708
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -50.889991760253906
+      max: 30.027212142944336
+      avg_min: -2.4969031673468915
+      avg_max: 2.3347478714216994
+      mean: -0.0015736418809647336
+      std: 0.25571878953430177
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9808258289701476
+      avg_max: 0.9891012119108362
+      mean: -0.006563609112533941
+      std: 0.3551230793477708
+      shape: (1280, 1024)
+  output:
+    min: -51.780860900878906
+    max: 30.70660400390625
+    avg_min: -3.067318748858543
+    avg_max: 2.8984877801343285
+    mean: -0.008137250974908283
+    std: 0.4813941324820594
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.act_h:
+  inputs:
+    0:
+      min: -51.780860900878906
+      max: 30.70660400390625
+      avg_min: -3.067318748858543
+      avg_max: 2.8984877801343285
+      mean: -0.008137250974908283
+      std: 0.4813941324820594
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9705970220016633
+    avg_max: 0.9670456421509231
+    mean: -0.0057490570247533766
+    std: 0.345665687214933
+    shape: (1280, 1024)
+decoder.rnn_layers.1.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 1.0446175036094019e-09
+      max: 1.0
+      avg_min: 0.0006933753792029073
+      avg_max: 0.9977307225211285
+      mean: 0.3514365402500282
+      std: 0.32005181986355885
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9705970220016633
+      avg_max: 0.9670456421509231
+      mean: -0.0057490570247533766
+      std: 0.345665687214933
+      shape: (1280, 1024)
+  output:
+    min: -0.9998709559440613
+    max: 0.9998664259910583
+    avg_min: -0.7024310171101872
+    avg_max: 0.743544528524527
+    mean: 0.00032440792050908126
+    std: 0.1314242965426167
+    shape: (1280, 1024)
+decoder.rnn_layers.1.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+decoder.rnn_layers.2.cells.0.fc_gate_x:
+  inputs:
+    0:
+      min: -2.748584032058716
+      max: 2.7187154293060303
+      avg_min: -1.1982355600662467
+      avg_max: 1.1189356212870456
+      mean: -0.00016460354038312733
+      std: 0.1760178086847606
+      shape: (1280, 2048)
+  output:
+    min: -24.67524528503418
+    max: 28.180944442749023
+    avg_min: -8.691330929284696
+    avg_max: 8.328183968147563
+    mean: -0.8204240057408129
+    std: 2.453284969094989
+    shape: (1280, 4096)
+decoder.rnn_layers.2.cells.0.fc_gate_h:
+  inputs:
+    0:
+      min: -0.9999999403953552
+      max: 0.9999994039535522
+      avg_min: -0.814193978671278
+      avg_max: 0.8058140833056374
+      mean: -0.0030742165989599258
+      std: 0.16821435780985747
+      shape: (1280, 1024)
+  output:
+    min: -18.816926956176758
+    max: 14.035696983337402
+    avg_min: -6.604132010829581
+    avg_max: 5.89386603040307
+    mean: -0.7366783961109576
+    std: 1.5059158176081058
+    shape: (1280, 4096)
+decoder.rnn_layers.2.cells.0.eltwiseadd_gate:
+  inputs:
+    0:
+      min: -24.67524528503418
+      max: 28.180944442749023
+      avg_min: -8.691330929284696
+      avg_max: 8.328183968147563
+      mean: -0.8204240057408129
+      std: 2.453284969094989
+      shape: (1280, 4096)
+    1:
+      min: -18.816926956176758
+      max: 14.035696983337402
+      avg_min: -6.604132010829581
+      avg_max: 5.89386603040307
+      mean: -0.7366783961109576
+      std: 1.5059158176081058
+      shape: (1280, 4096)
+  output:
+    min: -29.283180236816406
+    max: 26.588882446289062
+    avg_min: -12.15073210014387
+    avg_max: 10.381155111280712
+    mean: -1.557102406895562
+    std: 3.2988025400040257
+    shape: (1280, 4096)
+decoder.rnn_layers.2.cells.0.act_f:
+  inputs:
+    0:
+      min: -29.283180236816406
+      max: 23.37925148010254
+      avg_min: -12.019422794459913
+      avg_max: 7.526722565527717
+      mean: -3.6089404096094384
+      std: 3.0091061671078934
+      shape: (1280, 1024)
+  output:
+    min: 1.9163568843773293e-13
+    max: 1.0
+    avg_min: 0.00010279474620529197
+    avg_max: 0.9962542612231184
+    mean: 0.13819005828811198
+    std: 0.24204215179724617
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.act_i:
+  inputs:
+    0:
+      min: -24.22443962097168
+      max: 24.299768447875977
+      avg_min: -9.08453069510085
+      avg_max: 7.826490100581997
+      mean: -0.91120249788078
+      std: 2.6741725568824375
+      shape: (1280, 1024)
+  output:
+    min: 3.0161959735375277e-11
+    max: 1.0
+    avg_min: 0.0005223408363930552
+    avg_max: 0.9983030597815357
+    mean: 0.38171999553281244
+    std: 0.3367532650457234
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.act_o:
+  inputs:
+    0:
+      min: -25.566701889038086
+      max: 26.588882446289062
+      avg_min: -10.116214546192888
+      avg_max: 8.086509210607966
+      mean: -1.6018465338798051
+      std: 2.822778696064494
+      shape: (1280, 1024)
+  output:
+    min: 7.879931950005581e-12
+    max: 1.0
+    avg_min: 0.00020159611936281438
+    avg_max: 0.9988285692890035
+    mean: 0.30748807730969463
+    std: 0.3263006812590941
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.act_g:
+  inputs:
+    0:
+      min: -26.263608932495117
+      max: 23.93798828125
+      avg_min: -10.2838829126251
+      avg_max: 10.15937283199823
+      mean: -0.10642016830597752
+      std: 3.5525409631159355
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9999990369161875
+    avg_max: 0.999999153747988
+    mean: -0.0269237342450161
+    std: 0.879016886211502
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.eltwisemult_cell_forget:
+  inputs:
+    0:
+      min: 1.9163568843773293e-13
+      max: 1.0
+      avg_min: 0.00010279474620529197
+      avg_max: 0.9962542612231184
+      mean: 0.13819005828811198
+      std: 0.24204215179724617
+      shape: (1280, 1024)
+    1:
+      min: -50.77392578125
+      max: 57.88676452636719
+      avg_min: -4.026170471076211
+      avg_max: 3.778732706154328
+      mean: -0.010375739289892385
+      std: 0.6429531313365376
+      shape: (1280, 1024)
+  output:
+    min: -49.94257354736328
+    max: 57.657936096191406
+    avg_min: -3.367933605661549
+    avg_max: 3.1268611868780627
+    mean: -0.00047371524729856
+    std: 0.358745687372424
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.eltwisemult_cell_input:
+  inputs:
+    0:
+      min: 3.0161959735375277e-11
+      max: 1.0
+      avg_min: 0.0005223408363930552
+      avg_max: 0.9983030597815357
+      mean: 0.38171999553281244
+      std: 0.3367532650457234
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9999990369161875
+      avg_max: 0.999999153747988
+      mean: -0.0269237342450161
+      std: 0.879016886211502
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.9962779700086355
+    avg_max: 0.9960211656066814
+    mean: -0.01004450515169632
+    std: 0.4632697182775279
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.eltwiseadd_cell:
+  inputs:
+    0:
+      min: -49.94257354736328
+      max: 57.657936096191406
+      avg_min: -3.367933605661549
+      avg_max: 3.1268611868780627
+      mean: -0.00047371524729856
+      std: 0.358745687372424
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9962779700086355
+      avg_max: 0.9960211656066814
+      mean: -0.01004450515169632
+      std: 0.4632697182775279
+      shape: (1280, 1024)
+  output:
+    min: -50.93840026855469
+    max: 58.63222122192383
+    avg_min: -4.101547773835359
+    avg_max: 3.8555672365293057
+    mean: -0.010518220426035512
+    std: 0.6487821329571182
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.act_h:
+  inputs:
+    0:
+      min: -50.93840026855469
+      max: 58.63222122192383
+      avg_min: -4.101547773835359
+      avg_max: 3.8555672365293057
+      mean: -0.010518220426035512
+      std: 0.6487821329571182
+      shape: (1280, 1024)
+  output:
+    min: -1.0
+    max: 1.0
+    avg_min: -0.985261068585214
+    avg_max: 0.9829632449016149
+    mean: -0.008678492148203425
+    std: 0.4269098495511477
+    shape: (1280, 1024)
+decoder.rnn_layers.2.cells.0.eltwisemult_hidden:
+  inputs:
+    0:
+      min: 7.879931950005581e-12
+      max: 1.0
+      avg_min: 0.00020159611936281438
+      avg_max: 0.9988285692890035
+      mean: 0.30748807730969463
+      std: 0.3263006812590941
+      shape: (1280, 1024)
+    1:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.985261068585214
+      avg_max: 0.9829632449016149
+      mean: -0.008678492148203425
+      std: 0.4269098495511477
+      shape: (1280, 1024)
+  output:
+    min: -0.9999999403953552
+    max: 0.9999994039535522
+    avg_min: -0.8279972857303842
+    avg_max: 0.8186234360665419
+    mean: -0.0031441434696200784
+    std: 0.1694763855118817
+    shape: (1280, 1024)
+decoder.rnn_layers.2.dropout:
+  output:
+    min: 1.7976931348623157e+308
+    max: -1.7976931348623157e+308
+    avg_min: 0
+    avg_max: 0
+    mean: 0
+    std: 0
+    shape: ''
+decoder.classifier.classifier:
+  inputs:
+    0:
+      min: -2.9619157314300537
+      max: 2.9662914276123047
+      avg_min: -1.3161585628316654
+      avg_max: 1.2828886617770356
+      mean: -0.003905788197230935
+      std: 0.2593071699498008
+      shape: (1280, 1, 1024)
+  output:
+    min: -19.020658493041992
+    max: 20.08169937133789
+    avg_min: -9.04961023920039
+    avg_max: 8.032180467616302
+    mean: -3.873352058817827
+    std: 1.696852165054472
+    shape: (1280, 1, 32317)
+decoder.dropout:
+  inputs:
+    0:
+      min: -1.9940483570098877
+      max: 1.9767669439315796
+      avg_min: -0.9086800352129599
+      avg_max: 0.9094174789020626
+      mean: 0.002039626916454153
+      std: 0.22761719307769354
+      shape: (1280, 1, 1024)
+  output:
+    min: -1.9940483570098877
+    max: 1.9767669439315796
+    avg_min: -0.9086800352129599
+    avg_max: 0.9094174789020626
+    mean: 0.002039626916454153
+    std: 0.22761719307769354
+    shape: (1280, 1, 1024)
+decoder.eltwiseadd_residuals.0:
+  inputs:
+    0:
+      min: -0.9998709559440613
+      max: 0.9998664259910583
+      avg_min: -0.7024310171101872
+      avg_max: 0.743544528524527
+      mean: 0.00032440792050908126
+      std: 0.1314242965426167
+      shape: (1280, 1, 1024)
+    1:
+      min: -0.9999914765357971
+      max: 0.9999985694885254
+      avg_min: -0.7889921560716086
+      avg_max: 0.7594633606377601
+      mean: -0.0010860526439966734
+      std: 0.13179261237432402
+      shape: (1280, 1, 1024)
+  output:
+    min: -1.9940483570098877
+    max: 1.9767669439315796
+    avg_min: -0.9524248770448603
+    avg_max: 0.9735018678260661
+    mean: -0.0007616447304464125
+    std: 0.1854876875589218
+    shape: (1280, 1, 1024)
+decoder.eltwiseadd_residuals.1:
+  inputs:
+    0:
+      min: -0.9999999403953552
+      max: 0.9999994039535522
+      avg_min: -0.8279972857303842
+      avg_max: 0.8186234360665419
+      mean: -0.0031441434696200784
+      std: 0.1694763855118817
+      shape: (1280, 1, 1024)
+    1:
+      min: -1.9940483570098877
+      max: 1.9767669439315796
+      avg_min: -0.9524248770448603
+      avg_max: 0.9735018678260661
+      mean: -0.0007616447304464125
+      std: 0.1854876875589218
+      shape: (1280, 1, 1024)
+  output:
+    min: -2.9619157314300537
+    max: 2.9662914276123047
+    avg_min: -1.3161585628316654
+    avg_max: 1.2828886617770356
+    mean: -0.003905788197230935
+    std: 0.2593071699498008
+    shape: (1280, 1, 1024)
+decoder.attention_concats.0:
+  inputs:
+    0:
+      min: -1.0
+      max: 1.0
+      avg_min: -0.9846230725224109
+      avg_max: 0.9952872082423628
+      mean: 0.007966578123805529
+      std: 0.32187307320175484
+      shape: (1280, 1, 1024)
+    1:
+      min: -2.748584032058716
+      max: 2.7187154293060303
+      avg_min: -1.0244215909684593
+      avg_max: 0.9287193868937127
+      mean: 0.0004324376445557447
+      std: 0.16600608798290178
+      shape: (1280, 1, 1024)
+  output:
+    min: -2.748584032058716
+    max: 2.7187154293060303
+    avg_min: -1.200279578834437
+    avg_max: 1.1144092399752523
+    mean: 0.004199507878490362
+    std: 0.2561140031003674
+    shape: (1280, 1, 2048)
+decoder.attention_concats.1:
+  inputs:
+    0:
+      min: -0.9999914765357971
+      max: 0.9999985694885254
+      avg_min: -0.7889921560716086
+      avg_max: 0.7594633606377601
+      mean: -0.0010860526439966734
+      std: 0.13179261237432402
+      shape: (1280, 1, 1024)
+    1:
+      min: -2.748584032058716
+      max: 2.7187154293060303
+      avg_min: -1.0244215909684593
+      avg_max: 0.9287193868937127
+      mean: 0.0004324376445557447
+      std: 0.16600608798290178
+      shape: (1280, 1, 1024)
+  output:
+    min: -2.748584032058716
+    max: 2.7187154293060303
+    avg_min: -1.1142985966098458
+    avg_max: 0.996185907755004
+    mean: -0.00032680749604648707
+    std: 0.14988080842341334
+    shape: (1280, 1, 2048)
+decoder.attention_concats.2:
+  inputs:
+    0:
+      min: -1.9940483570098877
+      max: 1.9767669439315796
+      avg_min: -0.9524248770448603
+      avg_max: 0.9735018678260661
+      mean: -0.0007616447304464125
+      std: 0.1854876875589218
+      shape: (1280, 1, 1024)
+    1:
+      min: -2.748584032058716
+      max: 2.7187154293060303
+      avg_min: -1.0244215909684593
+      avg_max: 0.9287193868937127
+      mean: 0.0004324376445557447
+      std: 0.16600608798290178
+      shape: (1280, 1, 1024)
+  output:
+    min: -2.748584032058716
+    max: 2.7187154293060303
+    avg_min: -1.1982355600662467
+    avg_max: 1.1189356212870456
+    mean: -0.00016460354038312733
+    std: 0.1760178086847606
+    shape: (1280, 1, 2048)
diff --git a/examples/GNMT/quantize_gnmt.ipynb b/examples/GNMT/quantize_gnmt.ipynb
new file mode 100644
index 0000000..a2904d8
--- /dev/null
+++ b/examples/GNMT/quantize_gnmt.ipynb
@@ -0,0 +1,1069 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Quantizing Neural Machine Translation Models"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We continue our quest to quantize every Neural Network!  \n",
+    "On this chapter: __Google's Neural Machine Translation model__.  \n",
+    "A brief summary - using stacked LSTMs and attention mechanism, this model encodes a sentence into a list of vectors and then decodes it to the other language tokens until an end token is reached.  \n",
+    "To read more - refer to <a id=\"ref-1\" href=\"#cite-wu2016google\">Google's paper</a>."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Table of Contents\n",
+    "* [Quantizing Neural Machine Translation Models](#Quantizing-Neural-Machine-Translation-Models)\n",
+    "\t* [Getting the resources](#Getting-the-resources)\n",
+    "\t* [Loading the model](#Loading-the-model)\n",
+    "\t* [Evaulation of the model](#Evaulation-of-the-model)\n",
+    "\t* [Quantizing the model](#Quantizing-the-model)\n",
+    "\t\t* [Collecting the statistics](#Collecting-the-statistics)\n",
+    "\t\t* [Defining the Quantizer](#Defining-the-Quantizer)\n",
+    "\t\t* [Quantizing the model](#Quantizing-the-model)\n",
+    "\t\t* [Evaluating the quantized model](#Evaluating-the-quantized-model)\n",
+    "\t\t* [Finding the right quantization](#Finding-the-right-quantization)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Getting the resources"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In this project, we modified the [`mlperf/training/rnn_translator`](https://github.com/mlperf/training/tree/master/rnn_translator) project to enable quantization of the GNMT model.  \n",
+    "The instructions to download and setup the required environment for this task are in `README.md` (located in the current directory).  \n",
+    "Download the pretrained model using the command:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Uncomment the line below to download the pretrained model:\n",
+    "#! wget https://zenodo.org/record/2581623/files/model_best.pth"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "At this point, you should have everything ready to start quantizing!"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Preparing The Model For Quantization\n",
+    "\n",
+    "In order to be able to fully quantize the model, we modify it according to the instructions laid out in the Distiller [documentation](https://nervanasystems.github.io/distiller/prepare_model_quant.html). This mostly amounts to making sure every quantize-able operation is invoked via a dedicated PyTorch Module. You can compare the code under `seq2seq/models` in this example with the [original](https://github.com/mlperf/training/tree/master/rnn_translator/pytorch/seq2seq/models).\n",
+    "\n",
+    "For example, in `seq2seq/models/attention.py`, we added the following code to the `__init__` function of the `BahdanauAttention` class:\n",
+    "\n",
+    "```python\n",
+    "# Adding submodules for basic ops to allow quantization:\n",
+    "self.eltwiseadd_qk = EltwiseAdd()\n",
+    "self.eltwiseadd_norm_bias = EltwiseAdd()\n",
+    "self.eltwisemul_norm_scaler = EltwiseMult()\n",
+    "self.matmul_score = Matmul()\n",
+    "self.context_matmul = BatchMatmul()\n",
+    "```\n",
+    "\n",
+    "We're creating modules for operations that were invoked directly in the `forward` function in the original code. This enables Distiller to detect these operations and replace them with quantized counterparts.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Loading the model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import distiller\n",
+    "from distiller.modules import DistillerLSTM\n",
+    "from distiller.quantization import PostTrainLinearQuantizer\n",
+    "from ast import literal_eval\n",
+    "from itertools import zip_longest\n",
+    "from copy import deepcopy\n",
+    "\n",
+    "from seq2seq import models\n",
+    "from seq2seq.inference.inference import Translator\n",
+    "from seq2seq.utils import AverageMeter\n",
+    "import subprocess\n",
+    "import os\n",
+    "import seq2seq.data.config as config\n",
+    "from seq2seq.data.dataset import ParallelDataset\n",
+    "import logging\n",
+    "from seq2seq.utils import AverageMeter\n",
+    "# Import utilities from the example:\n",
+    "from translate import grouper, write_output, checkpoint_from_distributed, unwrap_distributed\n",
+    "from itertools import takewhile\n",
+    "from tqdm import tqdm\n",
+    "import logging\n",
+    "logging.disable(logging.INFO)  # Disables mlperf output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define some constants\n",
+    "batch_first=True\n",
+    "batch_size=128\n",
+    "beam_size=10\n",
+    "cov_penalty_factor=0.1\n",
+    "dataset_dir='./data'\n",
+    "input='./data/newstest2014.tok.clean.bpe.32000.en'\n",
+    "len_norm_const=5.0\n",
+    "len_norm_factor=0.6\n",
+    "max_seq_len=80\n",
+    "model='model_best.pth'\n",
+    "output='output_file'\n",
+    "print_freq=1\n",
+    "reference='./data/newstest2014.de'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Loading the model\n",
+    "checkpoint = torch.load('./model_best.pth', map_location={'cuda:0': 'cpu'})\n",
+    "vocab_size = checkpoint['tokenizer'].vocab_size\n",
+    "model_config = dict(vocab_size=vocab_size, math=checkpoint['config'].math,\n",
+    "                    **literal_eval(checkpoint['config'].model_config))\n",
+    "model_config['batch_first'] = batch_first\n",
+    "model = models.GNMT(**model_config)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "GNMT(\n",
+       "  (encoder): ResidualRecurrentEncoder(\n",
+       "    (rnn_layers): ModuleList(\n",
+       "      (0): LSTM(1024, 1024, batch_first=True, bidirectional=True)\n",
+       "      (1): LSTM(2048, 1024, batch_first=True)\n",
+       "      (2): LSTM(1024, 1024, batch_first=True)\n",
+       "      (3): LSTM(1024, 1024, batch_first=True)\n",
+       "    )\n",
+       "    (dropout): Dropout(p=0.2)\n",
+       "    (embedder): Embedding(32317, 1024, padding_idx=0)\n",
+       "    (eltwiseadd_residuals): ModuleList(\n",
+       "      (0): EltwiseAdd()\n",
+       "      (1): EltwiseAdd()\n",
+       "    )\n",
+       "  )\n",
+       "  (decoder): ResidualRecurrentDecoder(\n",
+       "    (att_rnn): RecurrentAttention(\n",
+       "      (rnn): LSTM(1024, 1024, batch_first=True)\n",
+       "      (attn): BahdanauAttention(\n",
+       "        (linear_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
+       "        (linear_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
+       "        (dropout): Dropout(p=0)\n",
+       "        (eltwiseadd_qk): EltwiseAdd()\n",
+       "        (eltwiseadd_norm_bias): EltwiseAdd()\n",
+       "        (eltwisemul_norm_scaler): EltwiseMult()\n",
+       "        (tanh): Tanh()\n",
+       "        (matmul_score): Matmul()\n",
+       "        (softmax_att): Softmax()\n",
+       "        (context_matmul): BatchMatmul()\n",
+       "      )\n",
+       "      (dropout): Dropout(p=0)\n",
+       "    )\n",
+       "    (rnn_layers): ModuleList(\n",
+       "      (0): LSTM(2048, 1024, batch_first=True)\n",
+       "      (1): LSTM(2048, 1024, batch_first=True)\n",
+       "      (2): LSTM(2048, 1024, batch_first=True)\n",
+       "    )\n",
+       "    (embedder): Embedding(32317, 1024, padding_idx=0)\n",
+       "    (classifier): Classifier(\n",
+       "      (classifier): Linear(in_features=1024, out_features=32317, bias=True)\n",
+       "    )\n",
+       "    (dropout): Dropout(p=0.2)\n",
+       "    (eltwiseadd_residuals): ModuleList(\n",
+       "      (0): EltwiseAdd()\n",
+       "      (1): EltwiseAdd()\n",
+       "    )\n",
+       "    (attention_concats): ModuleList(\n",
+       "      (0): Concat()\n",
+       "      (1): Concat()\n",
+       "      (2): Concat()\n",
+       "    )\n",
+       "  )\n",
+       ")"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "state_dict = checkpoint['state_dict']\n",
+    "if checkpoint_from_distributed(state_dict):\n",
+    "    state_dict = unwrap_distributed(state_dict)\n",
+    "\n",
+    "model.load_state_dict(state_dict)\n",
+    "torch.cuda.set_device(0)\n",
+    "model = model.cuda()\n",
+    "model.eval()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaulation of the model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer = checkpoint['tokenizer']\n",
+    "\n",
+    "\n",
+    "test_data = ParallelDataset(\n",
+    "    src_fname=os.path.join(dataset_dir, config.SRC_TEST_FNAME),\n",
+    "    tgt_fname=os.path.join(dataset_dir, config.TGT_TEST_FNAME),\n",
+    "    tokenizer=tokenizer,\n",
+    "    min_len=0,\n",
+    "    max_len=150,\n",
+    "    sort=False)\n",
+    "\n",
+    "def get_loader():\n",
+    "    return test_data.get_loader(batch_size=batch_size,\n",
+    "                                   batch_first=True,\n",
+    "                                   shuffle=False,\n",
+    "                                   num_workers=0,\n",
+    "                                   drop_last=False,\n",
+    "                                   distributed=False)\n",
+    "def get_translator(model):\n",
+    "    return Translator(model,\n",
+    "                       tokenizer,\n",
+    "                       beam_size=beam_size,\n",
+    "                       max_seq_len=max_seq_len,\n",
+    "                       len_norm_factor=len_norm_factor,\n",
+    "                       len_norm_const=len_norm_const,\n",
+    "                       cov_penalty_factor=cov_penalty_factor,\n",
+    "                       cuda=True)\n",
+    "torch.cuda.empty_cache()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def evaluate(model, test_path):\n",
+    "    test_file = open(test_path, 'w', encoding='UTF-8')\n",
+    "    model.eval()\n",
+    "    translator = get_translator(model)\n",
+    "    stats = {}\n",
+    "    iterations = enumerate(tqdm(get_loader()))\n",
+    "    for i, (src, tgt, indices) in iterations:\n",
+    "        src, src_length = src\n",
+    "        if translator.batch_first:\n",
+    "            batch_size = src.size(0)\n",
+    "        else:\n",
+    "            batch_size = src.size(1)\n",
+    "        bos = [translator.insert_target_start] * (batch_size * beam_size)\n",
+    "        bos = torch.LongTensor(bos)\n",
+    "        if translator.batch_first:\n",
+    "            bos = bos.view(-1, 1)\n",
+    "        else:\n",
+    "            bos = bos.view(1, -1)\n",
+    "        src_length = torch.LongTensor(src_length)\n",
+    "        stats['total_enc_len'] = int(src_length.sum())\n",
+    "        src = src.cuda()\n",
+    "        src_length = src_length.cuda()\n",
+    "        bos = bos.cuda()\n",
+    "        with torch.no_grad():\n",
+    "            context = translator.model.encode(src, src_length)\n",
+    "            context = [context, src_length, None]\n",
+    "            if beam_size == 1:\n",
+    "                generator = translator.generator.greedy_search\n",
+    "            else:\n",
+    "                generator = translator.generator.beam_search\n",
+    "            preds, lengths, counter = generator(batch_size, bos, context)\n",
+    "        stats['total_dec_len'] = lengths.sum().item()\n",
+    "        stats['iters'] = counter\n",
+    "        preds = preds.cpu()\n",
+    "        lengths = lengths.cpu()\n",
+    "        output = []\n",
+    "        for idx, pred in enumerate(preds):\n",
+    "            end = lengths[idx] - 1\n",
+    "            pred = pred[1: end]\n",
+    "            pred = pred.tolist()\n",
+    "            out = translator.tok.detokenize(pred)\n",
+    "            output.append(out)\n",
+    "        output = [output[indices.index(i)] for i in range(len(output))]\n",
+    "        for line in output:\n",
+    "            test_file.write(line)\n",
+    "            test_file.write('\\n')\n",
+    "        total_tokens = stats['total_dec_len'] + stats['total_enc_len']\n",
+    "    test_file.close()\n",
+    "    # run moses detokenizer\n",
+    "    detok_path = os.path.join(dataset_dir, config.DETOKENIZER)\n",
+    "    detok_test_path = test_path + '.detok'\n",
+    "\n",
+    "    with open(detok_test_path, 'w') as detok_test_file, \\\n",
+    "            open(test_path, 'r') as test_file:\n",
+    "        subprocess.run(['perl', detok_path], stdin=test_file,\n",
+    "                       stdout=detok_test_file, stderr=subprocess.DEVNULL)\n",
+    "    # run sacrebleu\n",
+    "    reference_path = os.path.join(dataset_dir,\n",
+    "                                  config.TGT_TEST_TARGET_FNAME)\n",
+    "    sacrebleu = subprocess.run(['sacrebleu --input {} {} --score-only -lc --tokenize intl'.\n",
+    "                                format(detok_test_path, reference_path)],\n",
+    "                               stdout=subprocess.PIPE, shell=True)\n",
+    "    bleu = float(sacrebleu.stdout.strip())\n",
+    "    print('BLEU on test dataset: {}'.format(bleu))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 24/24 [00:54<00:00,  2.03s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 22.16\n"
+     ]
+    }
+   ],
+   "source": [
+    "evaluate(model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Quantizing the model"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "As we already noted, we modified the model from `mlperf` to a modular implementation so we can quantize each and every operation in the graph.  \n",
+    "However, the default `nn.LSTM` was implemented in C++/CUDA, and we don't have usual access to it's operations hence we can't quantize it properly. This is why we'll convert the `nn.LSTM` to a `DistillerLSTM`, which is an entirely modular implementation of the LSTM - identical in functionality to the original `nn.LSTM`.  \n",
+    "This is done by simply calling `DistillerLSTM.from_pytorch_impl` for a single `nn.LSTM` and  \n",
+    "`convert_model_to_distiller_lstm` for an entire model containing multiple different LSTMs.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 24/24 [02:21<00:00,  5.05s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 22.16\n"
+     ]
+    }
+   ],
+   "source": [
+    "from distiller.modules import convert_model_to_distiller_lstm\n",
+    "model = convert_model_to_distiller_lstm(model)\n",
+    "evaluate(model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Collecting the statistics"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The quantizer uses statistics to define the range of the quantization. We collect these statistics using a `QuantCalibrationStatsCollector` instance like this:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 24/24 [58:48<00:00, 130.39s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 22.16\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context\n",
+    "\n",
+    "stats_file = './model_stats.yaml'\n",
+    "\n",
+    "if not os.path.isfile(stats_file): # Collect stats.\n",
+    "    model_copy = deepcopy(model)\n",
+    "    distiller.utils.assign_layer_fq_names(model_copy)\n",
+    "    collector = QuantCalibrationStatsCollector(model_copy)\n",
+    "    with collector_context(collector):\n",
+    "        val_loss = evaluate(model_copy, output + '.temp')\n",
+    "    collector.save(stats_file)\n",
+    "    del model_copy\n",
+    "    torch.cuda.empty_cache()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Defining the Quantizer"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "A distiller `Quantizer` object replaces each submodule in a model with its quantized counterpart, using a \n",
+    "`replacement_factory`.  \n",
+    "`Quantizer.replacement_factory` is a dictionary which maps from a module type (e.g. `nn.Linear` and `nn.Conv`) to a function. This function takes a module and quantization configuration, and returns a quantized version of the same module."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Replacing 'Matmul' modules using 'replace_non_param_layer' function\n",
+      "Replacing 'EltwiseMult' modules using 'replace_non_param_layer' function\n",
+      "Replacing 'EltwiseAdd' modules using 'replace_non_param_layer' function\n",
+      "Replacing 'BatchMatmul' modules using 'replace_non_param_layer' function\n",
+      "Replacing 'Linear' modules using 'replace_param_layer' function\n",
+      "Replacing 'Embedding' modules using 'replace_embedding' function\n",
+      "Replacing 'Concat' modules using 'replace_non_param_layer' function\n",
+      "Replacing 'Conv2d' modules using 'replace_param_layer' function\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Basic quantizer defintion\n",
+    "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n",
+    "                                    mode=\"SYMMETRIC\",  # As was suggested in GNMT's paper\n",
+    "                                    model_activation_stats=stats_file)\n",
+    "# We take a look at the replacement factory:\n",
+    "for t, rf in quantizer.replacement_factory.items():\n",
+    "    print(\"Replacing '{}' modules using '{}' function\".format(t.__name__, rf.__name__))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Quantizing the model"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This is done by simply calling `quantizer.prepare_model()`"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "WARNING: Logging before flag parsing goes to stderr.\n",
+      "W0807 15:01:57.225715 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n",
+      "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n",
+      "  'opportunities for optimization in the model will be ignored.', UserWarning)\n",
+      "\n",
+      "W0807 15:01:57.301426 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n",
+      "  UserWarning)\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "GNMT(\n",
+       "  (encoder): ResidualRecurrentEncoder(\n",
+       "    (rnn_layers): ModuleList(\n",
+       "      (0): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=True)\n",
+       "      (1): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "      (2): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "      (3): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "    )\n",
+       "    (dropout): Dropout(p=0.2)\n",
+       "    (embedder): RangeLinearEmbeddingWrapper(\n",
+       "      (wrapped_module): Embedding(32317, 1024, padding_idx=0)\n",
+       "    )\n",
+       "    (eltwiseadd_residuals): ModuleList(\n",
+       "      (0): RangeLinearQuantEltwiseAddWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=127.00788879394531, in_0_zero_point=0.0\n",
+       "        in_1_scale=127.00006103515625, in_1_zero_point=0.0\n",
+       "        out_scale=63.530452728271484, out_zero_point=0.0\n",
+       "        (wrapped_module): EltwiseAdd()\n",
+       "      )\n",
+       "      (1): RangeLinearQuantEltwiseAddWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=127.00004577636719, in_0_zero_point=0.0\n",
+       "        in_1_scale=63.530452728271484, in_1_zero_point=0.0\n",
+       "        out_scale=42.43059158325195, out_zero_point=0.0\n",
+       "        (wrapped_module): EltwiseAdd()\n",
+       "      )\n",
+       "    )\n",
+       "  )\n",
+       "  (decoder): ResidualRecurrentDecoder(\n",
+       "    (att_rnn): RecurrentAttention(\n",
+       "      (rnn): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "      (attn): BahdanauAttention(\n",
+       "        (linear_q): RangeLinearQuantParamLayerWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          w_scale=53.0649, w_zero_point=0.0000\n",
+       "          in_scale=127.0000, in_zero_point=0.0000\n",
+       "          out_scale=4.2200, out_zero_point=0.0000\n",
+       "          (wrapped_module): Linear(in_features=1024, out_features=1024, bias=False)\n",
+       "        )\n",
+       "        (linear_k): RangeLinearQuantParamLayerWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          w_scale=40.9795, w_zero_point=0.0000\n",
+       "          in_scale=42.4306, in_zero_point=0.0000\n",
+       "          out_scale=4.5466, out_zero_point=0.0000\n",
+       "          (wrapped_module): Linear(in_features=1024, out_features=1024, bias=False)\n",
+       "        )\n",
+       "        (dropout): Dropout(p=0)\n",
+       "        (eltwiseadd_qk): RangeLinearQuantEltwiseAddWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          in_0_scale=4.219996452331543, in_0_zero_point=0.0\n",
+       "          in_1_scale=4.546571731567383, in_1_zero_point=0.0\n",
+       "          out_scale=2.974982261657715, out_zero_point=0.0\n",
+       "          (wrapped_module): EltwiseAdd()\n",
+       "        )\n",
+       "        (eltwiseadd_norm_bias): RangeLinearQuantEltwiseAddWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          in_0_scale=2.974982261657715, in_0_zero_point=0.0\n",
+       "          in_1_scale=109.55178833007812, in_1_zero_point=0.0\n",
+       "          out_scale=2.961179256439209, out_zero_point=0.0\n",
+       "          (wrapped_module): EltwiseAdd()\n",
+       "        )\n",
+       "        (eltwisemul_norm_scaler): RangeLinearQuantEltwiseMultWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          in_0_scale=771.8616943359375, in_0_zero_point=0.0\n",
+       "          in_1_scale=100.56507873535156, in_1_zero_point=0.0\n",
+       "          out_scale=611.1994018554688, out_zero_point=0.0\n",
+       "          (wrapped_module): EltwiseMult()\n",
+       "        )\n",
+       "        (tanh): Tanh()\n",
+       "        (matmul_score): RangeLinearQuantMatmulWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8,  num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          in_0_scale=127.0000, in_0_zero_point=0.0000\n",
+       "          in_1_scale=611.1994, in_1_zero_point=0.0000\n",
+       "          out_scale=6.3274, out_zero_point=0.0000\n",
+       "          (wrapped_module): Matmul()\n",
+       "        )\n",
+       "        (softmax_att): Softmax()\n",
+       "        (context_matmul): RangeLinearQuantMatmulWrapper(\n",
+       "          mode=SYMMETRIC, num_bits_acts=8,  num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "          preset_activation_stats=True\n",
+       "          in_0_scale=128.1420, in_0_zero_point=0.0000\n",
+       "          in_1_scale=42.4306, in_1_zero_point=0.0000\n",
+       "          out_scale=46.2056, out_zero_point=0.0000\n",
+       "          (wrapped_module): BatchMatmul()\n",
+       "        )\n",
+       "      )\n",
+       "      (dropout): Dropout(p=0)\n",
+       "    )\n",
+       "    (rnn_layers): ModuleList(\n",
+       "      (0): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "      (1): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "      (2): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n",
+       "    )\n",
+       "    (embedder): RangeLinearEmbeddingWrapper(\n",
+       "      (wrapped_module): Embedding(32317, 1024, padding_idx=0)\n",
+       "    )\n",
+       "    (classifier): Classifier(\n",
+       "      (classifier): RangeLinearQuantParamLayerWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        w_scale=40.8311, w_zero_point=0.0000\n",
+       "        in_scale=42.8144, in_zero_point=0.0000\n",
+       "        out_scale=6.3242, out_zero_point=0.0000\n",
+       "        (wrapped_module): Linear(in_features=1024, out_features=32317, bias=True)\n",
+       "      )\n",
+       "    )\n",
+       "    (dropout): Dropout(p=0.2)\n",
+       "    (eltwiseadd_residuals): ModuleList(\n",
+       "      (0): RangeLinearQuantEltwiseAddWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=127.01639556884766, in_0_zero_point=0.0\n",
+       "        in_1_scale=127.00018310546875, in_1_zero_point=0.0\n",
+       "        out_scale=63.68953323364258, out_zero_point=0.0\n",
+       "        (wrapped_module): EltwiseAdd()\n",
+       "      )\n",
+       "      (1): RangeLinearQuantEltwiseAddWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=127.00001525878906, in_0_zero_point=0.0\n",
+       "        in_1_scale=63.68953323364258, in_1_zero_point=0.0\n",
+       "        out_scale=42.81440353393555, out_zero_point=0.0\n",
+       "        (wrapped_module): EltwiseAdd()\n",
+       "      )\n",
+       "    )\n",
+       "    (attention_concats): ModuleList(\n",
+       "      (0): RangeLinearQuantConcatWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=127.0, in_0_zero_point=0.0\n",
+       "        in_1_scale=46.20560836791992, in_1_zero_point=0.0\n",
+       "        out_scale=46.20560836791992, out_zero_point=0.0\n",
+       "        (wrapped_module): Concat()\n",
+       "      )\n",
+       "      (1): RangeLinearQuantConcatWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=127.00018310546875, in_0_zero_point=0.0\n",
+       "        in_1_scale=46.20560836791992, in_1_zero_point=0.0\n",
+       "        out_scale=46.20560836791992, out_zero_point=0.0\n",
+       "        (wrapped_module): Concat()\n",
+       "      )\n",
+       "      (2): RangeLinearQuantConcatWrapper(\n",
+       "        mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n",
+       "        preset_activation_stats=True\n",
+       "        in_0_scale=63.68953323364258, in_0_zero_point=0.0\n",
+       "        in_1_scale=46.20560836791992, in_1_zero_point=0.0\n",
+       "        out_scale=46.20560836791992, out_zero_point=0.0\n",
+       "        (wrapped_module): Concat()\n",
+       "      )\n",
+       "    )\n",
+       "  )\n",
+       ")"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "dummy_input = (torch.ones(1, 2).to(dtype=torch.long),\n",
+    "               torch.ones(1).to(dtype=torch.long),\n",
+    "               torch.ones(1, 2).to(dtype=torch.long))\n",
+    "quantizer.prepare_model(dummy_input)\n",
+    "quantizer.model"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "If you'd like to know how these functions replace the modules - I recommend reading the source code for them in  \n",
+    "`{DISTILLER_ROOT}/distiller/quantization/range_linear.py:PostTrainLinearQuantizer`."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Evaluating the quantized model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 24/24 [13:49<00:00, 31.91s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 18.05\n"
+     ]
+    }
+   ],
+   "source": [
+    "#torch.cuda.empty_cache()\n",
+    "evaluate(quantizer.model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Finding the right quantization"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "As we can see here, we quantized our model entirely and it lost some accuracy, so we want to apply more strategies to quantize better.  \n",
+    "Symmetric quantization means our range is the biggest we can hold our activations in:\n",
+    "$$\n",
+    "    M = \\max \\{ |\\text{acts}|\\},\\, \\text{range}_{symmetric} = [-M, M]\n",
+    "$$\n",
+    "This way we waste resolution. However, if we use assymetric quantization - we may get better results:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "W0807 15:15:53.630417 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n",
+      "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n",
+      "  'opportunities for optimization in the model will be ignored.', UserWarning)\n",
+      "\n",
+      "W0807 15:15:53.755373 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n",
+      "  UserWarning)\n",
+      "\n",
+      "100%|██████████| 24/24 [14:25<00:00, 28.42s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 18.52\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Basic quantizer defintion\n",
+    "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n",
+    "                                    mode=\"ASYMMETRIC_SIGNED\",  \n",
+    "                                    model_activation_stats=stats_file)\n",
+    "quantizer.prepare_model()\n",
+    "evaluate(quantizer.model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Here - we quantized asymmetrically, meaning our range still holds all the activations, but it's smaller than in the symmetrical case.  \n",
+    "The formula is:\n",
+    "$$\n",
+    "    \\text{range}_{asymmetric} = \\left[\\min\\{ \\text{acts}\\}, \\max \\{ \\text{acts}\\}\\right] \n",
+    "    \\subset \\text{range}_{symmetric}\n",
+    "$$\n",
+    "And we indeed got a slightly better result.  \n",
+    "However - some part of the activations during the evaluations are outliers, meaning they are way outside the range of most of their buddies. We're going to intercept this in two ways -\n",
+    "1. Quantize each channel separately, that way we achieve more accuracy. We'll add the argument `per_channel_wts=True`.\n",
+    "2. Limit the quantization range to a smaller one, thus clamping these outliers."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We'll try using the same technique as in `quantize_lstm.ipynb` - clipping the activations according to the average range recorded for them:  \n",
+    "\n",
+    "\n",
+    "$$\n",
+    "    m = \\underset{b\\in\\text{batches}}{\\text{avg}}\\left\\{\\min_{b}\\{\\text{acts}\\}\\right\\},\\,\n",
+    "    M = \\underset{b\\in\\text{batches}}{\\text{avg}}\\left\\{\\max_{b}\\{\\text{acts}\\}\\right\\}\n",
+    "$$\n",
+    "\n",
+    "\n",
+    "$$\n",
+    "    \\text{range}_{clipped} = [m,M] \\subset \\text{range}_{asymmetric} \\subset \\text{range}_{symmetric}\n",
+    "$$\n",
+    "\n",
+    "This is done by specifying `clip_acts=\"AVG\"` in the quantizer. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "W0807 15:30:25.722338 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n",
+      "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n",
+      "  'opportunities for optimization in the model will be ignored.', UserWarning)\n",
+      "\n",
+      "W0807 15:30:26.860130 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n",
+      "  UserWarning)\n",
+      "\n",
+      "100%|██████████| 24/24 [13:58<00:00, 29.49s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 9.63\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Basic quantizer defintion\n",
+    "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n",
+    "                                    mode=\"ASYMMETRIC_SIGNED\",  \n",
+    "                                    model_activation_stats=stats_file,\n",
+    "                                    per_channel_wts=True,\n",
+    "                                    clip_acts=\"AVG\")\n",
+    "quantizer.prepare_model()\n",
+    "evaluate(quantizer.model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Oh no! This is bad... turns out that by clamping the outliers we actually \"removed\" useful features from important layers like the attention layer. In the attention layer we have a softmax which relies on high values to pass a correct score of importance of features. Let's try clipping all the other values, except in the attention layer:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "W0807 15:44:32.041999 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n",
+      "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n",
+      "  'opportunities for optimization in the model will be ignored.', UserWarning)\n",
+      "\n",
+      "W0807 15:44:33.205797 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n",
+      "  UserWarning)\n",
+      "\n",
+      "100%|██████████| 24/24 [13:54<00:00, 32.72s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 16.94\n"
+     ]
+    }
+   ],
+   "source": [
+    "# No clipping in the attention layer\n",
+    "overrides_yaml = \"\"\"\n",
+    ".*att_rnn.attn.*:\n",
+    "    clip_acts: NONE # Quantize without clipping\n",
+    "\"\"\"\n",
+    "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n",
+    "# Basic quantizer defintion\n",
+    "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n",
+    "                                    mode=\"ASYMMETRIC_SIGNED\",  \n",
+    "                                    model_activation_stats=stats_file,\n",
+    "                                    overrides=overrides,\n",
+    "                                    per_channel_wts=True,\n",
+    "                                    clip_acts=\"AVG\")\n",
+    "quantizer.prepare_model()\n",
+    "evaluate(quantizer.model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The accuracy is somewhat \"restored\", by still we would like to get a score as close to the original model as possible. How about leaving the `classifier` asymmetric, without clipping it?"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "W0807 15:58:34.769241 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n",
+      "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n",
+      "  'opportunities for optimization in the model will be ignored.', UserWarning)\n",
+      "\n",
+      "W0807 15:58:35.894991 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n",
+      "  UserWarning)\n",
+      "\n",
+      "100%|██████████| 24/24 [13:14<00:00, 28.11s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLEU on test dataset: 21.49\n"
+     ]
+    }
+   ],
+   "source": [
+    "# No clipping in the attention layer and in the final classifier\n",
+    "overrides_yaml = \"\"\"\n",
+    ".*att_rnn.attn.*:\n",
+    "    clip_acts: NONE # Quantize without clipping\n",
+    "decoder.classifier.classifier:\n",
+    "    clip_acts: NONE # Quantize without clipping\n",
+    "\"\"\"\n",
+    "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n",
+    "# Basic quantizer defintion\n",
+    "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n",
+    "                                    mode=\"ASYMMETRIC_SIGNED\",  \n",
+    "                                    model_activation_stats=stats_file,\n",
+    "                                    overrides=overrides,\n",
+    "                                    per_channel_wts=True,\n",
+    "                                    clip_acts=\"AVG\")\n",
+    "quantizer.prepare_model()\n",
+    "evaluate(quantizer.model, output)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Finally, some good results! So now we know better which layers are sensitive to clipping and which are complimented by it.  "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# References\n",
+    "\n",
+    "<a id=\"cite-wu2016google\"/><sup><a href=#ref-1>[^]</a></sup>Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and Macherey, Klaus and others. 2016. _Google's neural machine translation system: Bridging the gap between human and machine translation_.\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<!--bibtex\n",
+    "\n",
+    "@article{wu2016google,\n",
+    "  title={Google's neural machine translation system: Bridging the gap between human and machine translation},\n",
+    "  author={Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and Macherey, Klaus and others},\n",
+    "  journal={arXiv preprint arXiv:1609.08144},\n",
+    "  year={2016}\n",
+    "}\n",
+    "\n",
+    "-->"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.5.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/GNMT/requirements.txt b/examples/GNMT/requirements.txt
new file mode 100644
index 0000000..5cf238b
--- /dev/null
+++ b/examples/GNMT/requirements.txt
@@ -0,0 +1,3 @@
+sacrebleu==1.2.10
+# numpy==1.14.2
+# mlperf-compliance==0.0.4
diff --git a/examples/GNMT/scripts/filter_dataset.py b/examples/GNMT/scripts/filter_dataset.py
new file mode 100644
index 0000000..3168d47
--- /dev/null
+++ b/examples/GNMT/scripts/filter_dataset.py
@@ -0,0 +1,79 @@
+import argparse
+from collections import Counter
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Clean dataset')
+    parser.add_argument('-f1', '--file1', help='file1')
+    parser.add_argument('-f2', '--file2', help='file2')
+    return parser.parse_args()
+
+
+def save_output(fname, data):
+    with open(fname, 'w') as f:
+        f.writelines(data)
+
+
+def main():
+    """
+    Discards all pairs of sentences which can't be decoded by latin-1 encoder.
+
+    It aims to filter out sentences with rare unicode glyphs and pairs which
+    are most likely not valid English-German sentences.
+
+    Examples of discarded sentences:
+
+        ✿★★★Hommage au king de la pop ★★★✿ ✿★★★Que son âme repos...
+
+        Для их осуществления нам, прежде всего, необходимо преодолеть
+        возражения рыночных фундаменталистов, которые хотят ликвидировать или
+        уменьшить роль МВФ.
+
+        practised as a scientist in various medical departments of the ⇗Medical
+        University of Hanover , the ⇗University of Ulm , and the ⇗RWTH Aachen
+        (rheumatology, pharmacology, physiology, pathology, microbiology,
+        immunology and electron-microscopy).
+
+        The same shift】 and press 【】 【alt out with a smaller diameter
+        circle.
+
+        Brought to you by ABMSUBS ♥leira(Coordinator/Translator)
+        ♥chibichan93(Timer/Typesetter) ♥ja...
+
+        Some examples: &0u - ☺ &0U - ☻ &tel - ☏ &PI - ¶ &SU - ☼ &cH- - ♥ &M2=♫
+        &sn - ﺵ SGML maps SGML to unicode.
+    """
+    args = parse_args()
+
+    c = Counter()
+    skipped = 0
+    valid = 0
+    data1 = []
+    data2 = []
+
+    with open(args.file1) as f1, open(args.file2) as f2:
+        for idx, lines in enumerate(zip(f1, f2)):
+            line1, line2 = lines
+            if idx % 100000 == 1:
+                print('Processed {} lines'.format(idx))
+            try:
+                line1.encode('latin1')
+                line2.encode('latin1')
+            except UnicodeEncodeError:
+                skipped += 1
+            else:
+                data1.append(line1)
+                data2.append(line2)
+                valid += 1
+                c.update(line1)
+
+    ratio = valid / (skipped + valid)
+    print('Skipped: {}, Valid: {}, Valid ratio {}'.format(skipped, valid, ratio))
+    print('Character frequency:', c)
+
+    save_output(args.file1, data1)
+    save_output(args.file2, data2)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/GNMT/seq2seq/__init__.py b/examples/GNMT/seq2seq/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/GNMT/seq2seq/data/__init__.py b/examples/GNMT/seq2seq/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/GNMT/seq2seq/data/config.py b/examples/GNMT/seq2seq/data/config.py
new file mode 100644
index 0000000..69af2a1
--- /dev/null
+++ b/examples/GNMT/seq2seq/data/config.py
@@ -0,0 +1,21 @@
+PAD_TOKEN = '<pad>'
+UNK_TOKEN = '<unk>'
+BOS_TOKEN = '<s>'
+EOS_TOKEN = '<\s>'
+
+PAD, UNK, BOS, EOS = [0, 1, 2, 3]
+
+VOCAB_FNAME = 'vocab.bpe.32000'
+
+SRC_TRAIN_FNAME = 'train.tok.clean.bpe.32000.en'
+TGT_TRAIN_FNAME = 'train.tok.clean.bpe.32000.de'
+
+SRC_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.en'
+TGT_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.de'
+
+SRC_TEST_FNAME = 'newstest2014.tok.bpe.32000.en'
+TGT_TEST_FNAME = 'newstest2014.tok.bpe.32000.de'
+
+TGT_TEST_TARGET_FNAME = 'newstest2014.de'
+
+DETOKENIZER = 'mosesdecoder/scripts/tokenizer/detokenizer.perl'
diff --git a/examples/GNMT/seq2seq/data/dataset.py b/examples/GNMT/seq2seq/data/dataset.py
new file mode 100644
index 0000000..0bf39b4
--- /dev/null
+++ b/examples/GNMT/seq2seq/data/dataset.py
@@ -0,0 +1,123 @@
+import logging
+
+import torch
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import SequentialSampler, RandomSampler
+from seq2seq.data.sampler import BucketingSampler
+from torch.utils.data import DataLoader
+
+import seq2seq.data.config as config
+
+
+def build_collate_fn(batch_first=False, sort=False):
+    def collate_seq(seq):
+        lengths = [len(s) for s in seq]
+        batch_length = max(lengths)
+
+        shape = (batch_length, len(seq))
+        seq_tensor = torch.full(shape, config.PAD, dtype=torch.int64)
+
+        for i, s in enumerate(seq):
+            end_seq = lengths[i]
+            seq_tensor[:end_seq, i].copy_(s[:end_seq])
+
+        if batch_first:
+            seq_tensor = seq_tensor.t()
+
+        return (seq_tensor, lengths)
+
+    def collate(seqs):
+        src_seqs, tgt_seqs = zip(*seqs)
+        if sort:
+            key = lambda item: len(item[1])
+            indices, src_seqs = zip(*sorted(enumerate(src_seqs), key=key,
+                                        reverse=True))
+            tgt_seqs = [tgt_seqs[idx] for idx in indices]
+        else:
+            indices = range(len(src_seqs))
+
+        return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]] + [indices])
+
+    return collate
+
+
+class ParallelDataset(Dataset):
+    def __init__(self, src_fname, tgt_fname, tokenizer,
+                 min_len, max_len, sort=False, max_size=None):
+
+        self.min_len = min_len
+        self.max_len = max_len
+
+        self.src = self.process_data(src_fname, tokenizer, max_size)
+        self.tgt = self.process_data(tgt_fname, tokenizer, max_size)
+        assert len(self.src) == len(self.tgt)
+
+        self.filter_data(min_len, max_len)
+        assert len(self.src) == len(self.tgt)
+
+        lengths = [len(s) + len(t) for (s, t) in zip(self.src, self.tgt)]
+        self.lengths = torch.tensor(lengths)
+
+        if sort:
+            self.sort_by_length()
+
+    def sort_by_length(self):
+        self.lengths, indices = self.lengths.sort(descending=True)
+
+        self.src = [self.src[idx] for idx in indices]
+        self.tgt = [self.tgt[idx] for idx in indices]
+
+    def filter_data(self, min_len, max_len):
+        logging.info('filtering data, min len: {}, max len: {}'.format(min_len, max_len))
+
+        initial_len = len(self.src)
+
+        filtered_src = []
+        filtered_tgt = []
+        for src, tgt in zip(self.src, self.tgt):
+            if min_len <= len(src) <= max_len and \
+                    min_len <= len(tgt) <= max_len:
+                filtered_src.append(src)
+                filtered_tgt.append(tgt)
+
+        self.src = filtered_src
+        self.tgt = filtered_tgt
+
+        filtered_len = len(self.src)
+        logging.info('pairs before: {}, after: {}'.format(initial_len, filtered_len))
+
+    def process_data(self, fname, tokenizer, max_size):
+        logging.info('processing data from {}'.format(fname))
+        data = []
+        with open(fname) as dfile:
+            for idx, line in enumerate(dfile):
+                if max_size and idx == max_size:
+                    break
+                entry = tokenizer.segment(line)
+                entry = torch.tensor(entry)
+                data.append(entry)
+        return data
+
+    def __len__(self):
+        return len(self.src)
+
+    def __getitem__(self, idx):
+        return self.src[idx], self.tgt[idx]
+
+    def get_loader(self, batch_size=1, shuffle=False, num_workers=0, batch_first=False,
+                   drop_last=False, distributed=False, bucket=True):
+
+        collate_fn = build_collate_fn(batch_first, sort=True)
+
+        if shuffle:
+            sampler = BucketingSampler(self, batch_size, bucket)
+        else:
+            sampler = SequentialSampler(self)
+
+        return DataLoader(self,
+                          batch_size=batch_size,
+                          collate_fn=collate_fn,
+                          sampler=sampler,
+                          num_workers=num_workers,
+                          pin_memory=False,
+                          drop_last=drop_last)
diff --git a/examples/GNMT/seq2seq/data/sampler.py b/examples/GNMT/seq2seq/data/sampler.py
new file mode 100644
index 0000000..74085c4
--- /dev/null
+++ b/examples/GNMT/seq2seq/data/sampler.py
@@ -0,0 +1,85 @@
+import torch
+from torch.utils.data.sampler import Sampler
+
+from seq2seq.utils import get_world_size, get_rank
+
+
+class BucketingSampler(Sampler):
+
+    def __init__(self, dataset, batch_size, bucket=True, world_size=None, rank=None):
+        if world_size is None:
+            world_size = get_world_size()
+        if rank is None:
+            rank = get_rank()
+
+        self.dataset = dataset
+        self.world_size = world_size
+        self.rank = rank
+        self.epoch = 0
+        self.bucket = bucket
+
+        self.batch_size = batch_size
+        self.global_batch_size = batch_size * world_size
+
+        self.data_len = len(self.dataset)
+        self.num_samples = self.data_len // self.global_batch_size \
+            * self.global_batch_size
+
+    def __iter__(self):
+
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+
+        # generate permutation
+        indices = torch.randperm(self.data_len, generator=g)
+        # make indices evenly divisible by (batch_size * world_size)
+        indices = indices[:self.num_samples]
+
+
+        if self.bucket:
+            # begin shards
+            batches_in_shard = 80
+            shard_size = self.global_batch_size * batches_in_shard
+            nshards = (self.num_samples + shard_size - 1) // shard_size
+
+            lengths = self.dataset.lengths[indices]
+
+            shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)]
+            len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)]
+
+            indices = []
+            for len_shard in len_shards:
+                _, ind = len_shard.sort()
+                indices.append(ind)
+
+            output = tuple(shard[idx] for shard,idx in zip(shards, indices))
+            indices = torch.cat(output)
+            # global reshuffle
+            indices = indices.view(-1, self.global_batch_size)
+            order = torch.randperm(indices.shape[0], generator=g)
+            indices = indices[order, :]
+            indices = indices.view(-1)
+            # end shards
+
+
+        assert len(indices) == self.num_samples
+
+        # build indices for each individual worker
+        # ranks are getting consecutive batches,
+        # default pytorch DistributedSampler assigns strided batches
+        # with offset = length / world_size
+        indices = indices.view(-1, self.batch_size)
+        indices = indices[self.rank::self.world_size].contiguous()
+        indices = indices.view(-1)
+        indices = indices.tolist()
+
+        assert len(indices) == self.num_samples // self.world_size
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples // self.world_size
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
diff --git a/examples/GNMT/seq2seq/data/tokenizer.py b/examples/GNMT/seq2seq/data/tokenizer.py
new file mode 100644
index 0000000..09b827a
--- /dev/null
+++ b/examples/GNMT/seq2seq/data/tokenizer.py
@@ -0,0 +1,44 @@
+import logging
+from collections import defaultdict
+
+import seq2seq.data.config as config
+
+def default():
+    return config.UNK
+
+class Tokenizer:
+    def __init__(self, vocab_fname, separator='@@'):
+
+        self.separator = separator
+
+        logging.info('building vocabulary from {}'.format(vocab_fname))
+        vocab = [config.PAD_TOKEN, config.UNK_TOKEN,
+                 config.BOS_TOKEN, config.EOS_TOKEN]
+
+        with open(vocab_fname) as vfile:
+            for line in vfile:
+                vocab.append(line.strip())
+
+        logging.info('size of vocabulary: {}'.format(len(vocab)))
+        self.vocab_size = len(vocab)
+
+
+        self.tok2idx = defaultdict(default)
+        for idx, token in enumerate(vocab):
+            self.tok2idx[token] = idx
+
+        self.idx2tok = {}
+        for key, value in self.tok2idx.items():
+            self.idx2tok[value] = key
+
+    def segment(self, line):
+        line = line.strip().split()
+        entry = [self.tok2idx[i] for i in line]
+        entry = [config.BOS] + entry + [config.EOS]
+        return entry
+
+    def detokenize(self, inputs, delim=' '):
+        detok = delim.join([self.idx2tok[idx] for idx in inputs])
+        detok = detok.replace(
+            self.separator+ ' ', '').replace(self.separator, '')
+        return detok
diff --git a/examples/GNMT/seq2seq/inference/__init__.py b/examples/GNMT/seq2seq/inference/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/GNMT/seq2seq/inference/beam_search.py b/examples/GNMT/seq2seq/inference/beam_search.py
new file mode 100644
index 0000000..90c46b7
--- /dev/null
+++ b/examples/GNMT/seq2seq/inference/beam_search.py
@@ -0,0 +1,245 @@
+import torch
+
+from seq2seq.data.config import BOS
+from seq2seq.data.config import EOS
+
+
+class SequenceGenerator(object):
+    def __init__(self,
+                 model,
+                 beam_size=5,
+                 max_seq_len=100,
+                 cuda=False,
+                 len_norm_factor=0.6,
+                 len_norm_const=5,
+                 cov_penalty_factor=0.1):
+
+        self.model = model
+        self.cuda = cuda
+        self.beam_size = beam_size
+        self.max_seq_len = max_seq_len
+        self.len_norm_factor = len_norm_factor
+        self.len_norm_const = len_norm_const
+        self.cov_penalty_factor = cov_penalty_factor
+
+        self.batch_first = self.model.batch_first
+
+    def greedy_search(self, batch_size, initial_input, initial_context=None):
+        max_seq_len = self.max_seq_len
+
+        translation = torch.zeros(batch_size, max_seq_len, dtype=torch.int64)
+        lengths = torch.ones(batch_size, dtype=torch.int64)
+        active = torch.arange(0, batch_size, dtype=torch.int64)
+        base_mask = torch.arange(0, batch_size, dtype=torch.int64)
+
+        if self.cuda:
+            translation = translation.cuda()
+            lengths = lengths.cuda()
+            active = active.cuda()
+            base_mask = base_mask.cuda()
+
+        translation[:, 0] = BOS
+        words, context = initial_input, initial_context
+
+        if self.batch_first:
+            word_view = (-1, 1)
+            ctx_batch_dim = 0
+        else:
+            word_view = (1, -1)
+            ctx_batch_dim = 1
+
+        counter = 0
+        for idx in range(1, max_seq_len):
+            if not len(active):
+                break
+            counter += 1
+
+            words = words.view(word_view)
+            words, logprobs, attn, context = self.model.generate(words, context, 1)
+            words = words.view(-1)
+
+            translation[active, idx] = words
+            lengths[active] += 1
+
+            terminating = (words == EOS)
+
+            if terminating.any():
+                not_terminating = ~terminating
+
+                mask = base_mask[:len(active)]
+                mask = mask.masked_select(not_terminating)
+                active = active.masked_select(not_terminating)
+
+                words = words[mask]
+                context[0] = context[0].index_select(ctx_batch_dim, mask)
+                context[1] = context[1].index_select(0, mask)
+                context[2] = context[2].index_select(1, mask)
+
+        return translation, lengths, counter
+
+    def beam_search(self, batch_size, initial_input, initial_context=None):
+        beam_size = self.beam_size
+        norm_const = self.len_norm_const
+        norm_factor = self.len_norm_factor
+        max_seq_len = self.max_seq_len
+        cov_penalty_factor = self.cov_penalty_factor
+
+        translation = torch.zeros(batch_size * beam_size, max_seq_len, dtype=torch.int64)
+        lengths = torch.ones(batch_size * beam_size, dtype=torch.int64)
+        scores = torch.zeros(batch_size * beam_size, dtype=torch.float32)
+
+        active = torch.arange(0, batch_size * beam_size, dtype=torch.int64)
+        base_mask = torch.arange(0, batch_size * beam_size, dtype=torch.int64)
+        global_offset = torch.arange(0, batch_size * beam_size, beam_size, dtype=torch.int64)
+
+        eos_beam_fill = torch.tensor([0] + (beam_size - 1) * [float('-inf')])
+
+        if self.cuda:
+            translation = translation.cuda()
+            lengths = lengths.cuda()
+            active = active.cuda()
+            base_mask = base_mask.cuda()
+            scores = scores.cuda()
+            global_offset = global_offset.cuda()
+            eos_beam_fill = eos_beam_fill.cuda()
+
+        translation[:, 0] = BOS
+
+        words, context = initial_input, initial_context
+
+        if self.batch_first:
+            word_view = (-1, 1)
+            ctx_batch_dim = 0
+            attn_query_dim = 1
+        else:
+            word_view = (1, -1)
+            ctx_batch_dim = 1
+            attn_query_dim = 0
+
+        # replicate context
+        if self.batch_first:
+            # context[0] (encoder state): (batch, seq, feature)
+            _, seq, feature = context[0].shape
+            context[0] = context[0].unsqueeze(1)
+            context[0] = context[0].expand(-1, beam_size, -1, -1)
+            context[0] = context[0].contiguous().view(batch_size * beam_size, seq, feature)
+            # context[0]: (batch * beam, seq, feature)
+        else:
+            # context[0] (encoder state): (seq, batch, feature)
+            seq, _, feature = context[0].shape
+            context[0] = context[0].unsqueeze(2)
+            context[0] = context[0].expand(-1, -1, beam_size, -1)
+            context[0] = context[0].contiguous().view(seq, batch_size * beam_size, feature)
+            # context[0]: (seq, batch * beam,  feature)
+
+        #context[1] (encoder seq length): (batch)
+        context[1] = context[1].unsqueeze(1)
+        context[1] = context[1].expand(-1, beam_size)
+        context[1] = context[1].contiguous().view(batch_size * beam_size)
+        #context[1]: (batch * beam)
+
+        accu_attn_scores = torch.zeros(batch_size * beam_size, seq)
+        if self.cuda:
+            accu_attn_scores = accu_attn_scores.cuda()
+
+        counter = 0
+        for idx in range(1, self.max_seq_len):
+            if not len(active):
+                break
+            counter += 1
+
+            eos_mask = (words == EOS)
+            eos_mask = eos_mask.view(-1, beam_size)
+
+            terminating, _ = eos_mask.min(dim=1)
+
+            lengths[active[~eos_mask.view(-1)]] += 1
+
+            words, logprobs, attn, context = self.model.generate(words, context, beam_size)
+
+            attn = attn.float().squeeze(attn_query_dim)
+            attn = attn.masked_fill(eos_mask.view(-1).unsqueeze(1), 0)
+            accu_attn_scores[active] += attn
+
+            # words: (batch, beam, k)
+            words = words.view(-1, beam_size, beam_size)
+            words = words.masked_fill(eos_mask.unsqueeze(2), EOS)
+
+            # logprobs: (batch, beam, k)
+            logprobs = logprobs.float().view(-1, beam_size, beam_size)
+
+            if eos_mask.any():
+                logprobs[eos_mask] = eos_beam_fill
+
+            active_scores = scores[active].view(-1, beam_size)
+            # new_scores: (batch, beam, k)
+            new_scores = active_scores.unsqueeze(2) + logprobs
+
+            if idx == 1:
+                new_scores[:, 1:, :].fill_(float('-inf'))
+
+            new_scores = new_scores.view(-1, beam_size * beam_size)
+            # index: (batch, beam)
+            _, index = new_scores.topk(beam_size, dim=1)
+            source_beam = index / beam_size
+
+            new_scores = new_scores.view(-1, beam_size * beam_size)
+            best_scores = torch.gather(new_scores, 1, index)
+            scores[active] = best_scores.view(-1)
+
+            words = words.view(-1, beam_size * beam_size)
+            words = torch.gather(words, 1, index)
+
+            # words: (1, batch * beam)
+            words = words.view(word_view)
+
+            offset = global_offset[:source_beam.shape[0]]
+            source_beam += offset.unsqueeze(1)
+
+            translation[active, :] = translation[active[source_beam.view(-1)], :]
+            translation[active, idx] = words.view(-1)
+
+            lengths[active] = lengths[active[source_beam.view(-1)]]
+
+            context[2] = context[2].index_select(1, source_beam.view(-1))
+
+            if terminating.any():
+                not_terminating = ~terminating
+                not_terminating = not_terminating.unsqueeze(1)
+                not_terminating = not_terminating.expand(-1, beam_size).contiguous()
+
+                normalization_mask = active.view(-1, beam_size)[terminating]
+
+                # length normalization
+                norm = lengths[normalization_mask].float()
+                norm = (norm_const + norm) / (norm_const + 1.0)
+                norm = norm ** norm_factor
+
+                scores[normalization_mask] /= norm
+
+                # coverage penalty
+                penalty = accu_attn_scores[normalization_mask]
+                penalty = penalty.clamp(0, 1)
+                penalty = penalty.log()
+                penalty[penalty == float('-inf')] = 0
+                penalty = penalty.sum(dim=-1)
+
+                scores[normalization_mask] += cov_penalty_factor * penalty
+
+                mask = base_mask[:len(active)]
+                mask = mask.masked_select(not_terminating.view(-1))
+
+                words = words.index_select(ctx_batch_dim, mask)
+                context[0] = context[0].index_select(ctx_batch_dim, mask)
+                context[1] = context[1].index_select(0, mask)
+                context[2] = context[2].index_select(1, mask)
+
+                active = active.masked_select(not_terminating.view(-1))
+
+        scores = scores.view(batch_size, beam_size)
+        _, idx = scores.max(dim=1)
+
+        translation = translation[idx + global_offset, :]
+        lengths = lengths[idx + global_offset]
+
+        return translation, lengths, counter
diff --git a/examples/GNMT/seq2seq/inference/inference.py b/examples/GNMT/seq2seq/inference/inference.py
new file mode 100644
index 0000000..e0ddbd9
--- /dev/null
+++ b/examples/GNMT/seq2seq/inference/inference.py
@@ -0,0 +1,88 @@
+import torch
+
+from seq2seq.data.config import BOS
+from seq2seq.data.config import EOS
+from seq2seq.inference.beam_search import SequenceGenerator
+from seq2seq.utils import batch_padded_sequences
+
+
+class Translator(object):
+
+    def __init__(self, model, tok,
+                 beam_size=5,
+                 len_norm_factor=0.6,
+                 len_norm_const=5.0,
+                 cov_penalty_factor=0.1,
+                 max_seq_len=50,
+                 cuda=False):
+
+        self.model = model
+        self.tok = tok
+        self.insert_target_start = [BOS]
+        self.insert_src_start = [BOS]
+        self.insert_src_end = [EOS]
+        self.batch_first = model.batch_first
+        self.cuda = cuda
+        self.beam_size = beam_size
+
+        self.generator = SequenceGenerator(
+            model=self.model,
+            beam_size=beam_size,
+            max_seq_len=max_seq_len,
+            cuda=cuda,
+            len_norm_factor=len_norm_factor,
+            len_norm_const=len_norm_const,
+            cov_penalty_factor=cov_penalty_factor)
+
+    def translate(self, input_sentences):
+        stats = {}
+        batch_size = len(input_sentences)
+        beam_size = self.beam_size
+
+        src_tok = [torch.tensor(self.tok.segment(line)) for line in input_sentences]
+
+        bos = [self.insert_target_start] * (batch_size * beam_size)
+        bos = torch.LongTensor(bos)
+        if self.batch_first:
+            bos = bos.view(-1, 1)
+        else:
+            bos = bos.view(1, -1)
+
+        src = batch_padded_sequences(src_tok, self.batch_first, sort=True)
+        src, src_length, indices = src
+
+        src_length = torch.LongTensor(src_length)
+        stats['total_enc_len'] = int(src_length.sum())
+
+        if self.cuda:
+            src = src.cuda()
+            src_length = src_length.cuda()
+            bos = bos.cuda()
+
+        with torch.no_grad():
+            context = self.model.encode(src, src_length)
+            context = [context, src_length, None]
+
+            if beam_size == 1:
+                generator = self.generator.greedy_search
+            else:
+                generator = self.generator.beam_search
+
+            preds, lengths, counter = generator(batch_size, bos, context)
+
+        preds = preds.cpu()
+        lengths = lengths.cpu()
+
+        output = []
+        for idx, pred in enumerate(preds):
+            end = lengths[idx] - 1
+            pred = pred[1: end]
+            pred = pred.tolist()
+            out = self.tok.detokenize(pred)
+            output.append(out)
+
+        stats['total_dec_len'] = int(lengths.sum())
+        stats['iters'] = counter
+
+        output = [output[indices.index(i)] for i in range(len(output))]
+        return output, stats
diff --git a/examples/GNMT/seq2seq/models/__init__.py b/examples/GNMT/seq2seq/models/__init__.py
new file mode 100644
index 0000000..6217b23
--- /dev/null
+++ b/examples/GNMT/seq2seq/models/__init__.py
@@ -0,0 +1,4 @@
+from .seq2seq_base import Seq2Seq
+from .gnmt import GNMT, ResidualRecurrentDecoder, ResidualRecurrentEncoder
+
+__all__ = ['GNMT']
diff --git a/examples/GNMT/seq2seq/models/attention.py b/examples/GNMT/seq2seq/models/attention.py
new file mode 100644
index 0000000..3a5a0ee
--- /dev/null
+++ b/examples/GNMT/seq2seq/models/attention.py
@@ -0,0 +1,164 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+from distiller.modules import *
+
+
+class BahdanauAttention(nn.Module):
+    """
+    It should be very similar to tf.contrib.seq2seq.BahdanauAttention
+    """
+
+    def __init__(self, query_size, key_size, num_units, normalize=False,
+                 dropout=0, batch_first=False):
+
+        super(BahdanauAttention, self).__init__()
+
+        self.normalize = normalize
+        self.batch_first = batch_first
+        self.num_units = num_units
+
+        self.linear_q = nn.Linear(query_size, num_units, bias=False)
+        self.linear_k = nn.Linear(key_size, num_units, bias=False)
+
+        self.linear_att = Parameter(torch.Tensor(num_units))
+
+        self.dropout = nn.Dropout(dropout)
+        self.mask = None
+
+        # Adding submodules for basic ops to allow quantization:
+        self.eltwiseadd_qk = EltwiseAdd()
+        self.eltwiseadd_norm_bias = EltwiseAdd()
+        self.eltwisemul_norm_scaler = EltwiseMult()
+        self.tanh = nn.Tanh()
+        self.matmul_score = Matmul()
+        self.softmax_att = nn.Softmax(dim=-1)
+        self.context_matmul = BatchMatmul()
+
+        if self.normalize:
+            self.normalize_scalar = Parameter(torch.Tensor(1))
+            self.normalize_bias = Parameter(torch.Tensor(num_units))
+        else:
+            self.register_parameter('normalize_scalar', None)
+            self.register_parameter('normalize_bias', None)
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        stdv = 1. / math.sqrt(self.num_units)
+        self.linear_att.data.uniform_(-stdv, stdv)
+
+        if self.normalize:
+            self.normalize_scalar.data.fill_(stdv)
+            self.normalize_bias.data.zero_()
+
+    def set_mask(self, context_len, context):
+        """
+        sets self.mask which is applied before softmax
+        ones for inactive context fields, zeros for active context fields
+
+        :param context_len: b
+        :param context: if batch_first: (b x t_k x n) else: (t_k x b x n)
+
+        self.mask: (b x t_k)
+        """
+
+        if self.batch_first:
+            max_len = context.size(1)
+        else:
+            max_len = context.size(0)
+
+        indices = torch.arange(0, max_len, dtype=torch.int64, device=context.device)
+        self.mask = indices >= (context_len.unsqueeze(1))
+
+    def calc_score(self, att_query, att_keys):
+        """
+        Calculate Bahdanau score
+
+        :param att_query: b x t_q x n
+        :param att_keys: b x t_k x n
+
+        return b x t_q x t_k scores
+        """
+
+        b, t_k, n = att_keys.size()
+        t_q = att_query.size(1)
+
+        att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
+        att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
+        sum_qk = self.eltwiseadd_qk(att_query, att_keys)
+
+        if self.normalize:
+            sum_qk = self.eltwiseadd_norm_bias(sum_qk, self.normalize_bias)
+
+            tmp = self.linear_att.to(torch.float32)
+            linear_att = tmp / tmp.norm()
+            linear_att = linear_att.to(self.normalize_scalar)
+
+            linear_att = self.eltwisemul_norm_scaler(linear_att, self.normalize_scalar)
+        else:
+            linear_att = self.linear_att
+
+        out = self.matmul_score(self.tanh(sum_qk),linear_att)
+        return out
+
+    def forward(self, query, keys):
+        """
+
+        :param query: if batch_first: (b x t_q x n) else: (t_q x b x n)
+        :param keys: if batch_first: (b x t_k x n) else (t_k x b x n)
+
+        :returns: (context, scores_normalized)
+        context: if batch_first: (b x t_q x n) else (t_q x b x n)
+        scores_normalized: if batch_first (b x t_q x t_k) else (t_q x b x t_k)
+        """
+
+        # first dim of keys and query has to be 'batch', it's needed for bmm
+        if not self.batch_first:
+            keys = keys.transpose(0, 1)
+            if query.dim() == 3:
+                query = query.transpose(0, 1)
+
+        if query.dim() == 2:
+            single_query = True
+            query = query.unsqueeze(1)
+        else:
+            single_query = False
+
+        b = query.size(0)
+        t_k = keys.size(1)
+        t_q = query.size(1)
+
+        # FC layers to transform query and key
+        processed_query = self.linear_q(query)
+        # TODO move this out of decoder for efficiency during inference
+        processed_key = self.linear_k(keys)
+
+        # scores: (b x t_q x t_k)
+        scores = self.calc_score(processed_query, processed_key)
+
+        if self.mask is not None:
+            mask = self.mask.unsqueeze(1).expand(b, t_q, t_k)
+            # TODO I can't use -INF because of overflow check in pytorch
+            scores.data.masked_fill_(mask, -65504.0)
+
+        # Normalize the scores, softmax over t_k
+        scores_normalized = self.softmax_att(scores)
+
+        # Calculate the weighted average of the attention inputs according to
+        # the scores
+        scores_normalized = self.dropout(scores_normalized)
+        # context: (b x t_q x n)
+        context = self.context_matmul(scores_normalized, keys)
+
+        if single_query:
+            context = context.squeeze(1)
+            scores_normalized = scores_normalized.squeeze(1)
+        elif not self.batch_first:
+            context = context.transpose(0, 1)
+            scores_normalized = scores_normalized.transpose(0, 1)
+
+        return context, scores_normalized
diff --git a/examples/GNMT/seq2seq/models/decoder.py b/examples/GNMT/seq2seq/models/decoder.py
new file mode 100644
index 0000000..66d1333
--- /dev/null
+++ b/examples/GNMT/seq2seq/models/decoder.py
@@ -0,0 +1,140 @@
+import itertools
+
+import torch
+import torch.nn as nn
+
+from seq2seq.models.attention import BahdanauAttention
+import seq2seq.data.config as config
+from distiller.modules import *
+
+
+class RecurrentAttention(nn.Module):
+
+    def __init__(self, input_size, context_size, hidden_size, num_layers=1,
+                 bias=True, batch_first=False, dropout=0):
+
+        super(RecurrentAttention, self).__init__()
+
+        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bias,
+                           batch_first)
+
+        self.attn = BahdanauAttention(hidden_size, context_size, context_size,
+                                      normalize=True, batch_first=batch_first)
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, inputs, hidden, context, context_len):
+        # set attention mask, sequences have different lengths, this mask
+        # allows to include only valid elements of context in attention's
+        # softmax
+        self.attn.set_mask(context_len, context)
+
+        rnn_outputs, hidden = self.rnn(inputs, hidden)
+        attn_outputs, scores = self.attn(rnn_outputs, context)
+        rnn_outputs = self.dropout(rnn_outputs)
+
+        return rnn_outputs, hidden, attn_outputs, scores
+
+
+class Classifier(nn.Module):
+
+    def __init__(self, in_features, out_features, math='fp32'):
+        super(Classifier, self).__init__()
+
+        self.out_features = out_features
+
+        # padding required to trigger HMMA kernels
+        if math == 'fp16':
+            out_features = (out_features + 7) // 8 * 8
+
+        self.classifier = nn.Linear(in_features, out_features)
+
+    def forward(self, x):
+        out = self.classifier(x)
+        out = out[..., :self.out_features]
+        return out
+
+
+class ResidualRecurrentDecoder(nn.Module):
+
+    def __init__(self, vocab_size, hidden_size=128, num_layers=8, bias=True,
+                 dropout=0, batch_first=False, math='fp32', embedder=None):
+
+        super(ResidualRecurrentDecoder, self).__init__()
+
+        self.num_layers = num_layers
+
+        self.att_rnn = RecurrentAttention(hidden_size, hidden_size,
+                                          hidden_size, num_layers=1,
+                                          batch_first=batch_first)
+
+        self.rnn_layers = nn.ModuleList()
+        for _ in range(num_layers - 1):
+            self.rnn_layers.append(
+                nn.LSTM(2 * hidden_size, hidden_size, num_layers=1, bias=bias,
+                        batch_first=batch_first))
+
+        if embedder is not None:
+            self.embedder = embedder
+        else:
+            self.embedder = nn.Embedding(vocab_size, hidden_size,
+                                        padding_idx=config.PAD)
+
+        self.classifier = Classifier(hidden_size, vocab_size, math)
+        self.dropout = nn.Dropout(p=dropout)
+
+        # Adding submodules for basic ops to allow quantization:
+        self.eltwiseadd_residuals = nn.ModuleList([EltwiseAdd() for _ in range(1, len(self.rnn_layers))])
+        self.attention_concats = nn.ModuleList([Concat(2) for _ in range(len(self.rnn_layers))])
+
+    def init_hidden(self, hidden):
+        if hidden is not None:
+            # per-layer chunks
+            hidden = hidden.chunk(self.num_layers)
+            # (h, c) chunks for LSTM layer
+            hidden = tuple(i.chunk(2) for i in hidden)
+        else:
+            hidden = [None] * self.num_layers
+
+        self.next_hidden = []
+        return hidden
+
+    def append_hidden(self, h):
+        if self.inference:
+            self.next_hidden.append(h)
+
+    def package_hidden(self):
+        if self.inference:
+            hidden = torch.cat(tuple(itertools.chain(*self.next_hidden)))
+        else:
+            hidden = None
+        return hidden
+
+    def forward(self, inputs, context, inference=False):
+        self.inference = inference
+
+        enc_context, enc_len, hidden = context
+        hidden = self.init_hidden(hidden)
+
+        x = self.embedder(inputs)
+
+        x, h, attn, scores = self.att_rnn(x, hidden[0], enc_context, enc_len)
+        self.append_hidden(h)
+
+        x = self.dropout(x)
+        x = self.attention_concats[0](x, attn)
+        x, h = self.rnn_layers[0](x, hidden[1])
+        self.append_hidden(h)
+
+        for i in range(1, len(self.rnn_layers)):
+            residual = x
+            x = self.dropout(x)
+            x = self.attention_concats[i](x, attn)
+            x, h = self.rnn_layers[i](x, hidden[i + 1])
+            self.append_hidden(h)
+            x = self.eltwiseadd_residuals[i-1](x, residual)
+
+        x = self.classifier(x)
+        hidden = self.package_hidden()
+
+        return x, scores, [enc_context, enc_len, hidden]
diff --git a/examples/GNMT/seq2seq/models/encoder.py b/examples/GNMT/seq2seq/models/encoder.py
new file mode 100644
index 0000000..c7cf1bd
--- /dev/null
+++ b/examples/GNMT/seq2seq/models/encoder.py
@@ -0,0 +1,62 @@
+import torch.nn as nn
+from torch.nn.utils.rnn import pack_padded_sequence
+from torch.nn.utils.rnn import pad_packed_sequence
+from distiller.modules import *
+
+import seq2seq.data.config as config
+
+
+class ResidualRecurrentEncoder(nn.Module):
+
+    def __init__(self, vocab_size, hidden_size=128, num_layers=8, bias=True,
+                 dropout=0, batch_first=False, embedder=None):
+
+        super(ResidualRecurrentEncoder, self).__init__()
+        self.batch_first = batch_first
+        self.rnn_layers = nn.ModuleList()
+        self.rnn_layers.append(
+            nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=bias,
+                    batch_first=batch_first, bidirectional=True))
+
+        self.rnn_layers.append(
+            nn.LSTM((2 * hidden_size), hidden_size, num_layers=1, bias=bias,
+                    batch_first=batch_first))
+
+        for _ in range(num_layers - 2):
+            self.rnn_layers.append(
+                nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=bias,
+                        batch_first=batch_first))
+
+        self.dropout = nn.Dropout(p=dropout)
+
+        if embedder is not None:
+            self.embedder = embedder
+        else:
+            self.embedder = nn.Embedding(vocab_size, hidden_size,
+                                        padding_idx=config.PAD)
+
+        # Adding submodules for basic ops to allow quantization:
+        self.eltwiseadd_residuals = nn.ModuleList([EltwiseAdd() for _ in range(2, len(self.rnn_layers))])
+
+    def forward(self, inputs, lengths):
+        x = self.embedder(inputs)
+
+        # bidirectional layer
+        x = pack_padded_sequence(x, lengths.cpu().numpy(),
+                                 batch_first=self.batch_first)
+        x, _ = self.rnn_layers[0](x)
+        x, _ = pad_packed_sequence(x, batch_first=self.batch_first)
+
+        # 1st unidirectional layer
+        x = self.dropout(x)
+        x, _ = self.rnn_layers[1](x)
+
+        # the rest of unidirectional layers,
+        # with residual connections starting from 3rd layer
+        for i in range(2, len(self.rnn_layers)):
+            residual = x
+            x = self.dropout(x)
+            x, _ = self.rnn_layers[i](x)
+            x = self.eltwiseadd_residuals[i-2](x, residual)
+
+        return x
diff --git a/examples/GNMT/seq2seq/models/gnmt.py b/examples/GNMT/seq2seq/models/gnmt.py
new file mode 100644
index 0000000..c576029
--- /dev/null
+++ b/examples/GNMT/seq2seq/models/gnmt.py
@@ -0,0 +1,36 @@
+import torch.nn as nn
+
+import seq2seq.data.config as config
+from .seq2seq_base import Seq2Seq
+from .decoder import ResidualRecurrentDecoder
+from .encoder import ResidualRecurrentEncoder
+
+
+class GNMT(Seq2Seq):
+    def __init__(self, vocab_size, hidden_size=512, num_layers=8, bias=True,
+                 dropout=0.2, batch_first=False, math='fp32',
+                 share_embedding=False):
+
+        super(GNMT, self).__init__(batch_first=batch_first)
+
+        if share_embedding:
+            embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD)
+        else:
+            embedder = None
+
+        self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size,
+                                                num_layers, bias, dropout,
+                                                batch_first, embedder)
+
+        self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size,
+                                                num_layers, bias, dropout,
+                                                batch_first, math, embedder)
+
+
+
+    def forward(self, input_encoder, input_enc_len, input_decoder):
+        context = self.encode(input_encoder, input_enc_len)
+        context = (context, input_enc_len, None)
+        output, _, _ = self.decode(input_decoder, context)
+
+        return output
diff --git a/examples/GNMT/seq2seq/models/seq2seq_base.py b/examples/GNMT/seq2seq/models/seq2seq_base.py
new file mode 100644
index 0000000..844a3c4
--- /dev/null
+++ b/examples/GNMT/seq2seq/models/seq2seq_base.py
@@ -0,0 +1,22 @@
+import torch.nn as nn
+from torch.nn.functional import log_softmax
+
+
+class Seq2Seq(nn.Module):
+    def __init__(self, encoder=None, decoder=None, batch_first=False):
+        super(Seq2Seq, self).__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.batch_first = batch_first
+
+    def encode(self, inputs, lengths):
+        return self.encoder(inputs, lengths)
+
+    def decode(self, inputs, context, inference=False):
+        return self.decoder(inputs, context, inference)
+
+    def generate(self, inputs, context, beam_size):
+        logits, scores, new_context = self.decode(inputs, context, True)
+        logprobs = log_softmax(logits, dim=-1)
+        logprobs, words = logprobs.topk(beam_size, dim=-1)
+        return words, logprobs, scores, new_context
diff --git a/examples/GNMT/seq2seq/utils.py b/examples/GNMT/seq2seq/utils.py
new file mode 100644
index 0000000..8629447
--- /dev/null
+++ b/examples/GNMT/seq2seq/utils.py
@@ -0,0 +1,116 @@
+from contextlib import contextmanager
+import os
+import logging.config
+
+import numpy as np
+import torch
+from torch.nn.utils.rnn import pack_padded_sequence
+
+import seq2seq.data.config as config
+
+
+def barrier():
+    """ Calls all_reduce on dummy tensor."""
+    if torch.distributed.is_initialized():
+        torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
+        torch.cuda.synchronize()
+
+
+def get_rank():
+    if torch.distributed.is_initialized():
+        rank = torch.distributed.get_rank()
+    else:
+        rank = 0
+    return rank
+
+def get_world_size():
+    if torch.distributed.is_initialized():
+        world_size = torch.distributed.get_world_size()
+    else:
+        world_size = 1
+    return world_size
+
+
+@contextmanager
+def sync_workers():
+    """ Gets distributed rank and synchronizes workers at exit"""
+    rank = get_rank()
+    yield rank
+    barrier()
+
+
+def setup_logging(log_file='log.log'):
+    """Setup logging configuration
+    """
+    class RankFilter(logging.Filter):
+        def __init__(self, rank):
+            self.rank = rank
+
+        def filter(self, record):
+            record.rank = self.rank
+            return True
+
+    rank = get_rank()
+    rank_filter = RankFilter(rank)
+
+    logging.basicConfig(level=logging.DEBUG,
+                        format="%(asctime)s - %(levelname)s - %(rank)s - %(message)s",
+                        datefmt="%Y-%m-%d %H:%M:%S",
+                        filename=log_file,
+                        filemode='w')
+    console = logging.StreamHandler()
+    console.setLevel(logging.INFO)
+    formatter = logging.Formatter('%(rank)s: %(message)s')
+    console.setFormatter(formatter)
+    logging.getLogger('').addHandler(console)
+    logging.getLogger('').addFilter(rank_filter)
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self, skip_first=True):
+        self.reset()
+        self.skip = skip_first
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+
+        if self.skip:
+            self.skip = False
+        else:
+            self.sum += val * n
+            self.count += n
+            self.avg = self.sum / self.count
+
+
+def batch_padded_sequences(seq, batch_first=False, sort=False):
+    if sort:
+        key = lambda item: len(item[1])
+        indices, seq = zip(*sorted(enumerate(seq), key=key, reverse=True))
+    else:
+        indices = range(len(seq))
+
+    lengths = [len(sentence) for sentence in seq]
+    batch_length = max(lengths)
+    seq_tensor = torch.LongTensor(batch_length, len(seq)).fill_(config.PAD)
+    for idx, sentence in enumerate(seq):
+        end_seq = lengths[idx]
+        seq_tensor[:end_seq, idx].copy_(sentence[:end_seq])
+    if batch_first:
+        seq_tensor = seq_tensor.t()
+    return seq_tensor, lengths, indices
+
+
+def debug_tensor(tensor, name):
+    logging.info(name)
+    tensor = tensor.float().cpu().numpy()
+    logging.info('MIN: {min} MAX: {max} AVG: {mean} STD: {std} NAN: {nans} INF: {infs}'
+                 .format(min=tensor.min(), max=tensor.max(), mean=tensor.mean(),
+                         std=tensor.std(), nans=np.isnan(tensor).sum(), infs=np.isinf(tensor).sum()))
diff --git a/examples/GNMT/translate.py b/examples/GNMT/translate.py
new file mode 100644
index 0000000..75275cf
--- /dev/null
+++ b/examples/GNMT/translate.py
@@ -0,0 +1,336 @@
+
+# Copyright 2019 The MLPerf Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+import argparse
+import codecs
+import time
+import warnings
+from ast import literal_eval
+from itertools import zip_longest
+
+import torch
+
+from seq2seq import models
+from seq2seq.inference.inference import Translator
+from seq2seq.utils import AverageMeter
+import subprocess
+import os
+import seq2seq.data.config as config
+from seq2seq.data.dataset import ParallelDataset
+import logging
+from seq2seq.utils import AverageMeter
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='GNMT Translate',
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    # data
+    dataset = parser.add_argument_group('data setup')
+    dataset.add_argument('--dataset-dir', default=None, required=True,
+                         help='path to directory with input data')
+    dataset.add_argument('-i', '--input', required=True,
+                         help='input file (tokenized)')
+    dataset.add_argument('-o', '--output', required=True,
+                         help='output file (tokenized)')
+    dataset.add_argument('-m', '--model', required=True,
+                         help='model checkpoint file')
+    dataset.add_argument('-r', '--reference', default=None,
+                         help='full path to the file with reference \
+                         translations (for sacrebleu)')
+
+    # parameters
+    params = parser.add_argument_group('inference setup')
+    params.add_argument('--batch-size', default=128, type=int,
+                        help='batch size')
+    params.add_argument('--beam-size', default=5, type=int,
+                        help='beam size')
+    params.add_argument('--max-seq-len', default=80, type=int,
+                        help='maximum prediciton sequence length')
+    params.add_argument('--len-norm-factor', default=0.6, type=float,
+                        help='length normalization factor')
+    params.add_argument('--cov-penalty-factor', default=0.1, type=float,
+                        help='coverage penalty factor')
+    params.add_argument('--len-norm-const', default=5.0, type=float,
+                        help='length normalization constant')
+    # general setup
+    general = parser.add_argument_group('general setup')
+
+    general.add_argument('--mode', default='accuracy', choices=['accuracy',
+            'performance'], help='test in accuracy or performance mode')
+
+    general.add_argument('--math', default='fp16', choices=['fp32', 'fp16'],
+                         help='arithmetic type')
+
+    batch_first_parser = general.add_mutually_exclusive_group(required=False)
+    batch_first_parser.add_argument('--batch-first', dest='batch_first',
+                                    action='store_true',
+                                    help='uses (batch, seq, feature) data \
+                                    format for RNNs')
+    batch_first_parser.add_argument('--seq-first', dest='batch_first',
+                                    action='store_false',
+                                    help='uses (seq, batch, feature) data \
+                                    format for RNNs')
+    batch_first_parser.set_defaults(batch_first=True)
+
+    cuda_parser = general.add_mutually_exclusive_group(required=False)
+    cuda_parser.add_argument('--cuda', dest='cuda', action='store_true',
+                             help='enables cuda (use \'--no-cuda\' to disable)')
+    cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false',
+                             help=argparse.SUPPRESS)
+    cuda_parser.set_defaults(cuda=True)
+
+    cudnn_parser = general.add_mutually_exclusive_group(required=False)
+    cudnn_parser.add_argument('--cudnn', dest='cudnn', action='store_true',
+                              help='enables cudnn (use \'--no-cudnn\' to disable)')
+    cudnn_parser.add_argument('--no-cudnn', dest='cudnn', action='store_false',
+                              help=argparse.SUPPRESS)
+    cudnn_parser.set_defaults(cudnn=True)
+
+    general.add_argument('--print-freq', '-p', default=1, type=int,
+                         help='print log every PRINT_FREQ batches')
+
+    return parser.parse_args()
+
+
+def grouper(iterable, size, fillvalue=None):
+    args = [iter(iterable)] * size
+    return zip_longest(*args, fillvalue=fillvalue)
+
+
+def write_output(output_file, lines):
+    for line in lines:
+        output_file.write(line)
+        output_file.write('\n')
+
+
+def checkpoint_from_distributed(state_dict):
+    ret = False
+    for key, _ in state_dict.items():
+        if key.find('module.') != -1:
+            ret = True
+            break
+    return ret
+
+
+def unwrap_distributed(state_dict):
+    new_state_dict = {}
+    for key, value in state_dict.items():
+        new_key = key.replace('module.', '')
+        new_state_dict[new_key] = value
+
+    return new_state_dict
+
+
+def main():
+    args = parse_args()
+    print(args)
+
+    if args.cuda:
+        torch.cuda.set_device(0)
+    if not args.cuda and torch.cuda.is_available():
+        warnings.warn('cuda is available but not enabled')
+    if args.math == 'fp16' and not args.cuda:
+        raise RuntimeError('fp16 requires cuda')
+    if not args.cudnn:
+        torch.backends.cudnn.enabled = False
+
+    checkpoint = torch.load(args.model, map_location={'cuda:0': 'cpu'})
+
+    vocab_size = checkpoint['tokenizer'].vocab_size
+    model_config = dict(vocab_size=vocab_size, math=checkpoint['config'].math,
+                        **literal_eval(checkpoint['config'].model_config))
+    model_config['batch_first'] = args.batch_first
+    model = models.GNMT(**model_config)
+
+    state_dict = checkpoint['state_dict']
+    if checkpoint_from_distributed(state_dict):
+        state_dict = unwrap_distributed(state_dict)
+
+    model.load_state_dict(state_dict)
+
+    if args.math == 'fp32':
+        dtype = torch.FloatTensor
+    if args.math == 'fp16':
+        dtype = torch.HalfTensor
+
+    model.type(dtype)
+    if args.cuda:
+        model = model.cuda()
+    model.eval()
+
+    tokenizer = checkpoint['tokenizer']
+
+
+    test_data = ParallelDataset(
+        src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME),
+        tgt_fname=os.path.join(args.dataset_dir, config.TGT_TEST_FNAME),
+        tokenizer=tokenizer,
+        min_len=0,
+        max_len=150,
+        sort=False)
+
+    test_loader = test_data.get_loader(batch_size=args.batch_size,
+                                       batch_first=True,
+                                       shuffle=False,
+                                       num_workers=0,
+                                       drop_last=False,
+                                       distributed=False)
+
+    translator = Translator(model,
+                                   tokenizer,
+                                   beam_size=args.beam_size,
+                                   max_seq_len=args.max_seq_len,
+                                   len_norm_factor=args.len_norm_factor,
+                                   len_norm_const=args.len_norm_const,
+                                   cov_penalty_factor=args.cov_penalty_factor,
+                                   cuda=args.cuda)
+
+    model.eval()
+    torch.cuda.empty_cache()
+
+    # only write the output to file in accuracy mode
+    if args.mode == 'accuracy':
+        test_file = open(args.output, 'w', encoding='UTF-8')
+
+    batch_time = AverageMeter(False)
+    tot_tok_per_sec = AverageMeter(False)
+    iterations = AverageMeter(False)
+    enc_seq_len = AverageMeter(False)
+    dec_seq_len = AverageMeter(False)
+    stats = {}
+
+    for i, (src, tgt, indices) in enumerate(test_loader):
+        translate_timer = time.time()
+        src, src_length = src
+
+        if translator.batch_first:
+            batch_size = src.size(0)
+        else:
+            batch_size = src.size(1)
+        beam_size = args.beam_size
+
+        bos = [translator.insert_target_start] * (batch_size * beam_size)
+        bos = torch.LongTensor(bos)
+        if translator.batch_first:
+            bos = bos.view(-1, 1)
+        else:
+            bos = bos.view(1, -1)
+
+        src_length = torch.LongTensor(src_length)
+        stats['total_enc_len'] = int(src_length.sum())
+
+        if args.cuda:
+            src = src.cuda()
+            src_length = src_length.cuda()
+            bos = bos.cuda()
+
+        with torch.no_grad():
+            context = translator.model.encode(src, src_length)
+            context = [context, src_length, None]
+
+            if beam_size == 1:
+                generator = translator.generator.greedy_search
+            else:
+                generator = translator.generator.beam_search
+            preds, lengths, counter = generator(batch_size, bos, context)
+
+        stats['total_dec_len'] = lengths.sum().item()
+        stats['iters'] = counter
+
+        preds = preds.cpu()
+        lengths = lengths.cpu()
+
+        output = []
+        for idx, pred in enumerate(preds):
+            end = lengths[idx] - 1
+            pred = pred[1: end]
+            pred = pred.tolist()
+            out = translator.tok.detokenize(pred)
+            output.append(out)
+
+        # only write the output to file in accuracy mode
+        if args.mode == 'accuracy':
+            output = [output[indices.index(i)] for i in range(len(output))]
+            for line in output:
+                test_file.write(line)
+                test_file.write('\n')
+
+
+        # Get timing
+        elapsed = time.time() - translate_timer
+        batch_time.update(elapsed, batch_size)
+
+        total_tokens = stats['total_dec_len'] + stats['total_enc_len']
+        ttps = total_tokens / elapsed
+        tot_tok_per_sec.update(ttps, batch_size)
+
+        iterations.update(stats['iters'])
+        enc_seq_len.update(stats['total_enc_len'] / batch_size, batch_size)
+        dec_seq_len.update(stats['total_dec_len'] / batch_size, batch_size)
+
+        if i % 5 == 0:
+            log = []
+            log += 'TEST '
+            log += 'Time {:.3f} ({:.3f})\t'.format(batch_time.val, batch_time.avg)
+            log += 'Decoder iters {:.1f} ({:.1f})\t'.format(iterations.val, iterations.avg)
+            log += 'Tok/s {:.0f} ({:.0f})'.format(tot_tok_per_sec.val, tot_tok_per_sec.avg)
+            log = ''.join(log)
+            print(log)
+
+
+    # summary timing
+    time_per_sentence = (batch_time.avg / batch_size)
+    log = []
+    log += 'TEST SUMMARY:\n'
+    log += 'Lines translated: {}\t'.format(len(test_loader.dataset))
+    log += 'Avg total tokens/s: {:.0f}\n'.format(tot_tok_per_sec.avg)
+    log += 'Avg time per batch: {:.3f} s\t'.format(batch_time.avg)
+    log += 'Avg time per sentence: {:.3f} ms\n'.format(1000 * time_per_sentence)
+    log += 'Avg encoder seq len: {:.2f}\t'.format(enc_seq_len.avg)
+    log += 'Avg decoder seq len: {:.2f}\t'.format(dec_seq_len.avg)
+    log += 'Total decoder iterations: {}'.format(int(iterations.sum))
+    log = ''.join(log)
+    print(log)
+
+    # only write the output to file in accuracy mode
+    if args.mode == 'accuracy':
+        test_file.close()
+
+        test_path = args.output
+        # run moses detokenizer
+        detok_path = os.path.join(args.dataset_dir, config.DETOKENIZER)
+        detok_test_path = test_path + '.detok'
+
+        with open(detok_test_path, 'w') as detok_test_file, \
+                open(test_path, 'r') as test_file:
+            subprocess.run(['perl', detok_path], stdin=test_file,
+                           stdout=detok_test_file, stderr=subprocess.DEVNULL)
+
+
+        # run sacrebleu
+        reference_path = os.path.join(args.dataset_dir,
+                                      config.TGT_TEST_TARGET_FNAME)
+        sacrebleu = subprocess.run(['sacrebleu --input {} {} --score-only -lc --tokenize intl'.format(detok_test_path,
+                                                                                                      reference_path)],
+                                   stdout=subprocess.PIPE, shell=True)
+        bleu = float(sacrebleu.stdout.strip())
+
+        print('BLEU on test dataset: {}'.format(bleu))
+
+        print('Finished evaluation on test set')
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/GNMT/verify_dataset.sh b/examples/GNMT/verify_dataset.sh
new file mode 100644
index 0000000..d149e2a
--- /dev/null
+++ b/examples/GNMT/verify_dataset.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+
+set -e
+
+ACTUAL_SRC_TRAIN=`cat data/train.tok.clean.bpe.32000.en |md5sum`
+EXPECTED_SRC_TRAIN='b7482095b787264a310d4933d197a134  -'
+if [[ $ACTUAL_SRC_TRAIN = $EXPECTED_SRC_TRAIN ]]; then
+  echo "OK: correct data/train.tok.clean.bpe.32000.en"
+else
+  echo "ERROR: incorrect data/train.tok.clean.bpe.32000.en"
+  echo "ERROR: expected $EXPECTED_SRC_TRAIN"
+  echo "ERROR: found $ACTUAL_SRC_TRAIN"
+fi
+
+ACTUAL_TGT_TRAIN=`cat data/train.tok.clean.bpe.32000.de |md5sum`
+EXPECTED_TGT_TRAIN='409064aedaef5b7c458ff19a7beda462  -'
+if [[ $ACTUAL_TGT_TRAIN = $EXPECTED_TGT_TRAIN ]]; then
+  echo "OK: correct data/train.tok.clean.bpe.32000.de"
+else
+  echo "ERROR: incorrect data/train.tok.clean.bpe.32000.de"
+  echo "ERROR: expected $EXPECTED_TGT_TRAIN"
+  echo "ERROR: found $ACTUAL_TGT_TRAIN"
+fi
+
+ACTUAL_SRC_VAL=`cat data/newstest_dev.tok.clean.bpe.32000.en |md5sum`
+EXPECTED_SRC_VAL='704c4ba8c8b63df1f6678d32b91438b5  -'
+if [[ $ACTUAL_SRC_VAL = $EXPECTED_SRC_VAL ]]; then
+  echo "OK: correct data/newstest_dev.tok.clean.bpe.32000.en"
+else
+  echo "ERROR: incorrect data/newstest_dev.tok.clean.bpe.32000.en"
+  echo "ERROR: expected $EXPECTED_SRC_VAL"
+  echo "ERROR: found $ACTUAL_SRC_VAL"
+fi
+
+ACTUAL_TGT_VAL=`cat data/newstest_dev.tok.clean.bpe.32000.de |md5sum`
+EXPECTED_TGT_VAL='d27f5a64c839e20c5caa8b9d60075dde  -'
+if [[ $ACTUAL_TGT_VAL = $EXPECTED_TGT_VAL ]]; then
+  echo "OK: correct data/newstest_dev.tok.clean.bpe.32000.de"
+else
+  echo "ERROR: incorrect data/newstest_dev.tok.clean.bpe.32000.de"
+  echo "ERROR: expected $EXPECTED_TGT_VAL"
+  echo "ERROR: found $ACTUAL_TGT_VAL"
+fi
+
+ACTUAL_SRC_TEST=`cat data/newstest2014.tok.bpe.32000.en |md5sum`
+EXPECTED_SRC_TEST='cb014e2509f86cd81d5a87c240c07464  -'
+if [[ $ACTUAL_SRC_TEST = $EXPECTED_SRC_TEST ]]; then
+  echo "OK: correct data/newstest2014.tok.bpe.32000.en"
+else
+  echo "ERROR: incorrect data/newstest2014.tok.bpe.32000.en"
+  echo "ERROR: expected $EXPECTED_SRC_TEST"
+  echo "ERROR: found $ACTUAL_SRC_TEST"
+fi
+
+ACTUAL_TGT_TEST=`cat data/newstest2014.tok.bpe.32000.de |md5sum`
+EXPECTED_TGT_TEST='d616740f6026dc493e66efdf9ac1cb04  -'
+if [[ $ACTUAL_TGT_TEST = $EXPECTED_TGT_TEST ]]; then
+  echo "OK: correct data/newstest2014.tok.bpe.32000.de"
+else
+  echo "ERROR: incorrect data/newstest2014.tok.bpe.32000.de"
+  echo "ERROR: expected $EXPECTED_TGT_TEST"
+  echo "ERROR: found $ACTUAL_TGT_TEST"
+fi
+
+ACTUAL_TGT_TEST_TARGET=`cat data/newstest2014.de |md5sum`
+EXPECTED_TGT_TEST_TARGET='f6c3818b477e4a25cad68b61cc883c17  -'
+if [[ $ACTUAL_TGT_TEST_TARGET = $EXPECTED_TGT_TEST_TARGET ]]; then
+  echo "OK: correct data/newstest2014.de"
+else
+  echo "ERROR: incorrect data/newstest2014.de"
+  echo "ERROR: expected $EXPECTED_TGT_TEST_TARGET"
+  echo "ERROR: found $ACTUAL_TGT_TEST_TARGET"
+fi
-- 
GitLab