From 1198a8890a6d858cb6b11124e3d6be4638f29408 Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Thu, 8 Aug 2019 19:01:35 +0300 Subject: [PATCH] Added GNMT post-training quantization example (#252) --- .gitignore | 5 + README.md | 4 +- distiller/modules/__init__.py | 11 +- distiller/modules/aggregate.py | 16 + distiller/modules/eltwise.py | 15 +- distiller/modules/matmul.py | 39 + distiller/quantization/range_linear.py | 135 +- examples/GNMT/LICENSE | 22 + examples/GNMT/README.md | 121 + examples/GNMT/download_dataset.sh | 172 + examples/GNMT/download_trained_model.sh | 4 + examples/GNMT/model_stats.yaml | 2868 +++++++++++++++++ examples/GNMT/quantize_gnmt.ipynb | 1069 ++++++ examples/GNMT/requirements.txt | 3 + examples/GNMT/scripts/filter_dataset.py | 79 + examples/GNMT/seq2seq/__init__.py | 0 examples/GNMT/seq2seq/data/__init__.py | 0 examples/GNMT/seq2seq/data/config.py | 21 + examples/GNMT/seq2seq/data/dataset.py | 123 + examples/GNMT/seq2seq/data/sampler.py | 85 + examples/GNMT/seq2seq/data/tokenizer.py | 44 + examples/GNMT/seq2seq/inference/__init__.py | 0 .../GNMT/seq2seq/inference/beam_search.py | 245 ++ examples/GNMT/seq2seq/inference/inference.py | 88 + examples/GNMT/seq2seq/models/__init__.py | 4 + examples/GNMT/seq2seq/models/attention.py | 164 + examples/GNMT/seq2seq/models/decoder.py | 140 + examples/GNMT/seq2seq/models/encoder.py | 62 + examples/GNMT/seq2seq/models/gnmt.py | 36 + examples/GNMT/seq2seq/models/seq2seq_base.py | 22 + examples/GNMT/seq2seq/utils.py | 116 + examples/GNMT/translate.py | 336 ++ examples/GNMT/verify_dataset.sh | 73 + 33 files changed, 6106 insertions(+), 16 deletions(-) create mode 100644 distiller/modules/aggregate.py create mode 100644 distiller/modules/matmul.py create mode 100644 examples/GNMT/LICENSE create mode 100644 examples/GNMT/README.md create mode 100644 examples/GNMT/download_dataset.sh create mode 100644 examples/GNMT/download_trained_model.sh create mode 100644 examples/GNMT/model_stats.yaml create mode 100644 examples/GNMT/quantize_gnmt.ipynb create mode 100644 examples/GNMT/requirements.txt create mode 100644 examples/GNMT/scripts/filter_dataset.py create mode 100644 examples/GNMT/seq2seq/__init__.py create mode 100644 examples/GNMT/seq2seq/data/__init__.py create mode 100644 examples/GNMT/seq2seq/data/config.py create mode 100644 examples/GNMT/seq2seq/data/dataset.py create mode 100644 examples/GNMT/seq2seq/data/sampler.py create mode 100644 examples/GNMT/seq2seq/data/tokenizer.py create mode 100644 examples/GNMT/seq2seq/inference/__init__.py create mode 100644 examples/GNMT/seq2seq/inference/beam_search.py create mode 100644 examples/GNMT/seq2seq/inference/inference.py create mode 100644 examples/GNMT/seq2seq/models/__init__.py create mode 100644 examples/GNMT/seq2seq/models/attention.py create mode 100644 examples/GNMT/seq2seq/models/decoder.py create mode 100644 examples/GNMT/seq2seq/models/encoder.py create mode 100644 examples/GNMT/seq2seq/models/gnmt.py create mode 100644 examples/GNMT/seq2seq/models/seq2seq_base.py create mode 100644 examples/GNMT/seq2seq/utils.py create mode 100644 examples/GNMT/translate.py create mode 100644 examples/GNMT/verify_dataset.sh diff --git a/.gitignore b/.gitignore index 913ea4c..a1670e9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,11 @@ __pycache__/ .cache pytest_collaterals/ +# GNMT sample +examples/GNMT/data +examples/GNMT/model_best.pth +examples/GNMT/output_file* + # Virtual env env/ .env/ diff --git a/README.md b/README.md index 60f97f0..3dcede3 100755 --- a/README.md +++ b/README.md @@ -250,8 +250,8 @@ For more details, there are some other resources you can refer to: + [Preparing a model for quantization](https://nervanasystems.github.io/distiller/prepare_model_quant.html) + [Tutorial: Using Distiller to prune a PyTorch language model](https://nervanasystems.github.io/distiller/tutorial-lang_model.html) + [Tutorial: Pruning Filters & Channels](https://nervanasystems.github.io/distiller/tutorial-struct_pruning.html) -+ [Tutorial: Post-Training Quantization of a Language Model -](https://nervanasystems.github.io/distiller/tutorial-lang_model_quant.html) ++ [Tutorial: Post-Training Quantization of a Language Model](https://nervanasystems.github.io/distiller/tutorial-lang_model_quant.html) ++ [Tutorial: Post-Training Quantization of GNMT (translation model)](https://nervanasystems.github.io/distiller/tutorial-lang_model_quant.html) + [Post-training quantization command line examples](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/command_line.md) ### Example invocations of the sample application diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index 5bd7d5c..03c3f57 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -14,10 +14,13 @@ # limitations under the License. # -from .eltwise import EltwiseAdd, EltwiseMult +from .eltwise import * from .grouping import * -from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm +from .matmul import * +from .rnn import * +from .aggregate import Norm -__all__ = ['EltwiseAdd', 'EltwiseMult', +__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', 'Concat', 'Chunk', 'Split', 'Stack', - 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm'] + 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', + 'Norm'] diff --git a/distiller/modules/aggregate.py b/distiller/modules/aggregate.py new file mode 100644 index 0000000..a217627 --- /dev/null +++ b/distiller/modules/aggregate.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + + +class Norm(nn.Module): + """ + A module wrapper for vector/matrix norm + """ + def __init__(self, p='fro', dim=None, keepdim=False): + super(Norm, self).__init__() + self.p = p + self.dim = dim + self.keepdim = keepdim + + def forward(self, x: torch.Tensor): + return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) diff --git a/distiller/modules/eltwise.py b/distiller/modules/eltwise.py index b61a168..8435059 100644 --- a/distiller/modules/eltwise.py +++ b/distiller/modules/eltwise.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import torch import torch.nn as nn class EltwiseAdd(nn.Module): def __init__(self, inplace=False): super(EltwiseAdd, self).__init__() + self.inplace = inplace def forward(self, *input): @@ -47,3 +48,15 @@ class EltwiseMult(nn.Module): for t in input[1:]: res = res * t return res + + +class EltwiseDiv(nn.Module): + def __init__(self, inplace=False): + super(EltwiseDiv, self).__init__() + self.inplace = inplace + + def forward(self, x: torch.Tensor, y): + if self.inplace: + return x.div_(y) + return x.div(y) + diff --git a/distiller/modules/matmul.py b/distiller/modules/matmul.py new file mode 100644 index 0000000..bdbeccd --- /dev/null +++ b/distiller/modules/matmul.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +import torch.nn as nn + + +class Matmul(nn.Module): + """ + A wrapper module for matmul operation between 2 tensors. + """ + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return a.matmul(b) + + +class BatchMatmul(nn.Module): + """ + A wrapper module for torch.bmm operation between 2 tensors. + """ + def __init__(self): + super(BatchMatmul, self).__init__() + + def forward(self, a: torch.Tensor, b:torch.Tensor): + return torch.bmm(a, b) \ No newline at end of file diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index a213c60..7bbedc4 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -18,7 +18,7 @@ import torch.nn as nn import argparse from enum import Enum from collections import OrderedDict -from functools import reduce, partial +from functools import reduce, partial, update_wrapper import logging import os from copy import deepcopy @@ -414,7 +414,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per output channel activation_stats (dict): See RangeLinearQuantWrapper - clip_n_stds (float): See RangeLinearQuantWrapper + clip_n_stds (int): See RangeLinearQuantWrapper scale_approx_mult_bits (int): See RangeLinearQuantWrapper """ def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32, @@ -468,6 +468,19 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor self.register_buffer('base_b_q', base_b_q) + # A flag indicating that the simulated quantized weights are pre-shifted. for faster performance. + # In the first forward pass - `w_zero_point` is added into the weights, to allow faster inference, + # and all subsequent calls are done with these shifted weights. + # Upon calling `self.state_dict()` - we restore the actual quantized weights. + self.is_simulated_quant_weight_shifted = False + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.is_simulated_quant_weight_shifted: + # We want to return the weights to their integer representation: + self.wrapped_module.weight.data -= self.w_zero_point + self.is_simulated_quant_weight_shifted = False + return super(RangeLinearQuantParamLayerWrapper, self).state_dict(destination, prefix, keep_vars) + def get_inputs_quantization_params(self, input): if not self.preset_act_stats: self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor( @@ -501,15 +514,16 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # to the input and weights and pass those to the wrapped model. Functionally, since at this point we're # dealing solely with integer values, the results are the same either way. - if self.mode != LinearQuantMode.SYMMETRIC: - input_q += self.in_0_zero_point + if self.mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted: + # We "store" the w_zero_point inside our wrapped module's weights to + # improve performance on inference. self.wrapped_module.weight.data += self.w_zero_point + self.is_simulated_quant_weight_shifted = True + input_q += self.in_0_zero_point accum = self.wrapped_module.forward(input_q) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) - if self.mode != LinearQuantMode.SYMMETRIC: - self.wrapped_module.weight.data -= self.w_zero_point return accum def get_output_quantization_params(self, accumulator): @@ -553,6 +567,96 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): return tmpstr +class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): + """ + Wrapper for quantizing the Matmul/BatchMatmul operation between 2 input tensors. + output = input1 @ input2 + where: + input1.shape = (input_batch, input_size) + input2.shape = (input_size, output_size) + The mathematical calculation is: + y_f = i1_f * i2_f + iN_f = iN_q / scale_iN + zp_iN => + y_q = scale_y * y_f + zp_y = scale_y * (i1_f * i2_f) + zp_y = + + scale_y + y_q = ------------------- * ((i1_q + zp_i1) * (i2_q + zp_i2) + zp_y + scale_i1 * scale_i2 + Args: + wrapped_module (distiller.modules.Matmul or distiller.modules.BatchMatmul): Module to be wrapped + num_bits_acts (int): Number of bits used for inputs and output quantization + num_bits_accum (int): Number of bits allocated for the accumulator of intermediate integer results + mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed/unsigned) + clip_acts (ClipNode): See RangeLinearQuantWrapper + activation_stats (dict): See RangeLinearQuantWrapper + clip_n_stds (int): See RangeLinearQuantWrapper + scale_approx_mult_bits (int): See RangeLinearQuantWrapper + """ + def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, + mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, + clip_n_stds=None, scale_approx_mult_bits=None): + super(RangeLinearQuantMatmulWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode, + clip_acts, activation_stats, clip_n_stds, + scale_approx_mult_bits) + + if not isinstance(wrapped_module, (distiller.modules.Matmul, distiller.modules.BatchMatmul)): + raise ValueError(self.__class__.__name__ + ' can wrap only Matmul modules') + if self.preset_act_stats: + self.register_buffer('accum_scale', self.in_0_scale * self.in_1_scale) + else: + self.accum_scale = 1 + + def get_inputs_quantization_params(self, input0, input1): + if not self.preset_act_stats: + self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor( + input0, self.num_bits_acts, self.mode, clip=self.clip_acts, + num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits) + self.in_1_scale, self.in_1_zero_point = _get_quant_params_from_tensor( + input0, self.num_bits_acts, self.mode, clip=self.clip_acts, + num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits) + return [self.in_0_scale, self.in_1_scale], [self.in_0_zero_point, self.in_1_zero_point] + + def quantized_forward(self, input0_q, input1_q): + accum = self.wrapped_module.forward(input0_q + self.in_0_zero_point, + input1_q + self.in_1_zero_point) + clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) + return accum + + def get_output_quantization_params(self, accumulator): + if self.preset_act_stats: + return self.output_scale, self.output_zero_point + + y_f = accumulator / self.accum_scale + return _get_quant_params_from_tensor(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts, + num_stds=self.clip_n_stds, + scale_approx_mult_bits=self.scale_approx_mult_bits) + + def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): + requant_scale = output_scale / self.accum_scale + if self.scale_approx_mult_bits is not None: + requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits) + return requant_scale, output_zero_point + + def extra_repr(self): + tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1]) + tmpstr += 'num_bits_acts={0}, num_bits_accum={1}, '.format(self.num_bits_acts, self.num_bits_accum) + tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts)) + if self.clip_acts == ClipMode.N_STD: + tmpstr += 'num_stds={} '.format(self.clip_n_stds) + tmpstr += 'scale_approx_mult_bits={}'.format(self.scale_approx_mult_bits) + tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats) + if self.preset_act_stats: + tmpstr += '\nin_0_scale={0:.4f}, in_0_zero_point={1:.4f}'.format(self.in_0_scale.item(), + self.in_0_zero_point.item()) + + tmpstr += '\nin_1_scale={0:.4f}, in_1_zero_point={1:.4f}'.format(self.in_1_scale.item(), + self.in_1_zero_point.item()) + + tmpstr += '\nout_scale={0:.4f}, out_zero_point={1:.4f}'.format(self.output_scale.item(), + self.output_zero_point.item()) + return tmpstr + + class NoStatsError(NotImplementedError): pass @@ -877,12 +981,25 @@ class PostTrainLinearQuantizer(Quantizer): self.replacement_factory[nn.Conv3d] = replace_param_layer self.replacement_factory[nn.Linear] = replace_param_layer - self.replacement_factory[distiller.modules.Concat] = partial( + factory_concat = partial( replace_non_param_layer, RangeLinearQuantConcatWrapper) - self.replacement_factory[distiller.modules.EltwiseAdd] = partial( + factory_eltwiseadd = partial( replace_non_param_layer, RangeLinearQuantEltwiseAddWrapper) - self.replacement_factory[distiller.modules.EltwiseMult] = partial( + factory_eltwisemult = partial( replace_non_param_layer, RangeLinearQuantEltwiseMultWrapper) + factory_matmul = partial( + replace_non_param_layer, RangeLinearQuantMatmulWrapper) + + update_wrapper(factory_concat, replace_non_param_layer) + update_wrapper(factory_eltwiseadd, replace_non_param_layer) + update_wrapper(factory_eltwisemult, replace_non_param_layer) + update_wrapper(factory_matmul, replace_non_param_layer) + + self.replacement_factory[distiller.modules.Concat] = factory_concat + self.replacement_factory[distiller.modules.EltwiseAdd] = factory_eltwiseadd + self.replacement_factory[distiller.modules.EltwiseMult] = factory_eltwisemult + self.replacement_factory[distiller.modules.Matmul] = factory_matmul + self.replacement_factory[distiller.modules.BatchMatmul] = factory_matmul self.replacement_factory[nn.Embedding] = replace_embedding save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.' diff --git a/examples/GNMT/LICENSE b/examples/GNMT/LICENSE new file mode 100644 index 0000000..4343c76 --- /dev/null +++ b/examples/GNMT/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2017 Elad Hoffer +Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/GNMT/README.md b/examples/GNMT/README.md new file mode 100644 index 0000000..893ce1f --- /dev/null +++ b/examples/GNMT/README.md @@ -0,0 +1,121 @@ +# Google's Neural Machine Translation +In this example we quantize [MLPerf's implementation of GNMT](https://github.com/mlperf/training/tree/master/rnn_translator/pytorch) +and show different configurations of quantization to achieve the highest accuracy using **post-training quantization**. + +Note that this folder contains only code required to run evaluation. All training code was removed. A link to a pre-trained model is provided below. + +For a summary on the quantization results see [below](#results). + +## Running the Example + +This example is implemented as Jupyter notebook. + +### Install Requirements + + pip install -r requirements + +(This will install [sacrebleu](https://pypi.org/project/sacrebleu/)) + +### Get the Dataset + +Download the data using the following command: + + bash download_dataset.sh + +Verify data with: + + bash verify_dataset.sh + +### Download the Pre-trained Model + + wget https://zenodo.org/record/2581623/files/model_best.pth + +### Run the Example + + jupyter notebook + +And start the `quantize_gnmt.ipynb` notebook. + +## Summary of Quantization Results + +### What is Quantized + +The following operations / modules are fully quantized: + +* Linear (fully-connected) +* Embedding +* Element-wise addition +* Element-wise multiplication +* MatMul / Batch MatMul +* Concat + +The following operations do not have a quantized implementation. The operations run in FP32, with quantized + de-quantize applied at the op boundary (input and output): + +* Softmax +* Tanh +* Sigmoid +* Division by norm in the attention block. That is, in pseudo code: + ```python + quant_dequant(x) + y = x / norm(x) + quant_dequant(y) + ``` + +### Results + +| Precision | Mode | Per-Channel | Clip Activations | Bleu Score | +|-----------|------------|-------------|---------------------------------------------------------------|------------| +| FP32 | N/A | N/A | N/A | 22.16 | +| INT8 | Symmetric | No | No | 18.05 | +| INT8 | Asymmetric | No | No | 18.52 | +| INT8 | Asymmetric | Yes | AVG in all layers | 9.63 | +| INT8 | Asymmetric | Yes | AVG in all layers except attention block | 16.94 | +| INT8 | Asymmetric | Yes | AVG in all layers except attention block and final classifier | 21.49 | + +## Dataset / Environment + +### Publication / Attribution + +We use [WMT16 English-German](http://www.statmt.org/wmt16/translation-task.html) for training. + +### Data preprocessing + +Script uses [subword-nmt](https://github.com/rsennrich/subword-nmt) package to segment text into subword units (BPE), by default it builds shared vocabulary of 32,000 tokens. Preprocessing removes all pairs of sentences that can't be decoded by latin-1 encoder. + +## Model + +### Publication / Attribution + +Implemented model is similar to the one from [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144) paper. + +Most important difference is in the attention mechanism. This repository implements `gnmt_v2` attention: output from first LSTM layer of decoder goes into attention, then re-weighted context is concatenated with inputs to all subsequent LSTM layers in decoder at current timestep. + +The same attention mechanism is also implemented in default GNMT-like models from [tensorflow/nmt](https://github.com/tensorflow/nmt) and [NVIDIA/OpenSeq2Seq](https://github.com/NVIDIA/OpenSeq2Seq). + +### Structure + +* general: + * encoder and decoder are using shared embeddings + * data-parallel multi-gpu training + * dynamic loss scaling with backoff for Tensor Cores (mixed precision) training + * trained with label smoothing loss (smoothing factor 0.1) +* encoder: + * 4-layer LSTM, hidden size 1024, first layer is bidirectional, the rest are + undirectional + * with residual connections starting from 3rd layer + * uses standard LSTM layer (accelerated by cudnn) +* decoder: + * 4-layer unidirectional LSTM with hidden size 1024 and fully-connected + classifier + * with residual connections starting from 3rd layer + * uses standard LSTM layer (accelerated by cudnn) +* attention: + * normalized Bahdanau attention + * model uses `gnmt_v2` attention mechanism + * output from first LSTM layer of decoder goes into attention, + then re-weighted context is concatenated with the input to all subsequent + LSTM layers in decoder at the current timestep +* inference: + * beam search with default beam size 5 + * with coverage penalty and length normalization + * BLEU computed by [sacrebleu](https://pypi.org/project/sacrebleu/) diff --git a/examples/GNMT/download_dataset.sh b/examples/GNMT/download_dataset.sh new file mode 100644 index 0000000..8dbe528 --- /dev/null +++ b/examples/GNMT/download_dataset.sh @@ -0,0 +1,172 @@ +#! /usr/bin/env bash + +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + + +OUTPUT_DIR=${1:-"data"} +echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable." + +OUTPUT_DIR_DATA="${OUTPUT_DIR}/data" + +mkdir -p $OUTPUT_DIR_DATA + +echo "Downloading Europarl v7. This may take a while..." +wget -nc -nv -O ${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz \ + http://www.statmt.org/europarl/v7/de-en.tgz + +echo "Downloading Common Crawl corpus. This may take a while..." +wget -nc -nv -O ${OUTPUT_DIR_DATA}/common-crawl.tgz \ + http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz + +echo "Downloading News Commentary v11. This may take a while..." +wget -nc -nv -O ${OUTPUT_DIR_DATA}/nc-v11.tgz \ + http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz + +echo "Downloading dev/test sets" +wget -nc -nv -O ${OUTPUT_DIR_DATA}/dev.tgz \ + http://data.statmt.org/wmt16/translation-task/dev.tgz +wget -nc -nv -O ${OUTPUT_DIR_DATA}/test.tgz \ + http://data.statmt.org/wmt16/translation-task/test.tgz + +# Extract everything +echo "Extracting all files..." +mkdir -p "${OUTPUT_DIR_DATA}/europarl-v7-de-en" +tar -xvzf "${OUTPUT_DIR_DATA}/europarl-v7-de-en.tgz" -C "${OUTPUT_DIR_DATA}/europarl-v7-de-en" +mkdir -p "${OUTPUT_DIR_DATA}/common-crawl" +tar -xvzf "${OUTPUT_DIR_DATA}/common-crawl.tgz" -C "${OUTPUT_DIR_DATA}/common-crawl" +mkdir -p "${OUTPUT_DIR_DATA}/nc-v11" +tar -xvzf "${OUTPUT_DIR_DATA}/nc-v11.tgz" -C "${OUTPUT_DIR_DATA}/nc-v11" +mkdir -p "${OUTPUT_DIR_DATA}/dev" +tar -xvzf "${OUTPUT_DIR_DATA}/dev.tgz" -C "${OUTPUT_DIR_DATA}/dev" +mkdir -p "${OUTPUT_DIR_DATA}/test" +tar -xvzf "${OUTPUT_DIR_DATA}/test.tgz" -C "${OUTPUT_DIR_DATA}/test" + +# Concatenate Training data +cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.en" \ + "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.en" \ + "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.en" \ + > "${OUTPUT_DIR}/train.en" +wc -l "${OUTPUT_DIR}/train.en" + +cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.de-en.de" \ + "${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.de-en.de" \ + "${OUTPUT_DIR_DATA}/nc-v11/training-parallel-nc-v11/news-commentary-v11.de-en.de" \ + > "${OUTPUT_DIR}/train.de" +wc -l "${OUTPUT_DIR}/train.de" + +# Clone Moses +if [ ! -d "${OUTPUT_DIR}/mosesdecoder" ]; then + echo "Cloning moses for data processing" + git clone https://github.com/moses-smt/mosesdecoder.git "${OUTPUT_DIR}/mosesdecoder" + cd ${OUTPUT_DIR}/mosesdecoder + git reset --hard 8c5eaa1a122236bbf927bde4ec610906fea599e6 + cd - +fi + +# Convert SGM files +# Convert newstest2014 data into raw text format +${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ + < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-src.de.sgm \ + > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.de +${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ + < ${OUTPUT_DIR_DATA}/dev/dev/newstest2014-deen-ref.en.sgm \ + > ${OUTPUT_DIR_DATA}/dev/dev/newstest2014.en + +# Convert newstest2015 data into raw text format +${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ + < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-src.de.sgm \ + > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.de +${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ + < ${OUTPUT_DIR_DATA}/dev/dev/newstest2015-deen-ref.en.sgm \ + > ${OUTPUT_DIR_DATA}/dev/dev/newstest2015.en + +# Convert newstest2016 data into raw text format +${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ + < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-src.de.sgm \ + > ${OUTPUT_DIR_DATA}/test/test/newstest2016.de +${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \ + < ${OUTPUT_DIR_DATA}/test/test/newstest2016-deen-ref.en.sgm \ + > ${OUTPUT_DIR_DATA}/test/test/newstest2016.en + +# Copy dev/test data to output dir +cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.de ${OUTPUT_DIR} +cp ${OUTPUT_DIR_DATA}/dev/dev/newstest20*.en ${OUTPUT_DIR} +cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.de ${OUTPUT_DIR} +cp ${OUTPUT_DIR_DATA}/test/test/newstest20*.en ${OUTPUT_DIR} + +# Tokenize data +for f in ${OUTPUT_DIR}/*.de; do + echo "Tokenizing $f..." + ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l de -threads 8 < $f > ${f%.*}.tok.de +done + +for f in ${OUTPUT_DIR}/*.en; do + echo "Tokenizing $f..." + ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l en -threads 8 < $f > ${f%.*}.tok.en +done + +# Clean all corpora +for f in ${OUTPUT_DIR}/*.en; do + fbase=${f%.*} + echo "Cleaning ${fbase}..." + ${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $fbase de en "${fbase}.clean" 1 80 +done + +# Create dev dataset +cat "${OUTPUT_DIR}/newstest2015.tok.clean.en" \ + "${OUTPUT_DIR}/newstest2016.tok.clean.en" \ + > "${OUTPUT_DIR}/newstest_dev.tok.clean.en" + +cat "${OUTPUT_DIR}/newstest2015.tok.clean.de" \ + "${OUTPUT_DIR}/newstest2016.tok.clean.de" \ + > "${OUTPUT_DIR}/newstest_dev.tok.clean.de" + +# Filter datasets +python3 scripts/filter_dataset.py -f1 ${OUTPUT_DIR}/train.tok.clean.en -f2 ${OUTPUT_DIR}/train.tok.clean.de +python3 scripts/filter_dataset.py -f1 ${OUTPUT_DIR}/newstest_dev.tok.clean.en -f2 ${OUTPUT_DIR}/newstest_dev.tok.clean.de + +# Generate Subword Units (BPE) +# Clone Subword NMT +if [ ! -d "${OUTPUT_DIR}/subword-nmt" ]; then + git clone https://github.com/rsennrich/subword-nmt.git "${OUTPUT_DIR}/subword-nmt" + cd ${OUTPUT_DIR}/subword-nmt + git reset --hard 48ba99e657591c329e0003f0c6e32e493fa959ef + cd - +fi + +# Learn Shared BPE +for merge_ops in 32000; do + echo "Learning BPE with merge_ops=${merge_ops}. This may take a while..." + cat "${OUTPUT_DIR}/train.tok.clean.de" "${OUTPUT_DIR}/train.tok.clean.en" | \ + ${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $merge_ops > "${OUTPUT_DIR}/bpe.${merge_ops}" + + echo "Apply BPE with merge_ops=${merge_ops} to tokenized files..." + for lang in en de; do + for f in ${OUTPUT_DIR}/*.tok.${lang} ${OUTPUT_DIR}/*.tok.clean.${lang}; do + outfile="${f%.*}.bpe.${merge_ops}.${lang}" + ${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c "${OUTPUT_DIR}/bpe.${merge_ops}" < $f > "${outfile}" + echo ${outfile} + done + done + + # Create vocabulary file for BPE + cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \ + ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" + +done + +echo "All done." diff --git a/examples/GNMT/download_trained_model.sh b/examples/GNMT/download_trained_model.sh new file mode 100644 index 0000000..77a4a23 --- /dev/null +++ b/examples/GNMT/download_trained_model.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +wget https://zenodo.org/record/2581623/files/model_best.pth + diff --git a/examples/GNMT/model_stats.yaml b/examples/GNMT/model_stats.yaml new file mode 100644 index 0000000..32de9d7 --- /dev/null +++ b/examples/GNMT/model_stats.yaml @@ -0,0 +1,2868 @@ +encoder.rnn_layers.0.cells.0.fc_gate_x: + inputs: + 0: + min: -5.349642276763916 + max: 5.236245632171631 + avg_min: -2.8281705602780907 + avg_max: 2.8138245348501125 + mean: -0.0016106524300209288 + std: 0.8805053743938167 + shape: (1, 1024) + output: + min: -33.63774490356445 + max: 32.98709487915039 + avg_min: -19.3743170070767 + avg_max: 18.377735952944196 + mean: -1.2778008344058638 + std: 5.574481019638121 + shape: (1, 4096) +encoder.rnn_layers.0.cells.0.fc_gate_h: + inputs: + 0: + min: -1.0 + max: 0.9997552633285522 + avg_min: -0.9242292824059608 + avg_max: 0.8228753872990404 + mean: -0.002536592865063468 + std: 0.23544047089016987 + shape: (1, 1024) + output: + min: -22.65224838256836 + max: 17.655214309692383 + avg_min: -11.509061529416135 + avg_max: 6.881520398276833 + mean: -2.373206253470364 + std: 2.8913670592601197 + shape: (1, 4096) +encoder.rnn_layers.0.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -33.63774490356445 + max: 32.98709487915039 + avg_min: -19.3743170070767 + avg_max: 18.377735952944196 + mean: -1.2778008344058638 + std: 5.574481019638121 + shape: (1, 4096) + 1: + min: -22.65224838256836 + max: 17.655214309692383 + avg_min: -11.509061529416135 + avg_max: 6.881520398276833 + mean: -2.373206253470364 + std: 2.8913670592601197 + shape: (1, 4096) + output: + min: -40.20295333862305 + max: 34.10171890258789 + avg_min: -23.746535006890156 + avg_max: 18.852929802167523 + mean: -3.6510070722635595 + std: 6.424404119623572 + shape: (1, 4096) +encoder.rnn_layers.0.cells.0.act_f: + inputs: + 0: + min: -40.20295333862305 + max: 30.254833221435547 + avg_min: -23.296182616223867 + avg_max: 11.772601394017986 + mean: -7.432837164007311 + std: 5.768370101504384 + shape: (1, 1024) + output: + min: 3.4680012470380246e-18 + max: 1.0 + avg_min: 2.613044365988191e-07 + avg_max: 0.9997981336035177 + mean: 0.11436928112791814 + std: 0.2739924793333765 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.act_i: + inputs: + 0: + min: -33.06312561035156 + max: 34.10171890258789 + avg_min: -20.231574570140438 + avg_max: 15.238747643933712 + mean: -3.533190907395537 + std: 5.892242776130438 + shape: (1, 1024) + output: + min: 4.37388120446097e-15 + max: 1.0 + avg_min: 1.3271535511620103e-07 + avg_max: 0.9999727739693318 + mean: 0.27593418032485884 + std: 0.3954827759906979 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.act_o: + inputs: + 0: + min: -34.778221130371094 + max: 30.507577896118164 + avg_min: -20.07659851906139 + avg_max: 14.524603366545767 + mean: -3.567072719363816 + std: 5.69751302438932 + shape: (1, 1024) + output: + min: 7.870648145337088e-16 + max: 1.0 + avg_min: 2.206335442308734e-07 + avg_max: 0.9999628556888489 + mean: 0.2618299175561274 + std: 0.3833761347669008 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.act_g: + inputs: + 0: + min: -33.982521057128906 + max: 33.54863739013672 + avg_min: -19.17722672297466 + avg_max: 18.48774343457505 + mean: -0.0709275224609481 + std: 6.116932441002841 + shape: (1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -1.0 + avg_max: 0.9999999999947166 + mean: -0.007298258008836419 + std: 0.9385271485476292 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 3.4680012470380246e-18 + max: 1.0 + avg_min: 2.613044365988191e-07 + avg_max: 0.9997981336035177 + mean: 0.11436928112791814 + std: 0.2739924793333765 + shape: (1, 1024) + 1: + min: -38.982444763183594 + max: 5.8226799964904785 + avg_min: -5.096891815574018 + avg_max: 1.6211083284579975 + mean: -0.00908752453454041 + std: 0.5260199992435526 + shape: (1, 1024) + output: + min: -38.97710418701172 + max: 4.823007583618164 + avg_min: -4.562025848362994 + avg_max: 1.0867458244325479 + mean: -0.0052387909689317535 + std: 0.24397997338840743 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 4.37388120446097e-15 + max: 1.0 + avg_min: 1.3271535511620103e-07 + avg_max: 0.9999727739693318 + mean: 0.27593418032485884 + std: 0.3954827759906979 + shape: (1, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -1.0 + avg_max: 0.9999999999947166 + mean: -0.007298258008836419 + std: 0.9385271485476292 + shape: (1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999414832227521 + avg_max: 0.9998500162423687 + mean: -0.004381999838699326 + std: 0.4680802328708717 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -38.97710418701172 + max: 4.823007583618164 + avg_min: -4.562025848362994 + avg_max: 1.0867458244325479 + mean: -0.0052387909689317535 + std: 0.24397997338840743 + shape: (1, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999414832227521 + avg_max: 0.9998500162423687 + mean: -0.004381999838699326 + std: 0.4680802328708717 + shape: (1, 1024) + output: + min: -39.97100830078125 + max: 5.8226799964904785 + avg_min: -5.376022642979372 + avg_max: 1.6904500271956118 + mean: -0.009620790795830529 + std: 0.5374170427844072 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.act_h: + inputs: + 0: + min: -39.97100830078125 + max: 5.8226799964904785 + avg_min: -5.376022642979372 + avg_max: 1.6904500271956118 + mean: -0.009620790795830529 + std: 0.5374170427844072 + shape: (1, 1024) + output: + min: -1.0 + max: 0.999982476234436 + avg_min: -0.9784979219253714 + avg_max: 0.9111029639984748 + mean: -0.0042562744052456435 + std: 0.3862437789687602 + shape: (1, 1024) +encoder.rnn_layers.0.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 7.870648145337088e-16 + max: 1.0 + avg_min: 2.206335442308734e-07 + avg_max: 0.9999628556888489 + mean: 0.2618299175561274 + std: 0.3833761347669008 + shape: (1, 1024) + 1: + min: -1.0 + max: 0.999982476234436 + avg_min: -0.9784979219253714 + avg_max: 0.9111029639984748 + mean: -0.0042562744052456435 + std: 0.3862437789687602 + shape: (1, 1024) + output: + min: -1.0 + max: 0.9997552633285522 + avg_min: -0.957775045493643 + avg_max: 0.8561479863309971 + mean: -0.0026583670083781055 + std: 0.24113962918670837 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.fc_gate_x: + inputs: + 0: + min: -5.349642276763916 + max: 5.236245632171631 + avg_min: -2.828170560278106 + avg_max: 2.8138245348501405 + mean: -0.0016106524300209281 + std: 0.8805050697722259 + shape: (1, 1024) + output: + min: -33.94660949707031 + max: 31.446853637695312 + avg_min: -19.276340029144965 + avg_max: 18.700496405858367 + mean: -0.4969901222221274 + std: 5.666540069661909 + shape: (1, 4096) +encoder.rnn_layers.0.cells_reverse.0.fc_gate_h: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.9213716166421844 + avg_max: 0.9217871793855027 + mean: -0.0008159191441113994 + std: 0.2933630896077876 + shape: (1, 1024) + output: + min: -26.57025146484375 + max: 27.723560333251953 + avg_min: -13.664213478780777 + avg_max: 13.256404327800057 + mean: -1.6347749916045144 + std: 3.3098913033147026 + shape: (1, 4096) +encoder.rnn_layers.0.cells_reverse.0.eltwiseadd_gate: + inputs: + 0: + min: -33.94660949707031 + max: 31.446853637695312 + avg_min: -19.276340029144965 + avg_max: 18.700496405858367 + mean: -0.4969901222221274 + std: 5.666540069661909 + shape: (1, 4096) + 1: + min: -26.57025146484375 + max: 27.723560333251953 + avg_min: -13.664213478780777 + avg_max: 13.256404327800057 + mean: -1.6347749916045144 + std: 3.3098913033147026 + shape: (1, 4096) + output: + min: -38.69415283203125 + max: 36.723724365234375 + avg_min: -23.688260936494675 + avg_max: 21.477238427920927 + mean: -2.131765110762214 + std: 6.632126340207978 + shape: (1, 4096) +encoder.rnn_layers.0.cells_reverse.0.act_f: + inputs: + 0: + min: -38.21221160888672 + max: 35.994014739990234 + avg_min: -21.775707706900725 + avg_max: 17.50440279642364 + mean: -3.8737916625820072 + std: 6.84193347529921 + shape: (1, 1024) + output: + min: 2.5389102872982082e-17 + max: 1.0 + avg_min: 1.0996477415504891e-07 + avg_max: 0.999999346159746 + mean: 0.2890327369231582 + std: 0.4023833394248511 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.act_i: + inputs: + 0: + min: -38.2786750793457 + max: 32.760406494140625 + avg_min: -21.206531206573878 + avg_max: 16.691689617168617 + mean: -3.073931181881033 + std: 6.2677184148953256 + shape: (1, 1024) + output: + min: 2.3756509732875972e-17 + max: 1.0 + avg_min: 1.7373452237735842e-08 + avg_max: 0.9999988585396135 + mean: 0.3108025884336769 + std: 0.4110379687299371 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.act_o: + inputs: + 0: + min: -38.69415283203125 + max: 33.763919830322266 + avg_min: -19.567437244926133 + avg_max: 17.492547388038638 + mean: -1.5878526072804218 + std: 6.0648371330715545 + shape: (1, 1024) + output: + min: 1.5679886799044933e-17 + max: 1.0 + avg_min: 7.394662946045065e-08 + avg_max: 0.9999994433218858 + mean: 0.39423968600374154 + std: 0.4279581642861166 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.act_g: + inputs: + 0: + min: -37.34138870239258 + max: 36.723724365234375 + avg_min: -20.704529739459048 + avg_max: 20.84152177737822 + mean: 0.008514990976159644 + std: 6.647000615879316 + shape: (1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -1.0 + avg_max: 1.0 + mean: 0.0014180161330502351 + std: 0.9442339332189296 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.eltwisemult_cell_forget: + inputs: + 0: + min: 2.5389102872982082e-17 + max: 1.0 + avg_min: 1.0996477415504891e-07 + avg_max: 0.999999346159746 + mean: 0.2890327369231582 + std: 0.4023833394248511 + shape: (1, 1024) + 1: + min: -29.373083114624023 + max: 35.754878997802734 + avg_min: -3.36926156270743 + avg_max: 3.9063054220693965 + mean: -0.0008279334693150963 + std: 0.6366862423327254 + shape: (1, 1024) + output: + min: -29.373075485229492 + max: 34.76329040527344 + avg_min: -3.194934033077149 + avg_max: 3.729673293131146 + mean: 0.00015181826198886874 + std: 0.3949363630374722 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.eltwisemult_cell_input: + inputs: + 0: + min: 2.3756509732875972e-17 + max: 1.0 + avg_min: 1.7373452237735842e-08 + avg_max: 0.9999988585396135 + mean: 0.3108025884336769 + std: 0.4110379687299371 + shape: (1, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -1.0 + avg_max: 1.0 + mean: 0.0014180161330502351 + std: 0.9442339332189296 + shape: (1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.999991173282417 + avg_max: 0.9999928301513564 + mean: -0.0005279613764160218 + std: 0.5003141965724145 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.eltwiseadd_cell: + inputs: + 0: + min: -29.373075485229492 + max: 34.76329040527344 + avg_min: -3.194934033077149 + avg_max: 3.729673293131146 + mean: 0.00015181826198886874 + std: 0.3949363630374722 + shape: (1, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.999991173282417 + avg_max: 0.9999928301513564 + mean: -0.0005279613764160218 + std: 0.5003141965724145 + shape: (1, 1024) + output: + min: -29.373083114624023 + max: 35.754878997802734 + avg_min: -3.50784800743793 + avg_max: 4.100828251393332 + mean: -0.00037614313566861415 + std: 0.6483412291978414 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.act_h: + inputs: + 0: + min: -29.373083114624023 + max: 35.754878997802734 + avg_min: -3.50784800743793 + avg_max: 4.100828251393332 + mean: -0.00037614313566861415 + std: 0.6483412291978414 + shape: (1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9818865314740792 + avg_max: 0.9818776522137024 + mean: -0.0013970042477792044 + std: 0.45158473029410123 + shape: (1, 1024) +encoder.rnn_layers.0.cells_reverse.0.eltwisemult_hidden: + inputs: + 0: + min: 1.5679886799044933e-17 + max: 1.0 + avg_min: 7.394662946045065e-08 + avg_max: 0.9999994433218858 + mean: 0.39423968600374154 + std: 0.4279581642861166 + shape: (1, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9818865314740792 + avg_max: 0.9818776522137024 + mean: -0.0013970042477792044 + std: 0.45158473029410123 + shape: (1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9556271536847694 + avg_max: 0.9564541215110065 + mean: -0.0005957838928230832 + std: 0.2985633225360773 + shape: (1, 1024) +encoder.rnn_layers.0.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +encoder.rnn_layers.1.cells.0.fc_gate_x: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.3797274232525334 + avg_max: 0.3723408187603086 + mean: -0.0006261473494000447 + std: 0.1684595854192406 + shape: (128, 2048) + output: + min: -28.018692016601562 + max: 26.689453125 + avg_min: -5.231459095807397 + avg_max: 4.8616450890998655 + mean: -0.5998025477032523 + std: 2.1475245530098204 + shape: (128, 4096) +encoder.rnn_layers.1.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9999995231628418 + max: 0.9999927878379822 + avg_min: -0.9296687000683234 + avg_max: 0.9067187918312595 + mean: -0.00016366511202441184 + std: 0.17426053388600266 + shape: (128, 1024) + output: + min: -22.963897705078125 + max: 23.043453216552734 + avg_min: -7.452188115125264 + avg_max: 8.383283630046785 + mean: -0.9507504072732051 + std: 1.9114254594378783 + shape: (128, 4096) +encoder.rnn_layers.1.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -28.018692016601562 + max: 26.689453125 + avg_min: -5.231459095807397 + avg_max: 4.8616450890998655 + mean: -0.5998025477032523 + std: 2.1475245530098204 + shape: (128, 4096) + 1: + min: -22.963897705078125 + max: 23.043453216552734 + avg_min: -7.452188115125264 + avg_max: 8.383283630046785 + mean: -0.9507504072732051 + std: 1.9114254594378783 + shape: (128, 4096) + output: + min: -31.574262619018555 + max: 32.37839126586914 + avg_min: -10.47135662152461 + avg_max: 10.518839543297785 + mean: -1.5505529552098096 + std: 3.1590192373764134 + shape: (128, 4096) +encoder.rnn_layers.1.cells.0.act_f: + inputs: + 0: + min: -31.574262619018555 + max: 25.064706802368164 + avg_min: -10.176423978218846 + avg_max: 8.090042744440224 + mean: -3.3116955970157718 + std: 3.4230677646590277 + shape: (128, 1024) + output: + min: 1.9385273957893065e-14 + max: 1.0 + avg_min: 0.0009497451324273016 + avg_max: 0.9984393432169701 + mean: 0.16464387024575694 + std: 0.2689671071958494 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.act_i: + inputs: + 0: + min: -29.14598274230957 + max: 28.608158111572266 + avg_min: -8.99660675317649 + avg_max: 6.228059029819189 + mean: -1.7129896603174661 + std: 2.7638664904974917 + shape: (128, 1024) + output: + min: 2.1981662367068222e-13 + max: 1.0 + avg_min: 0.0017008052480266586 + avg_max: 0.9786340831643511 + mean: 0.2868795179297321 + std: 0.27602027714229127 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.act_o: + inputs: + 0: + min: -25.69948959350586 + max: 27.677391052246094 + avg_min: -8.206474038311828 + avg_max: 9.042313837365008 + mean: -1.2052700564898609 + std: 2.6057113237143055 + shape: (128, 1024) + output: + min: 6.9000694914722605e-12 + max: 1.0 + avg_min: 0.0033538953499784242 + avg_max: 0.9974105607476547 + mean: 0.33604429936415675 + std: 0.2886479601419748 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.act_g: + inputs: + 0: + min: -28.827739715576172 + max: 32.37839126586914 + avg_min: -8.882986145798268 + avg_max: 8.65820987619275 + mean: 0.027743494574744162 + std: 2.8342722477735816 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999114334316583 + avg_max: 0.9996711071705672 + mean: 0.015988719393334825 + std: 0.7484585498246807 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 1.9385273957893065e-14 + max: 1.0 + avg_min: 0.0009497451324273016 + avg_max: 0.9984393432169701 + mean: 0.16464387024575694 + std: 0.2689671071958494 + shape: (128, 1024) + 1: + min: -86.69337463378906 + max: 80.73381805419922 + avg_min: -12.103038570158164 + avg_max: 12.752722294005268 + mean: -0.023152882641702514 + std: 1.2572998992641964 + shape: (128, 1024) + output: + min: -86.67153930664062 + max: 80.55928802490234 + avg_min: -11.985433492784532 + avg_max: 12.574206966781782 + mean: -0.030583657865831754 + std: 1.2066110910522718 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 2.1981662367068222e-13 + max: 1.0 + avg_min: 0.0017008052480266586 + avg_max: 0.9786340831643511 + mean: 0.2868795179297321 + std: 0.27602027714229127 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999114334316583 + avg_max: 0.9996711071705672 + mean: 0.015988719393334825 + std: 0.7484585498246807 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9046535779632451 + avg_max: 0.9683758158411752 + mean: 0.006315402761504094 + std: 0.3088953017599018 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -86.67153930664062 + max: 80.55928802490234 + avg_min: -11.985433492784532 + avg_max: 12.574206966781782 + mean: -0.030583657865831754 + std: 1.2066110910522718 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9046535779632451 + avg_max: 0.9683758158411752 + mean: 0.006315402761504094 + std: 0.3088953017599018 + shape: (128, 1024) + output: + min: -87.62094116210938 + max: 81.55608367919922 + avg_min: -12.543978190728746 + avg_max: 13.23383954467387 + mean: -0.024268255073244073 + std: 1.2882582959130635 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.act_h: + inputs: + 0: + min: -87.62094116210938 + max: 81.55608367919922 + avg_min: -12.543978190728746 + avg_max: 13.23383954467387 + mean: -0.024268255073244073 + std: 1.2882582959130635 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9870239480229829 + avg_max: 0.982027827093264 + mean: 0.005185764388532731 + std: 0.3119488266195677 + shape: (128, 1024) +encoder.rnn_layers.1.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 6.9000694914722605e-12 + max: 1.0 + avg_min: 0.0033538953499784242 + avg_max: 0.9974105607476547 + mean: 0.33604429936415675 + std: 0.2886479601419748 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9870239480229829 + avg_max: 0.982027827093264 + mean: 0.005185764388532731 + std: 0.3119488266195677 + shape: (128, 1024) + output: + min: -0.9999995231628418 + max: 0.9999927878379822 + avg_min: -0.9430685600378369 + avg_max: 0.9199835463058242 + mean: -0.00017594442323635664 + std: 0.17510466702537447 + shape: (128, 1024) +encoder.rnn_layers.1.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +encoder.rnn_layers.2.cells.0.fc_gate_x: + inputs: + 0: + min: -0.9999995231628418 + max: 0.9999927878379822 + avg_min: -0.9430685600378369 + avg_max: 0.9199835463058242 + mean: -0.00017594442323635664 + std: 0.17510466702537447 + shape: (128, 1024) + output: + min: -17.669666290283203 + max: 18.855335235595703 + avg_min: -5.826010705240624 + avg_max: 5.579422310161379 + mean: -0.7621265583790381 + std: 1.6214516118322957 + shape: (128, 4096) +encoder.rnn_layers.2.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9999260902404785 + max: 0.999937891960144 + avg_min: -0.9288329214057661 + avg_max: 0.8692965302824718 + mean: -0.002193973255493973 + std: 0.13664756692926694 + shape: (128, 1024) + output: + min: -12.899184226989746 + max: 12.18427848815918 + avg_min: -5.383265833059953 + avg_max: 6.135640713992537 + mean: -0.7798955419852971 + std: 1.2988276201978577 + shape: (128, 4096) +encoder.rnn_layers.2.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -17.669666290283203 + max: 18.855335235595703 + avg_min: -5.826010705240624 + avg_max: 5.579422310161379 + mean: -0.7621265583790381 + std: 1.6214516118322957 + shape: (128, 4096) + 1: + min: -12.899184226989746 + max: 12.18427848815918 + avg_min: -5.383265833059953 + avg_max: 6.135640713992537 + mean: -0.7798955419852971 + std: 1.2988276201978577 + shape: (128, 4096) + output: + min: -22.86528968811035 + max: 22.331012725830078 + avg_min: -9.460085806430587 + avg_max: 10.091681053441096 + mean: -1.5420221014644224 + std: 2.4497739415101045 + shape: (128, 4096) +encoder.rnn_layers.2.cells.0.act_f: + inputs: + 0: + min: -22.86528968811035 + max: 20.687637329101562 + avg_min: -8.921805253498242 + avg_max: 9.76669416798307 + mean: -1.7373663354206654 + std: 2.701828662299363 + shape: (128, 1024) + output: + min: 1.1741696503975163e-10 + max: 1.0 + avg_min: 0.0011454785276338222 + avg_max: 0.9971977971010819 + mean: 0.2779621332383798 + std: 0.3120235028974232 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.act_i: + inputs: + 0: + min: -19.433446884155273 + max: 22.331012725830078 + avg_min: -9.11351156621439 + avg_max: 5.007397494993483 + mean: -2.482964531553934 + std: 2.096762127403773 + shape: (128, 1024) + output: + min: 3.632128597885753e-09 + max: 1.0 + avg_min: 0.0005238509440227943 + avg_max: 0.9815354767928449 + mean: 0.17026146355457997 + std: 0.22359512877822252 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.act_o: + inputs: + 0: + min: -20.721616744995117 + max: 21.40572166442871 + avg_min: -7.591665619305041 + avg_max: 6.900344703221481 + mean: -1.9338501979173957 + std: 2.188371648150393 + shape: (128, 1024) + output: + min: 1.0016504292664763e-09 + max: 1.0 + avg_min: 0.0014563385881592266 + avg_max: 0.996341593846912 + mean: 0.2302334359294881 + std: 0.2644176532615949 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.act_g: + inputs: + 0: + min: -20.245834350585938 + max: 17.422372817993164 + avg_min: -7.094027656303417 + avg_max: 6.198857292499591 + mean: -0.013907334995555003 + std: 2.0268158859398273 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999658937755573 + avg_max: 0.999913128430412 + mean: 0.006783831941798703 + std: 0.7667732969711837 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 1.1741696503975163e-10 + max: 1.0 + avg_min: 0.0011454785276338222 + avg_max: 0.9971977971010819 + mean: 0.2779621332383798 + std: 0.3120235028974232 + shape: (128, 1024) + 1: + min: -70.65399169921875 + max: 79.07190704345703 + avg_min: -9.986505569694323 + avg_max: 11.433578196754656 + mean: -0.007533524133234033 + std: 0.9619221798317081 + shape: (128, 1024) + output: + min: -70.47572326660156 + max: 79.0152816772461 + avg_min: -9.786985937847655 + avg_max: 11.29573609329911 + mean: -0.0077121059270617645 + std: 0.9119806954490677 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 3.632128597885753e-09 + max: 1.0 + avg_min: 0.0005238509440227943 + avg_max: 0.9815354767928449 + mean: 0.17026146355457997 + std: 0.22359512877822252 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999658937755573 + avg_max: 0.999913128430412 + mean: 0.006783831941798703 + std: 0.7667732969711837 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9706906902643418 + avg_max: 0.9338277108983972 + mean: 0.00031330419493738533 + std: 0.23285220882823307 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -70.47572326660156 + max: 79.0152816772461 + avg_min: -9.786985937847655 + avg_max: 11.29573609329911 + mean: -0.0077121059270617645 + std: 0.9119806954490677 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9706906902643418 + avg_max: 0.9338277108983972 + mean: 0.00031330419493738533 + std: 0.23285220882823307 + shape: (128, 1024) + output: + min: -71.34825897216797 + max: 79.90394592285156 + avg_min: -10.283224373522492 + avg_max: 11.882522354690009 + mean: -0.007398801722887531 + std: 0.9852414934516581 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.act_h: + inputs: + 0: + min: -71.34825897216797 + max: 79.90394592285156 + avg_min: -10.283224373522492 + avg_max: 11.882522354690009 + mean: -0.007398801722887531 + std: 0.9852414934516581 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9912044704627142 + avg_max: 0.9643497033340547 + mean: -0.0026311075007701672 + std: 0.2839022589840224 + shape: (128, 1024) +encoder.rnn_layers.2.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 1.0016504292664763e-09 + max: 1.0 + avg_min: 0.0014563385881592266 + avg_max: 0.996341593846912 + mean: 0.2302334359294881 + std: 0.2644176532615949 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9912044704627142 + avg_max: 0.9643497033340547 + mean: -0.0026311075007701672 + std: 0.2839022589840224 + shape: (128, 1024) + output: + min: -0.9999260902404785 + max: 0.999937891960144 + avg_min: -0.9419258169546493 + avg_max: 0.8823925270396857 + mean: -0.002203096514809122 + std: 0.1377186248199642 + shape: (128, 1024) +encoder.rnn_layers.2.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +encoder.rnn_layers.3.cells.0.fc_gate_x: + inputs: + 0: + min: -1.999041199684143 + max: 1.996099829673767 + avg_min: -1.4410850684714778 + avg_max: 1.4346018574915211 + mean: -0.0023790409310782876 + std: 0.23082689970193895 + shape: (128, 1024) + output: + min: -18.958541870117188 + max: 20.005878448486328 + avg_min: -6.6773925654306705 + avg_max: 7.386176010906286 + mean: -0.7425925570313029 + std: 1.9741718173439569 + shape: (128, 4096) +encoder.rnn_layers.3.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9999996423721313 + max: 0.9999986290931702 + avg_min: -0.9598443289464481 + avg_max: 0.9680624933584159 + mean: -5.464283497081485e-05 + std: 0.1605789723297388 + shape: (128, 1024) + output: + min: -17.731414794921875 + max: 16.26653289794922 + avg_min: -7.986687395696665 + avg_max: 8.483382081558762 + mean: -1.0937484681389607 + std: 1.8863347454511785 + shape: (128, 4096) +encoder.rnn_layers.3.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -18.958541870117188 + max: 20.005878448486328 + avg_min: -6.6773925654306705 + avg_max: 7.386176010906286 + mean: -0.7425925570313029 + std: 1.9741718173439569 + shape: (128, 4096) + 1: + min: -17.731414794921875 + max: 16.26653289794922 + avg_min: -7.986687395696665 + avg_max: 8.483382081558762 + mean: -1.0937484681389607 + std: 1.8863347454511785 + shape: (128, 4096) + output: + min: -25.56599998474121 + max: 24.13156509399414 + avg_min: -12.193620794438145 + avg_max: 13.726210550973875 + mean: -1.8363410236118098 + std: 3.220885522048646 + shape: (128, 4096) +encoder.rnn_layers.3.cells.0.act_f: + inputs: + 0: + min: -25.56599998474121 + max: 24.13156509399414 + avg_min: -10.901736432407242 + avg_max: 11.459924239973608 + mean: -1.7886736150289273 + std: 3.611997766132868 + shape: (128, 1024) + output: + min: 7.885464850532209e-12 + max: 1.0 + avg_min: 0.001426324373246965 + avg_max: 0.9981180747150038 + mean: 0.3085659271068626 + std: 0.362517692693201 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.act_i: + inputs: + 0: + min: -21.904094696044922 + max: 19.92990493774414 + avg_min: -10.301701992016742 + avg_max: 5.387517666390012 + mean: -3.4113261673007766 + std: 2.452103517547047 + shape: (128, 1024) + output: + min: 3.070241560987341e-10 + max: 1.0 + avg_min: 0.0005902890219026971 + avg_max: 0.9767627610722911 + mean: 0.1199813902886455 + std: 0.20597901242826788 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.act_o: + inputs: + 0: + min: -21.795564651489258 + max: 20.31422233581543 + avg_min: -9.119528548429464 + avg_max: 12.915865638912122 + mean: -2.1100439747307944 + std: 2.7725332585946427 + shape: (128, 1024) + output: + min: 3.422208905146107e-10 + max: 1.0 + avg_min: 0.0007607025358741477 + avg_max: 0.9998594416467933 + mean: 0.23564658800694233 + std: 0.2900290182410238 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.act_g: + inputs: + 0: + min: -22.5491943359375 + max: 21.174640655517578 + avg_min: -11.443979141296161 + avg_max: 10.109944212490001 + mean: -0.035320331854928055 + std: 2.991404617719586 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999641429391214 + avg_max: 0.9999893591374617 + mean: -0.017447813011902232 + std: 0.8383612744760685 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 7.885464850532209e-12 + max: 1.0 + avg_min: 0.001426324373246965 + avg_max: 0.9981180747150038 + mean: 0.3085659271068626 + std: 0.362517692693201 + shape: (128, 1024) + 1: + min: -72.4426498413086 + max: 66.09469604492188 + avg_min: -13.666880435458225 + avg_max: 12.621639567200226 + mean: 0.001298973259466056 + std: 1.3620516853160491 + shape: (128, 1024) + output: + min: -72.35834503173828 + max: 66.04717254638672 + avg_min: -13.53175763829178 + avg_max: 12.483070900776225 + mean: -0.0005504995357517113 + std: 1.3219009414904979 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 3.070241560987341e-10 + max: 1.0 + avg_min: 0.0005902890219026971 + avg_max: 0.9767627610722911 + mean: 0.1199813902886455 + std: 0.20597901242826788 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999641429391214 + avg_max: 0.9999893591374617 + mean: -0.017447813011902232 + std: 0.8383612744760685 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9390226983917377 + avg_max: 0.9557934822118774 + mean: 0.0013386550747985504 + std: 0.21654626606567212 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -72.35834503173828 + max: 66.04717254638672 + avg_min: -13.53175763829178 + avg_max: 12.483070900776225 + mean: -0.0005504995357517113 + std: 1.3219009414904979 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9390226983917377 + avg_max: 0.9557934822118774 + mean: 0.0013386550747985504 + std: 0.21654626606567212 + shape: (128, 1024) + output: + min: -73.15592956542969 + max: 66.85343170166016 + avg_min: -14.088761083898394 + avg_max: 12.986537455179008 + mean: 0.0007881555251539649 + std: 1.3934345367324847 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.act_h: + inputs: + 0: + min: -73.15592956542969 + max: 66.85343170166016 + avg_min: -14.088761083898394 + avg_max: 12.986537455179008 + mean: 0.0007881555251539649 + std: 1.3934345367324847 + shape: (128, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9939409935101039 + avg_max: 0.9954337507019633 + mean: 0.005215724385428704 + std: 0.3152480532527837 + shape: (128, 1024) +encoder.rnn_layers.3.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 3.422208905146107e-10 + max: 1.0 + avg_min: 0.0007607025358741477 + avg_max: 0.9998594416467933 + mean: 0.23564658800694233 + std: 0.2900290182410238 + shape: (128, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9939409935101039 + avg_max: 0.9954337507019633 + mean: 0.005215724385428704 + std: 0.3152480532527837 + shape: (128, 1024) + output: + min: -0.9999996423721313 + max: 0.9999986290931702 + avg_min: -0.973251596759896 + avg_max: 0.981412514344157 + mean: -8.919282888996098e-05 + std: 0.16194114228528503 + shape: (128, 1024) +encoder.rnn_layers.3.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +encoder.dropout: + inputs: + 0: + min: -1.999041199684143 + max: 1.996099829673767 + avg_min: -1.2932450498143837 + avg_max: 1.2545068023933308 + mean: -0.0010675477744067996 + std: 0.19403692465489641 + shape: (128, 67, 2048) + output: + min: -1.999041199684143 + max: 1.996099829673767 + avg_min: -1.2932450498143837 + avg_max: 1.2545068023933308 + mean: -0.0010675477744067996 + std: 0.19403692465489641 + shape: (128, 67, 2048) +encoder.embedder: + inputs: + 0: + min: 0 + max: 32103 + avg_min: 3.2019918366436153 + avg_max: 24950.48025676814 + mean: 2937.1528870343564 + std: 6014.0608835026405 + shape: (128, 67) + output: + min: -5.349642276763916 + max: 5.236245632171631 + avg_min: -2.757863937090343 + avg_max: 2.8868405888984556 + mean: 0.0012734194185268903 + std: 0.8584338496588495 + shape: (128, 67, 1024) +encoder.eltwiseadd_residuals.0: + inputs: + 0: + min: -0.9999260902404785 + max: 0.999937891960144 + avg_min: -0.9973272457718849 + avg_max: 0.9888223583499591 + mean: -0.0022379238847255083 + std: 0.13743977474031957 + shape: (128, 67, 1024) + 1: + min: -0.9999995231628418 + max: 0.9999927878379822 + avg_min: -0.9997136642535528 + avg_max: 0.995096003015836 + mean: -0.0001647915998243358 + std: 0.17545655249766362 + shape: (128, 67, 1024) + output: + min: -1.999041199684143 + max: 1.996099829673767 + avg_min: -1.8801900794108708 + avg_max: 1.7707290103038151 + mean: -0.0024027154916742197 + std: 0.2309446271021139 + shape: (128, 67, 1024) +encoder.eltwiseadd_residuals.1: + inputs: + 0: + min: -0.9999996423721313 + max: 0.9999986290931702 + avg_min: -0.9991549079616864 + avg_max: 0.9982590352495511 + mean: -4.938640427099017e-05 + std: 0.16142380765711173 + shape: (128, 67, 1024) + 1: + min: -1.999041199684143 + max: 1.996099829673767 + avg_min: -1.8801900794108708 + avg_max: 1.7707290103038151 + mean: -0.0024027154916742197 + std: 0.2309446271021139 + shape: (128, 67, 1024) + output: + min: -2.9931235313415527 + max: 2.9921934604644775 + avg_min: -2.454954892396927 + avg_max: 2.5017193754514055 + mean: -0.0024521019076928496 + std: 0.2936707415080642 + shape: (128, 67, 1024) +decoder.att_rnn.rnn.cells.0.fc_gate_x: + inputs: + 0: + min: -5.343320369720459 + max: 5.05069637298584 + avg_min: -2.7414190683472004 + avg_max: 2.8718470222494563 + mean: 0.0012989030032717187 + std: 0.8618320366411912 + shape: (1280, 1024) + output: + min: -32.593318939208984 + max: 33.264034271240234 + avg_min: -19.51530230447147 + avg_max: 18.39185942424816 + mean: -0.319708395848218 + std: 5.248436831320663 + shape: (1280, 4096) +decoder.att_rnn.rnn.cells.0.fc_gate_h: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.9713709625970117 + avg_max: 0.9817781538775797 + mean: 0.007541471822687429 + std: 0.3220154108800773 + shape: (1280, 1024) + output: + min: -35.26850891113281 + max: 38.68207550048828 + avg_min: -19.4308424914486 + avg_max: 22.16190616622404 + mean: -1.9407802034043842 + std: 4.835737332097758 + shape: (1280, 4096) +decoder.att_rnn.rnn.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -32.593318939208984 + max: 33.264034271240234 + avg_min: -19.51530230447147 + avg_max: 18.39185942424816 + mean: -0.319708395848218 + std: 5.248436831320663 + shape: (1280, 4096) + 1: + min: -35.26850891113281 + max: 38.68207550048828 + avg_min: -19.4308424914486 + avg_max: 22.16190616622404 + mean: -1.9407802034043842 + std: 4.835737332097758 + shape: (1280, 4096) + output: + min: -45.18857192993164 + max: 48.00703430175781 + avg_min: -26.367431187897573 + avg_max: 27.033264856660004 + mean: -2.260488603289206 + std: 7.086690069945273 + shape: (1280, 4096) +decoder.att_rnn.rnn.cells.0.act_f: + inputs: + 0: + min: -42.005043029785156 + max: 48.00703430175781 + avg_min: -24.572684604666218 + avg_max: 24.47041631977212 + mean: -5.428667449415389 + std: 7.5857557123719195 + shape: (1280, 1024) + output: + min: 5.720600170657743e-19 + max: 1.0 + avg_min: 1.3068073496415576e-06 + avg_max: 0.9999987296174087 + mean: 0.22353376293282823 + std: 0.37551881742884263 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.act_i: + inputs: + 0: + min: -42.485984802246094 + max: 37.89415740966797 + avg_min: -20.737807757666975 + avg_max: 17.757313247745003 + mean: -2.6210182911726867 + std: 6.095373006020184 + shape: (1280, 1024) + output: + min: 3.5364804280239927e-19 + max: 1.0 + avg_min: 2.9600880243534555e-08 + avg_max: 0.9999995308980515 + mean: 0.330330826534649 + std: 0.4106618580468271 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.act_o: + inputs: + 0: + min: -38.675228118896484 + max: 45.59527587890625 + avg_min: -19.614914631039895 + avg_max: 22.38339533752272 + mean: -1.0162639051926945 + std: 6.201834056672583 + shape: (1280, 1024) + output: + min: 1.5979450257881012e-17 + max: 1.0 + avg_min: 7.854096923179388e-08 + avg_max: 0.9999999956133669 + mean: 0.4198552611838568 + std: 0.4295806405099903 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.act_g: + inputs: + 0: + min: -45.18857192993164 + max: 45.650177001953125 + avg_min: -24.537228854318712 + avg_max: 25.075750009129578 + mean: 0.023995243424084333 + std: 7.12668924216872 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.999999995680337 + avg_max: 0.999999995680337 + mean: -0.002128829708629359 + std: 0.9478563806555594 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 5.720600170657743e-19 + max: 1.0 + avg_min: 1.3068073496415576e-06 + avg_max: 0.9999987296174087 + mean: 0.22353376293282823 + std: 0.37551881742884263 + shape: (1280, 1024) + 1: + min: -45.964332580566406 + max: 74.85725402832031 + avg_min: -8.37450361519716 + avg_max: 30.634137733129954 + mean: 0.06729932622641525 + std: 1.6449935154656605 + shape: (1280, 1024) + output: + min: -45.9624137878418 + max: 74.85725402832031 + avg_min: -8.23528243631459 + avg_max: 30.575408138686335 + mean: 0.06604976536836245 + std: 1.5455946584670446 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 3.5364804280239927e-19 + max: 1.0 + avg_min: 2.9600880243534555e-08 + avg_max: 0.9999995308980515 + mean: 0.330330826534649 + std: 0.4106618580468271 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.999999995680337 + avg_max: 0.999999995680337 + mean: -0.002128829708629359 + std: 0.9478563806555594 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999764687894426 + avg_max: 0.9999969378950895 + mean: 0.0027081930132716804 + std: 0.510135969194079 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -45.9624137878418 + max: 74.85725402832031 + avg_min: -8.23528243631459 + avg_max: 30.575408138686335 + mean: 0.06604976536836245 + std: 1.5455946584670446 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999764687894426 + avg_max: 0.9999969378950895 + mean: 0.0027081930132716804 + std: 0.510135969194079 + shape: (1280, 1024) + output: + min: -46.92375183105469 + max: 75.85040283203125 + avg_min: -8.574188237645677 + avg_max: 31.416855715031076 + mean: 0.06875795865061662 + std: 1.676732644640398 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.act_h: + inputs: + 0: + min: -46.92375183105469 + max: 75.85040283203125 + avg_min: -8.574188237645677 + avg_max: 31.416855715031076 + mean: 0.06875795865061662 + std: 1.676732644640398 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9937744413533918 + avg_max: 0.9959949824247473 + mean: 0.0061799199947633445 + std: 0.47477836896857967 + shape: (1280, 1024) +decoder.att_rnn.rnn.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 1.5979450257881012e-17 + max: 1.0 + avg_min: 7.854096923179388e-08 + avg_max: 0.9999999956133669 + mean: 0.4198552611838568 + std: 0.4295806405099903 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9937744413533918 + avg_max: 0.9959949824247473 + mean: 0.0061799199947633445 + std: 0.47477836896857967 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9846230725224109 + avg_max: 0.9952872082423628 + mean: 0.007966578123805529 + std: 0.32187307320175484 + shape: (1280, 1024) +decoder.att_rnn.rnn.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +decoder.att_rnn.attn.linear_q: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.9846230725224109 + avg_max: 0.9952872082423628 + mean: 0.007966578123805529 + std: 0.32187307320175484 + shape: (1280, 1, 1024) + output: + min: -28.324504852294922 + max: 30.094812393188477 + avg_min: -16.16285647060098 + avg_max: 16.370824370223502 + mean: 0.07139282220419639 + std: 5.614923564767698 + shape: (1280, 1, 1024) +decoder.att_rnn.attn.linear_k: + inputs: + 0: + min: -2.9931235313415527 + max: 2.9921934604644775 + avg_min: -2.5274831603082375 + avg_max: 2.571398892563376 + mean: -0.002407128676685295 + std: 0.3183030856867703 + shape: (1280, 67, 1024) + output: + min: -27.933134078979492 + max: 27.290098190307617 + avg_min: -21.23075208556785 + avg_max: 21.645756855439604 + mean: 0.11102438071386868 + std: 4.714025089771699 + shape: (1280, 67, 1024) +decoder.att_rnn.attn.dropout: + inputs: + 0: + min: 0.0 + max: 0.9910882711410522 + avg_min: 1.299117331064688e-05 + avg_max: 0.39005828052340624 + mean: 0.013619308356258466 + std: 0.05484342060778131 + shape: (1280, 1, 67) + output: + min: 0.0 + max: 0.9910882711410522 + avg_min: 1.299117331064688e-05 + avg_max: 0.39005828052340624 + mean: 0.013619308356258466 + std: 0.05484342060778131 + shape: (1280, 1, 67) +decoder.att_rnn.attn.eltwiseadd_qk: + inputs: + 0: + min: -28.324504852294922 + max: 30.094812393188477 + avg_min: -16.16285647060098 + avg_max: 16.370824370223502 + mean: 0.07139282138405703 + std: 5.614948532126775 + shape: (1280, 1, 67, 1024) + 1: + min: -27.933134078979492 + max: 27.290098190307617 + avg_min: -21.23075208556785 + avg_max: 21.645756855439604 + mean: 0.11102438071386868 + std: 4.714025089771699 + shape: (1280, 1, 67, 1024) + output: + min: -42.6893310546875 + max: 41.46746063232422 + avg_min: -26.125514688384673 + avg_max: 26.724065986376143 + mean: 0.18241720259650032 + std: 6.144074303575077 + shape: (1280, 1, 67, 1024) +decoder.att_rnn.attn.eltwiseadd_norm_bias: + inputs: + 0: + min: -42.6893310546875 + max: 41.46746063232422 + avg_min: -26.125514688384673 + avg_max: 26.724065986376143 + mean: 0.18241720259650032 + std: 6.144074303575077 + shape: (1280, 1, 67, 1024) + 1: + min: -1.159269094467163 + max: 1.030885934829712 + avg_min: -1.159269094467163 + avg_max: 1.030885934829712 + mean: 0.01333538442850113 + std: 0.3476333932113675 + shape: (1024) + output: + min: -42.8883171081543 + max: 41.60582733154297 + avg_min: -26.312186408310776 + avg_max: 26.855665168333605 + mean: 0.19575258549512087 + std: 6.287847745163477 + shape: (1280, 1, 67, 1024) +decoder.att_rnn.attn.eltwisemul_norm_scaler: + inputs: + 0: + min: -0.164537250995636 + max: 0.09800389409065247 + avg_min: -0.164537250995636 + avg_max: 0.09800389409065247 + mean: 0.00032652184017933905 + std: 0.031248209015367 + shape: (1024) + 1: + min: 1.2628638744354248 + max: 1.2628638744354248 + avg_min: 1.2628638744354248 + avg_max: 1.2628638744354248 + mean: 1.2628638744354248 + std: .nan + shape: (1) + output: + min: -0.2077881544828415 + max: 0.12376558035612106 + avg_min: -0.2077881544828415 + avg_max: 0.12376558035612106 + mean: 0.0004123526159673929 + std: 0.03946245597745369 + shape: (1024) +decoder.att_rnn.attn.tanh: + inputs: + 0: + min: -42.8883171081543 + max: 41.60582733154297 + avg_min: -26.312186408310776 + avg_max: 26.855665168333605 + mean: 0.19575258549512087 + std: 6.287847745163477 + shape: (1280, 1, 67, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.999999995680337 + avg_max: 0.999999995680337 + mean: 0.03299323406470721 + std: 0.9384829454472874 + shape: (1280, 1, 67, 1024) +decoder.att_rnn.attn.matmul_score: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.999999995680337 + avg_max: 0.999999995680337 + mean: 0.03299323406470721 + std: 0.9384829454472874 + shape: (1280, 1, 67, 1024) + 1: + min: -0.2077881544828415 + max: 0.12376558035612106 + avg_min: -0.2077881544828415 + avg_max: 0.12376558035612106 + mean: 0.0004123526159673929 + std: 0.03946245597745369 + shape: (1024) + output: + min: -5.075064182281494 + max: 20.071443557739258 + avg_min: 4.662021810703742 + avg_max: 12.197041146406967 + mean: 7.974570217477471 + std: 2.9091675582996563 + shape: (1280, 1, 67) +decoder.att_rnn.attn.softmax_att: + inputs: + 0: + min: -65504.0 + max: 20.071443557739258 + avg_min: -58239.877876515005 + avg_max: 12.163362407550382 + mean: -24693.233116322437 + std: 31749.93291331495 + shape: (1280, 1, 67) + output: + min: 0.0 + max: 0.9910882711410522 + avg_min: 1.299117331064688e-05 + avg_max: 0.39005828052340624 + mean: 0.013619308356258466 + std: 0.05484342060778131 + shape: (1280, 1, 67) +decoder.att_rnn.attn.context_matmul: + inputs: + 0: + min: 0.0 + max: 0.9910882711410522 + avg_min: 1.299117331064688e-05 + avg_max: 0.39005828052340624 + mean: 0.013619308356258466 + std: 0.05484342060778131 + shape: (1280, 1, 67) + 1: + min: -2.9931235313415527 + max: 2.9921934604644775 + avg_min: -2.5274831603082375 + avg_max: 2.571398892563376 + mean: -0.002407128676685295 + std: 0.3183030856867703 + shape: (1280, 67, 1024) + output: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.0244215909684593 + avg_max: 0.9287193868937127 + mean: 0.0004324376445557447 + std: 0.16600608798290178 + shape: (1280, 1, 1024) +decoder.att_rnn.dropout: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.9846230725224109 + avg_max: 0.9952872082423628 + mean: 0.007966578123805529 + std: 0.32187307320175484 + shape: (1280, 1, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9846230725224109 + avg_max: 0.9952872082423628 + mean: 0.007966578123805529 + std: 0.32187307320175484 + shape: (1280, 1, 1024) +decoder.rnn_layers.0.cells.0.fc_gate_x: + inputs: + 0: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.200279578834437 + avg_max: 1.1144092399752523 + mean: 0.004199507878490362 + std: 0.2561140031003674 + shape: (1280, 2048) + output: + min: -31.11646270751953 + max: 31.012344360351562 + avg_min: -11.667031159561684 + avg_max: 10.85241473739067 + mean: -1.1740948434960983 + std: 3.0091560830798434 + shape: (1280, 4096) +decoder.rnn_layers.0.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9999914765357971 + max: 0.9999936819076538 + avg_min: -0.7896906179993346 + avg_max: 0.7587759274110363 + mean: -0.0010008882399605407 + std: 0.1321120655297924 + shape: (1280, 1024) + output: + min: -16.805160522460938 + max: 22.041872024536133 + avg_min: -5.688125228613951 + avg_max: 5.603913084911495 + mean: -0.49282410457060527 + std: 1.2749332466602463 + shape: (1280, 4096) +decoder.rnn_layers.0.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -31.11646270751953 + max: 31.012344360351562 + avg_min: -11.667031159561684 + avg_max: 10.85241473739067 + mean: -1.1740948434960983 + std: 3.0091560830798434 + shape: (1280, 4096) + 1: + min: -16.805160522460938 + max: 22.041872024536133 + avg_min: -5.688125228613951 + avg_max: 5.603913084911495 + mean: -0.49282410457060527 + std: 1.2749332466602463 + shape: (1280, 4096) + output: + min: -33.57273864746094 + max: 30.18963623046875 + avg_min: -13.689522092797787 + avg_max: 12.499767167380655 + mean: -1.6669189450446138 + std: 3.6505088399213936 + shape: (1280, 4096) +decoder.rnn_layers.0.cells.0.act_f: + inputs: + 0: + min: -30.98964500427246 + max: 28.624046325683594 + avg_min: -13.069127936845424 + avg_max: 11.013900484127936 + mean: -2.6897614965278116 + std: 4.286025676313951 + shape: (1280, 1024) + output: + min: 3.4783092882741465e-14 + max: 1.0 + avg_min: 1.753829452079096e-05 + avg_max: 0.9996802927737837 + mean: 0.2552954440455087 + std: 0.35829928323219956 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.act_i: + inputs: + 0: + min: -29.928306579589844 + max: 27.947511672973633 + avg_min: -10.685935257525923 + avg_max: 9.380031768927397 + mean: -2.4518663697698173 + std: 3.0217416033646693 + shape: (1280, 1024) + output: + min: 1.005313792824293e-13 + max: 1.0 + avg_min: 0.00017471713297229053 + avg_max: 0.9996398801214232 + mean: 0.22650649123144959 + std: 0.3026139789025641 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.act_o: + inputs: + 0: + min: -27.317293167114258 + max: 23.132068634033203 + avg_min: -11.532117180877865 + avg_max: 8.524241487095862 + mean: -1.5604653776026836 + std: 3.2415928306767676 + shape: (1280, 1024) + output: + min: 1.3685174955063717e-12 + max: 1.0 + avg_min: 9.716954026726062e-05 + avg_max: 0.9978680859455905 + mean: 0.3295869104145614 + std: 0.34560961994352346 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.act_g: + inputs: + 0: + min: -33.57273864746094 + max: 30.18963623046875 + avg_min: -11.964716329735301 + avg_max: 11.46917405824984 + mean: 0.03441744941392035 + std: 3.2752884250043954 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999999861703825 + avg_max: 0.9999999514456561 + mean: 0.009212502904665593 + std: 0.8613827573443271 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 3.4783092882741465e-14 + max: 1.0 + avg_min: 1.753829452079096e-05 + avg_max: 0.9996802927737837 + mean: 0.2552954440455087 + std: 0.35829928323219956 + shape: (1280, 1024) + 1: + min: -50.77605438232422 + max: 51.83877944946289 + avg_min: -11.694220676448898 + avg_max: 11.783769438574792 + mean: -0.013233005671694572 + std: 0.9262749365617192 + shape: (1280, 1024) + output: + min: -50.32355880737305 + max: 51.78708267211914 + avg_min: -11.160377957814195 + avg_max: 11.075782726721808 + mean: -0.010043481503626429 + std: 0.8084092966416248 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 1.005313792824293e-13 + max: 1.0 + avg_min: 0.00017471713297229053 + avg_max: 0.9996398801214232 + mean: 0.22650649123144959 + std: 0.3026139789025641 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999999861703825 + avg_max: 0.9999999514456561 + mean: 0.009212502904665593 + std: 0.8613827573443271 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9977656979239379 + avg_max: 0.9981463519374972 + mean: -0.0036360137033707263 + std: 0.3448696757960219 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -50.32355880737305 + max: 51.78708267211914 + avg_min: -11.160377957814195 + avg_max: 11.075782726721808 + mean: -0.010043481503626429 + std: 0.8084092966416248 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9977656979239379 + avg_max: 0.9981463519374972 + mean: -0.0036360137033707263 + std: 0.3448696757960219 + shape: (1280, 1024) + output: + min: -51.32215881347656 + max: 52.70581817626953 + avg_min: -11.736164883080493 + avg_max: 11.741759218124862 + mean: -0.013679495197402003 + std: 0.9252576031128981 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.act_h: + inputs: + 0: + min: -51.32215881347656 + max: 52.70581817626953 + avg_min: -11.736164883080493 + avg_max: 11.741759218124862 + mean: -0.013679495197402003 + std: 0.9252576031128981 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9954183305247454 + avg_max: 0.9928096467189585 + mean: -0.005522266373139992 + std: 0.3470194696164207 + shape: (1280, 1024) +decoder.rnn_layers.0.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 1.3685174955063717e-12 + max: 1.0 + avg_min: 9.716954026726062e-05 + avg_max: 0.9978680859455905 + mean: 0.3295869104145614 + std: 0.34560961994352346 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9954183305247454 + avg_max: 0.9928096467189585 + mean: -0.005522266373139992 + std: 0.3470194696164207 + shape: (1280, 1024) + output: + min: -0.9999914765357971 + max: 0.9999985694885254 + avg_min: -0.7889921560716086 + avg_max: 0.7594633606377601 + mean: -0.0010860526439966734 + std: 0.13179261237432402 + shape: (1280, 1024) +decoder.rnn_layers.0.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +decoder.rnn_layers.1.cells.0.fc_gate_x: + inputs: + 0: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.1142985966098458 + avg_max: 0.996185907755004 + mean: -0.00032680749604648707 + std: 0.14988080842341334 + shape: (1280, 2048) + output: + min: -22.17432403564453 + max: 24.88280487060547 + avg_min: -7.668186707710953 + avg_max: 7.5104381103194156 + mean: -0.8090552802835963 + std: 2.096109569488443 + shape: (1280, 4096) +decoder.rnn_layers.1.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9998709559440613 + max: 0.9998664259910583 + avg_min: -0.6980919281250983 + avg_max: 0.7355864281399854 + mean: 0.00030723568868500604 + std: 0.13119853605881615 + shape: (1280, 1024) + output: + min: -10.816532135009766 + max: 10.679582595825195 + avg_min: -4.561453889193147 + avg_max: 4.048953114064886 + mean: -0.47011378199029547 + std: 0.9846700195969109 + shape: (1280, 4096) +decoder.rnn_layers.1.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -22.17432403564453 + max: 24.88280487060547 + avg_min: -7.668186707710953 + avg_max: 7.5104381103194156 + mean: -0.8090552802835963 + std: 2.096109569488443 + shape: (1280, 4096) + 1: + min: -10.816532135009766 + max: 10.679582595825195 + avg_min: -4.561453889193147 + avg_max: 4.048953114064886 + mean: -0.47011378199029547 + std: 0.9846700195969109 + shape: (1280, 4096) + output: + min: -24.22309684753418 + max: 24.288864135742188 + avg_min: -9.790749129284656 + avg_max: 8.619773781969311 + mean: -1.2791690607754058 + std: 2.5821871111713826 + shape: (1280, 4096) +decoder.rnn_layers.1.cells.0.act_f: + inputs: + 0: + min: -24.22309684753418 + max: 23.923398971557617 + avg_min: -9.621848200680162 + avg_max: 6.821592486440458 + mean: -2.6714931534247466 + std: 2.5443121998813525 + shape: (1280, 1024) + output: + min: 3.020248634522105e-11 + max: 1.0 + avg_min: 0.0003451858472697981 + avg_max: 0.9935724002256816 + mean: 0.1752624285941042 + std: 0.25292423787376883 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.act_i: + inputs: + 0: + min: -20.21559715270996 + max: 22.597978591918945 + avg_min: -7.931391174873612 + avg_max: 6.599652247616419 + mean: -1.3425940824358642 + std: 2.214069210841872 + shape: (1280, 1024) + output: + min: 1.6614134512593637e-09 + max: 1.0 + avg_min: 0.0012407131223308894 + avg_max: 0.994309761283102 + mean: 0.3069548572800805 + std: 0.29030061738546137 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.act_o: + inputs: + 0: + min: -20.679615020751953 + max: 24.288864135742188 + avg_min: -8.320266337876905 + avg_max: 7.602936600299369 + mean: -1.092906504534604 + std: 2.4894375831577364 + shape: (1280, 1024) + output: + min: 1.0446175036094019e-09 + max: 1.0 + avg_min: 0.0006933753792029073 + avg_max: 0.9977307225211285 + mean: 0.3514365402500282 + std: 0.32005181986355885 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.act_g: + inputs: + 0: + min: -21.044048309326172 + max: 21.751747131347656 + avg_min: -7.526448396886325 + avg_max: 7.799964112378239 + mean: -0.009682508857590904 + std: 2.3471144877832884 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999812895662334 + avg_max: 0.9999866421637916 + mean: -0.007454321779905927 + std: 0.8034238800947899 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 3.020248634522105e-11 + max: 1.0 + avg_min: 0.0003451858472697981 + avg_max: 0.9935724002256816 + mean: 0.1752624285941042 + std: 0.25292423787376883 + shape: (1280, 1024) + 1: + min: -50.93753433227539 + max: 30.134031295776367 + avg_min: -2.980211339807237 + avg_max: 2.858599665124761 + mean: -0.008311334580423152 + std: 0.4794455076286656 + shape: (1280, 1024) + output: + min: -50.889991760253906 + max: 30.027212142944336 + avg_min: -2.4969031673468915 + avg_max: 2.3347478714216994 + mean: -0.0015736418809647336 + std: 0.25571878953430177 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 1.6614134512593637e-09 + max: 1.0 + avg_min: 0.0012407131223308894 + avg_max: 0.994309761283102 + mean: 0.3069548572800805 + std: 0.29030061738546137 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999812895662334 + avg_max: 0.9999866421637916 + mean: -0.007454321779905927 + std: 0.8034238800947899 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9808258289701476 + avg_max: 0.9891012119108362 + mean: -0.006563609112533941 + std: 0.3551230793477708 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -50.889991760253906 + max: 30.027212142944336 + avg_min: -2.4969031673468915 + avg_max: 2.3347478714216994 + mean: -0.0015736418809647336 + std: 0.25571878953430177 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9808258289701476 + avg_max: 0.9891012119108362 + mean: -0.006563609112533941 + std: 0.3551230793477708 + shape: (1280, 1024) + output: + min: -51.780860900878906 + max: 30.70660400390625 + avg_min: -3.067318748858543 + avg_max: 2.8984877801343285 + mean: -0.008137250974908283 + std: 0.4813941324820594 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.act_h: + inputs: + 0: + min: -51.780860900878906 + max: 30.70660400390625 + avg_min: -3.067318748858543 + avg_max: 2.8984877801343285 + mean: -0.008137250974908283 + std: 0.4813941324820594 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9705970220016633 + avg_max: 0.9670456421509231 + mean: -0.0057490570247533766 + std: 0.345665687214933 + shape: (1280, 1024) +decoder.rnn_layers.1.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 1.0446175036094019e-09 + max: 1.0 + avg_min: 0.0006933753792029073 + avg_max: 0.9977307225211285 + mean: 0.3514365402500282 + std: 0.32005181986355885 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9705970220016633 + avg_max: 0.9670456421509231 + mean: -0.0057490570247533766 + std: 0.345665687214933 + shape: (1280, 1024) + output: + min: -0.9998709559440613 + max: 0.9998664259910583 + avg_min: -0.7024310171101872 + avg_max: 0.743544528524527 + mean: 0.00032440792050908126 + std: 0.1314242965426167 + shape: (1280, 1024) +decoder.rnn_layers.1.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +decoder.rnn_layers.2.cells.0.fc_gate_x: + inputs: + 0: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.1982355600662467 + avg_max: 1.1189356212870456 + mean: -0.00016460354038312733 + std: 0.1760178086847606 + shape: (1280, 2048) + output: + min: -24.67524528503418 + max: 28.180944442749023 + avg_min: -8.691330929284696 + avg_max: 8.328183968147563 + mean: -0.8204240057408129 + std: 2.453284969094989 + shape: (1280, 4096) +decoder.rnn_layers.2.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9999999403953552 + max: 0.9999994039535522 + avg_min: -0.814193978671278 + avg_max: 0.8058140833056374 + mean: -0.0030742165989599258 + std: 0.16821435780985747 + shape: (1280, 1024) + output: + min: -18.816926956176758 + max: 14.035696983337402 + avg_min: -6.604132010829581 + avg_max: 5.89386603040307 + mean: -0.7366783961109576 + std: 1.5059158176081058 + shape: (1280, 4096) +decoder.rnn_layers.2.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -24.67524528503418 + max: 28.180944442749023 + avg_min: -8.691330929284696 + avg_max: 8.328183968147563 + mean: -0.8204240057408129 + std: 2.453284969094989 + shape: (1280, 4096) + 1: + min: -18.816926956176758 + max: 14.035696983337402 + avg_min: -6.604132010829581 + avg_max: 5.89386603040307 + mean: -0.7366783961109576 + std: 1.5059158176081058 + shape: (1280, 4096) + output: + min: -29.283180236816406 + max: 26.588882446289062 + avg_min: -12.15073210014387 + avg_max: 10.381155111280712 + mean: -1.557102406895562 + std: 3.2988025400040257 + shape: (1280, 4096) +decoder.rnn_layers.2.cells.0.act_f: + inputs: + 0: + min: -29.283180236816406 + max: 23.37925148010254 + avg_min: -12.019422794459913 + avg_max: 7.526722565527717 + mean: -3.6089404096094384 + std: 3.0091061671078934 + shape: (1280, 1024) + output: + min: 1.9163568843773293e-13 + max: 1.0 + avg_min: 0.00010279474620529197 + avg_max: 0.9962542612231184 + mean: 0.13819005828811198 + std: 0.24204215179724617 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.act_i: + inputs: + 0: + min: -24.22443962097168 + max: 24.299768447875977 + avg_min: -9.08453069510085 + avg_max: 7.826490100581997 + mean: -0.91120249788078 + std: 2.6741725568824375 + shape: (1280, 1024) + output: + min: 3.0161959735375277e-11 + max: 1.0 + avg_min: 0.0005223408363930552 + avg_max: 0.9983030597815357 + mean: 0.38171999553281244 + std: 0.3367532650457234 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.act_o: + inputs: + 0: + min: -25.566701889038086 + max: 26.588882446289062 + avg_min: -10.116214546192888 + avg_max: 8.086509210607966 + mean: -1.6018465338798051 + std: 2.822778696064494 + shape: (1280, 1024) + output: + min: 7.879931950005581e-12 + max: 1.0 + avg_min: 0.00020159611936281438 + avg_max: 0.9988285692890035 + mean: 0.30748807730969463 + std: 0.3263006812590941 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.act_g: + inputs: + 0: + min: -26.263608932495117 + max: 23.93798828125 + avg_min: -10.2838829126251 + avg_max: 10.15937283199823 + mean: -0.10642016830597752 + std: 3.5525409631159355 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999990369161875 + avg_max: 0.999999153747988 + mean: -0.0269237342450161 + std: 0.879016886211502 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 1.9163568843773293e-13 + max: 1.0 + avg_min: 0.00010279474620529197 + avg_max: 0.9962542612231184 + mean: 0.13819005828811198 + std: 0.24204215179724617 + shape: (1280, 1024) + 1: + min: -50.77392578125 + max: 57.88676452636719 + avg_min: -4.026170471076211 + avg_max: 3.778732706154328 + mean: -0.010375739289892385 + std: 0.6429531313365376 + shape: (1280, 1024) + output: + min: -49.94257354736328 + max: 57.657936096191406 + avg_min: -3.367933605661549 + avg_max: 3.1268611868780627 + mean: -0.00047371524729856 + std: 0.358745687372424 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 3.0161959735375277e-11 + max: 1.0 + avg_min: 0.0005223408363930552 + avg_max: 0.9983030597815357 + mean: 0.38171999553281244 + std: 0.3367532650457234 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999990369161875 + avg_max: 0.999999153747988 + mean: -0.0269237342450161 + std: 0.879016886211502 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9962779700086355 + avg_max: 0.9960211656066814 + mean: -0.01004450515169632 + std: 0.4632697182775279 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -49.94257354736328 + max: 57.657936096191406 + avg_min: -3.367933605661549 + avg_max: 3.1268611868780627 + mean: -0.00047371524729856 + std: 0.358745687372424 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9962779700086355 + avg_max: 0.9960211656066814 + mean: -0.01004450515169632 + std: 0.4632697182775279 + shape: (1280, 1024) + output: + min: -50.93840026855469 + max: 58.63222122192383 + avg_min: -4.101547773835359 + avg_max: 3.8555672365293057 + mean: -0.010518220426035512 + std: 0.6487821329571182 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.act_h: + inputs: + 0: + min: -50.93840026855469 + max: 58.63222122192383 + avg_min: -4.101547773835359 + avg_max: 3.8555672365293057 + mean: -0.010518220426035512 + std: 0.6487821329571182 + shape: (1280, 1024) + output: + min: -1.0 + max: 1.0 + avg_min: -0.985261068585214 + avg_max: 0.9829632449016149 + mean: -0.008678492148203425 + std: 0.4269098495511477 + shape: (1280, 1024) +decoder.rnn_layers.2.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 7.879931950005581e-12 + max: 1.0 + avg_min: 0.00020159611936281438 + avg_max: 0.9988285692890035 + mean: 0.30748807730969463 + std: 0.3263006812590941 + shape: (1280, 1024) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.985261068585214 + avg_max: 0.9829632449016149 + mean: -0.008678492148203425 + std: 0.4269098495511477 + shape: (1280, 1024) + output: + min: -0.9999999403953552 + max: 0.9999994039535522 + avg_min: -0.8279972857303842 + avg_max: 0.8186234360665419 + mean: -0.0031441434696200784 + std: 0.1694763855118817 + shape: (1280, 1024) +decoder.rnn_layers.2.dropout: + output: + min: 1.7976931348623157e+308 + max: -1.7976931348623157e+308 + avg_min: 0 + avg_max: 0 + mean: 0 + std: 0 + shape: '' +decoder.classifier.classifier: + inputs: + 0: + min: -2.9619157314300537 + max: 2.9662914276123047 + avg_min: -1.3161585628316654 + avg_max: 1.2828886617770356 + mean: -0.003905788197230935 + std: 0.2593071699498008 + shape: (1280, 1, 1024) + output: + min: -19.020658493041992 + max: 20.08169937133789 + avg_min: -9.04961023920039 + avg_max: 8.032180467616302 + mean: -3.873352058817827 + std: 1.696852165054472 + shape: (1280, 1, 32317) +decoder.dropout: + inputs: + 0: + min: -1.9940483570098877 + max: 1.9767669439315796 + avg_min: -0.9086800352129599 + avg_max: 0.9094174789020626 + mean: 0.002039626916454153 + std: 0.22761719307769354 + shape: (1280, 1, 1024) + output: + min: -1.9940483570098877 + max: 1.9767669439315796 + avg_min: -0.9086800352129599 + avg_max: 0.9094174789020626 + mean: 0.002039626916454153 + std: 0.22761719307769354 + shape: (1280, 1, 1024) +decoder.eltwiseadd_residuals.0: + inputs: + 0: + min: -0.9998709559440613 + max: 0.9998664259910583 + avg_min: -0.7024310171101872 + avg_max: 0.743544528524527 + mean: 0.00032440792050908126 + std: 0.1314242965426167 + shape: (1280, 1, 1024) + 1: + min: -0.9999914765357971 + max: 0.9999985694885254 + avg_min: -0.7889921560716086 + avg_max: 0.7594633606377601 + mean: -0.0010860526439966734 + std: 0.13179261237432402 + shape: (1280, 1, 1024) + output: + min: -1.9940483570098877 + max: 1.9767669439315796 + avg_min: -0.9524248770448603 + avg_max: 0.9735018678260661 + mean: -0.0007616447304464125 + std: 0.1854876875589218 + shape: (1280, 1, 1024) +decoder.eltwiseadd_residuals.1: + inputs: + 0: + min: -0.9999999403953552 + max: 0.9999994039535522 + avg_min: -0.8279972857303842 + avg_max: 0.8186234360665419 + mean: -0.0031441434696200784 + std: 0.1694763855118817 + shape: (1280, 1, 1024) + 1: + min: -1.9940483570098877 + max: 1.9767669439315796 + avg_min: -0.9524248770448603 + avg_max: 0.9735018678260661 + mean: -0.0007616447304464125 + std: 0.1854876875589218 + shape: (1280, 1, 1024) + output: + min: -2.9619157314300537 + max: 2.9662914276123047 + avg_min: -1.3161585628316654 + avg_max: 1.2828886617770356 + mean: -0.003905788197230935 + std: 0.2593071699498008 + shape: (1280, 1, 1024) +decoder.attention_concats.0: + inputs: + 0: + min: -1.0 + max: 1.0 + avg_min: -0.9846230725224109 + avg_max: 0.9952872082423628 + mean: 0.007966578123805529 + std: 0.32187307320175484 + shape: (1280, 1, 1024) + 1: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.0244215909684593 + avg_max: 0.9287193868937127 + mean: 0.0004324376445557447 + std: 0.16600608798290178 + shape: (1280, 1, 1024) + output: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.200279578834437 + avg_max: 1.1144092399752523 + mean: 0.004199507878490362 + std: 0.2561140031003674 + shape: (1280, 1, 2048) +decoder.attention_concats.1: + inputs: + 0: + min: -0.9999914765357971 + max: 0.9999985694885254 + avg_min: -0.7889921560716086 + avg_max: 0.7594633606377601 + mean: -0.0010860526439966734 + std: 0.13179261237432402 + shape: (1280, 1, 1024) + 1: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.0244215909684593 + avg_max: 0.9287193868937127 + mean: 0.0004324376445557447 + std: 0.16600608798290178 + shape: (1280, 1, 1024) + output: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.1142985966098458 + avg_max: 0.996185907755004 + mean: -0.00032680749604648707 + std: 0.14988080842341334 + shape: (1280, 1, 2048) +decoder.attention_concats.2: + inputs: + 0: + min: -1.9940483570098877 + max: 1.9767669439315796 + avg_min: -0.9524248770448603 + avg_max: 0.9735018678260661 + mean: -0.0007616447304464125 + std: 0.1854876875589218 + shape: (1280, 1, 1024) + 1: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.0244215909684593 + avg_max: 0.9287193868937127 + mean: 0.0004324376445557447 + std: 0.16600608798290178 + shape: (1280, 1, 1024) + output: + min: -2.748584032058716 + max: 2.7187154293060303 + avg_min: -1.1982355600662467 + avg_max: 1.1189356212870456 + mean: -0.00016460354038312733 + std: 0.1760178086847606 + shape: (1280, 1, 2048) diff --git a/examples/GNMT/quantize_gnmt.ipynb b/examples/GNMT/quantize_gnmt.ipynb new file mode 100644 index 0000000..a2904d8 --- /dev/null +++ b/examples/GNMT/quantize_gnmt.ipynb @@ -0,0 +1,1069 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quantizing Neural Machine Translation Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We continue our quest to quantize every Neural Network! \n", + "On this chapter: __Google's Neural Machine Translation model__. \n", + "A brief summary - using stacked LSTMs and attention mechanism, this model encodes a sentence into a list of vectors and then decodes it to the other language tokens until an end token is reached. \n", + "To read more - refer to <a id=\"ref-1\" href=\"#cite-wu2016google\">Google's paper</a>." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Table of Contents\n", + "* [Quantizing Neural Machine Translation Models](#Quantizing-Neural-Machine-Translation-Models)\n", + "\t* [Getting the resources](#Getting-the-resources)\n", + "\t* [Loading the model](#Loading-the-model)\n", + "\t* [Evaulation of the model](#Evaulation-of-the-model)\n", + "\t* [Quantizing the model](#Quantizing-the-model)\n", + "\t\t* [Collecting the statistics](#Collecting-the-statistics)\n", + "\t\t* [Defining the Quantizer](#Defining-the-Quantizer)\n", + "\t\t* [Quantizing the model](#Quantizing-the-model)\n", + "\t\t* [Evaluating the quantized model](#Evaluating-the-quantized-model)\n", + "\t\t* [Finding the right quantization](#Finding-the-right-quantization)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting the resources" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this project, we modified the [`mlperf/training/rnn_translator`](https://github.com/mlperf/training/tree/master/rnn_translator) project to enable quantization of the GNMT model. \n", + "The instructions to download and setup the required environment for this task are in `README.md` (located in the current directory). \n", + "Download the pretrained model using the command:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment the line below to download the pretrained model:\n", + "#! wget https://zenodo.org/record/2581623/files/model_best.pth" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At this point, you should have everything ready to start quantizing!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparing The Model For Quantization\n", + "\n", + "In order to be able to fully quantize the model, we modify it according to the instructions laid out in the Distiller [documentation](https://nervanasystems.github.io/distiller/prepare_model_quant.html). This mostly amounts to making sure every quantize-able operation is invoked via a dedicated PyTorch Module. You can compare the code under `seq2seq/models` in this example with the [original](https://github.com/mlperf/training/tree/master/rnn_translator/pytorch/seq2seq/models).\n", + "\n", + "For example, in `seq2seq/models/attention.py`, we added the following code to the `__init__` function of the `BahdanauAttention` class:\n", + "\n", + "```python\n", + "# Adding submodules for basic ops to allow quantization:\n", + "self.eltwiseadd_qk = EltwiseAdd()\n", + "self.eltwiseadd_norm_bias = EltwiseAdd()\n", + "self.eltwisemul_norm_scaler = EltwiseMult()\n", + "self.matmul_score = Matmul()\n", + "self.context_matmul = BatchMatmul()\n", + "```\n", + "\n", + "We're creating modules for operations that were invoked directly in the `forward` function in the original code. This enables Distiller to detect these operations and replace them with quantized counterparts.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import distiller\n", + "from distiller.modules import DistillerLSTM\n", + "from distiller.quantization import PostTrainLinearQuantizer\n", + "from ast import literal_eval\n", + "from itertools import zip_longest\n", + "from copy import deepcopy\n", + "\n", + "from seq2seq import models\n", + "from seq2seq.inference.inference import Translator\n", + "from seq2seq.utils import AverageMeter\n", + "import subprocess\n", + "import os\n", + "import seq2seq.data.config as config\n", + "from seq2seq.data.dataset import ParallelDataset\n", + "import logging\n", + "from seq2seq.utils import AverageMeter\n", + "# Import utilities from the example:\n", + "from translate import grouper, write_output, checkpoint_from_distributed, unwrap_distributed\n", + "from itertools import takewhile\n", + "from tqdm import tqdm\n", + "import logging\n", + "logging.disable(logging.INFO) # Disables mlperf output" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Define some constants\n", + "batch_first=True\n", + "batch_size=128\n", + "beam_size=10\n", + "cov_penalty_factor=0.1\n", + "dataset_dir='./data'\n", + "input='./data/newstest2014.tok.clean.bpe.32000.en'\n", + "len_norm_const=5.0\n", + "len_norm_factor=0.6\n", + "max_seq_len=80\n", + "model='model_best.pth'\n", + "output='output_file'\n", + "print_freq=1\n", + "reference='./data/newstest2014.de'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Loading the model\n", + "checkpoint = torch.load('./model_best.pth', map_location={'cuda:0': 'cpu'})\n", + "vocab_size = checkpoint['tokenizer'].vocab_size\n", + "model_config = dict(vocab_size=vocab_size, math=checkpoint['config'].math,\n", + " **literal_eval(checkpoint['config'].model_config))\n", + "model_config['batch_first'] = batch_first\n", + "model = models.GNMT(**model_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GNMT(\n", + " (encoder): ResidualRecurrentEncoder(\n", + " (rnn_layers): ModuleList(\n", + " (0): LSTM(1024, 1024, batch_first=True, bidirectional=True)\n", + " (1): LSTM(2048, 1024, batch_first=True)\n", + " (2): LSTM(1024, 1024, batch_first=True)\n", + " (3): LSTM(1024, 1024, batch_first=True)\n", + " )\n", + " (dropout): Dropout(p=0.2)\n", + " (embedder): Embedding(32317, 1024, padding_idx=0)\n", + " (eltwiseadd_residuals): ModuleList(\n", + " (0): EltwiseAdd()\n", + " (1): EltwiseAdd()\n", + " )\n", + " )\n", + " (decoder): ResidualRecurrentDecoder(\n", + " (att_rnn): RecurrentAttention(\n", + " (rnn): LSTM(1024, 1024, batch_first=True)\n", + " (attn): BahdanauAttention(\n", + " (linear_q): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (linear_k): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (dropout): Dropout(p=0)\n", + " (eltwiseadd_qk): EltwiseAdd()\n", + " (eltwiseadd_norm_bias): EltwiseAdd()\n", + " (eltwisemul_norm_scaler): EltwiseMult()\n", + " (tanh): Tanh()\n", + " (matmul_score): Matmul()\n", + " (softmax_att): Softmax()\n", + " (context_matmul): BatchMatmul()\n", + " )\n", + " (dropout): Dropout(p=0)\n", + " )\n", + " (rnn_layers): ModuleList(\n", + " (0): LSTM(2048, 1024, batch_first=True)\n", + " (1): LSTM(2048, 1024, batch_first=True)\n", + " (2): LSTM(2048, 1024, batch_first=True)\n", + " )\n", + " (embedder): Embedding(32317, 1024, padding_idx=0)\n", + " (classifier): Classifier(\n", + " (classifier): Linear(in_features=1024, out_features=32317, bias=True)\n", + " )\n", + " (dropout): Dropout(p=0.2)\n", + " (eltwiseadd_residuals): ModuleList(\n", + " (0): EltwiseAdd()\n", + " (1): EltwiseAdd()\n", + " )\n", + " (attention_concats): ModuleList(\n", + " (0): Concat()\n", + " (1): Concat()\n", + " (2): Concat()\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state_dict = checkpoint['state_dict']\n", + "if checkpoint_from_distributed(state_dict):\n", + " state_dict = unwrap_distributed(state_dict)\n", + "\n", + "model.load_state_dict(state_dict)\n", + "torch.cuda.set_device(0)\n", + "model = model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaulation of the model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = checkpoint['tokenizer']\n", + "\n", + "\n", + "test_data = ParallelDataset(\n", + " src_fname=os.path.join(dataset_dir, config.SRC_TEST_FNAME),\n", + " tgt_fname=os.path.join(dataset_dir, config.TGT_TEST_FNAME),\n", + " tokenizer=tokenizer,\n", + " min_len=0,\n", + " max_len=150,\n", + " sort=False)\n", + "\n", + "def get_loader():\n", + " return test_data.get_loader(batch_size=batch_size,\n", + " batch_first=True,\n", + " shuffle=False,\n", + " num_workers=0,\n", + " drop_last=False,\n", + " distributed=False)\n", + "def get_translator(model):\n", + " return Translator(model,\n", + " tokenizer,\n", + " beam_size=beam_size,\n", + " max_seq_len=max_seq_len,\n", + " len_norm_factor=len_norm_factor,\n", + " len_norm_const=len_norm_const,\n", + " cov_penalty_factor=cov_penalty_factor,\n", + " cuda=True)\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(model, test_path):\n", + " test_file = open(test_path, 'w', encoding='UTF-8')\n", + " model.eval()\n", + " translator = get_translator(model)\n", + " stats = {}\n", + " iterations = enumerate(tqdm(get_loader()))\n", + " for i, (src, tgt, indices) in iterations:\n", + " src, src_length = src\n", + " if translator.batch_first:\n", + " batch_size = src.size(0)\n", + " else:\n", + " batch_size = src.size(1)\n", + " bos = [translator.insert_target_start] * (batch_size * beam_size)\n", + " bos = torch.LongTensor(bos)\n", + " if translator.batch_first:\n", + " bos = bos.view(-1, 1)\n", + " else:\n", + " bos = bos.view(1, -1)\n", + " src_length = torch.LongTensor(src_length)\n", + " stats['total_enc_len'] = int(src_length.sum())\n", + " src = src.cuda()\n", + " src_length = src_length.cuda()\n", + " bos = bos.cuda()\n", + " with torch.no_grad():\n", + " context = translator.model.encode(src, src_length)\n", + " context = [context, src_length, None]\n", + " if beam_size == 1:\n", + " generator = translator.generator.greedy_search\n", + " else:\n", + " generator = translator.generator.beam_search\n", + " preds, lengths, counter = generator(batch_size, bos, context)\n", + " stats['total_dec_len'] = lengths.sum().item()\n", + " stats['iters'] = counter\n", + " preds = preds.cpu()\n", + " lengths = lengths.cpu()\n", + " output = []\n", + " for idx, pred in enumerate(preds):\n", + " end = lengths[idx] - 1\n", + " pred = pred[1: end]\n", + " pred = pred.tolist()\n", + " out = translator.tok.detokenize(pred)\n", + " output.append(out)\n", + " output = [output[indices.index(i)] for i in range(len(output))]\n", + " for line in output:\n", + " test_file.write(line)\n", + " test_file.write('\\n')\n", + " total_tokens = stats['total_dec_len'] + stats['total_enc_len']\n", + " test_file.close()\n", + " # run moses detokenizer\n", + " detok_path = os.path.join(dataset_dir, config.DETOKENIZER)\n", + " detok_test_path = test_path + '.detok'\n", + "\n", + " with open(detok_test_path, 'w') as detok_test_file, \\\n", + " open(test_path, 'r') as test_file:\n", + " subprocess.run(['perl', detok_path], stdin=test_file,\n", + " stdout=detok_test_file, stderr=subprocess.DEVNULL)\n", + " # run sacrebleu\n", + " reference_path = os.path.join(dataset_dir,\n", + " config.TGT_TEST_TARGET_FNAME)\n", + " sacrebleu = subprocess.run(['sacrebleu --input {} {} --score-only -lc --tokenize intl'.\n", + " format(detok_test_path, reference_path)],\n", + " stdout=subprocess.PIPE, shell=True)\n", + " bleu = float(sacrebleu.stdout.strip())\n", + " print('BLEU on test dataset: {}'.format(bleu))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 24/24 [00:54<00:00, 2.03s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 22.16\n" + ] + } + ], + "source": [ + "evaluate(model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Quantizing the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we already noted, we modified the model from `mlperf` to a modular implementation so we can quantize each and every operation in the graph. \n", + "However, the default `nn.LSTM` was implemented in C++/CUDA, and we don't have usual access to it's operations hence we can't quantize it properly. This is why we'll convert the `nn.LSTM` to a `DistillerLSTM`, which is an entirely modular implementation of the LSTM - identical in functionality to the original `nn.LSTM`. \n", + "This is done by simply calling `DistillerLSTM.from_pytorch_impl` for a single `nn.LSTM` and \n", + "`convert_model_to_distiller_lstm` for an entire model containing multiple different LSTMs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 24/24 [02:21<00:00, 5.05s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 22.16\n" + ] + } + ], + "source": [ + "from distiller.modules import convert_model_to_distiller_lstm\n", + "model = convert_model_to_distiller_lstm(model)\n", + "evaluate(model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Collecting the statistics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The quantizer uses statistics to define the range of the quantization. We collect these statistics using a `QuantCalibrationStatsCollector` instance like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 24/24 [58:48<00:00, 130.39s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 22.16\n" + ] + } + ], + "source": [ + "import os\n", + "from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context\n", + "\n", + "stats_file = './model_stats.yaml'\n", + "\n", + "if not os.path.isfile(stats_file): # Collect stats.\n", + " model_copy = deepcopy(model)\n", + " distiller.utils.assign_layer_fq_names(model_copy)\n", + " collector = QuantCalibrationStatsCollector(model_copy)\n", + " with collector_context(collector):\n", + " val_loss = evaluate(model_copy, output + '.temp')\n", + " collector.save(stats_file)\n", + " del model_copy\n", + " torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Quantizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A distiller `Quantizer` object replaces each submodule in a model with its quantized counterpart, using a \n", + "`replacement_factory`. \n", + "`Quantizer.replacement_factory` is a dictionary which maps from a module type (e.g. `nn.Linear` and `nn.Conv`) to a function. This function takes a module and quantization configuration, and returns a quantized version of the same module." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Replacing 'Matmul' modules using 'replace_non_param_layer' function\n", + "Replacing 'EltwiseMult' modules using 'replace_non_param_layer' function\n", + "Replacing 'EltwiseAdd' modules using 'replace_non_param_layer' function\n", + "Replacing 'BatchMatmul' modules using 'replace_non_param_layer' function\n", + "Replacing 'Linear' modules using 'replace_param_layer' function\n", + "Replacing 'Embedding' modules using 'replace_embedding' function\n", + "Replacing 'Concat' modules using 'replace_non_param_layer' function\n", + "Replacing 'Conv2d' modules using 'replace_param_layer' function\n" + ] + } + ], + "source": [ + "# Basic quantizer defintion\n", + "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n", + " mode=\"SYMMETRIC\", # As was suggested in GNMT's paper\n", + " model_activation_stats=stats_file)\n", + "# We take a look at the replacement factory:\n", + "for t, rf in quantizer.replacement_factory.items():\n", + " print(\"Replacing '{}' modules using '{}' function\".format(t.__name__, rf.__name__))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Quantizing the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is done by simply calling `quantizer.prepare_model()`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "W0807 15:01:57.225715 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n", + "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n", + " 'opportunities for optimization in the model will be ignored.', UserWarning)\n", + "\n", + "W0807 15:01:57.301426 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n", + " UserWarning)\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "GNMT(\n", + " (encoder): ResidualRecurrentEncoder(\n", + " (rnn_layers): ModuleList(\n", + " (0): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=True)\n", + " (1): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " (2): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " (3): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " )\n", + " (dropout): Dropout(p=0.2)\n", + " (embedder): RangeLinearEmbeddingWrapper(\n", + " (wrapped_module): Embedding(32317, 1024, padding_idx=0)\n", + " )\n", + " (eltwiseadd_residuals): ModuleList(\n", + " (0): RangeLinearQuantEltwiseAddWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.00788879394531, in_0_zero_point=0.0\n", + " in_1_scale=127.00006103515625, in_1_zero_point=0.0\n", + " out_scale=63.530452728271484, out_zero_point=0.0\n", + " (wrapped_module): EltwiseAdd()\n", + " )\n", + " (1): RangeLinearQuantEltwiseAddWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.00004577636719, in_0_zero_point=0.0\n", + " in_1_scale=63.530452728271484, in_1_zero_point=0.0\n", + " out_scale=42.43059158325195, out_zero_point=0.0\n", + " (wrapped_module): EltwiseAdd()\n", + " )\n", + " )\n", + " )\n", + " (decoder): ResidualRecurrentDecoder(\n", + " (att_rnn): RecurrentAttention(\n", + " (rnn): DistillerLSTM(1024, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " (attn): BahdanauAttention(\n", + " (linear_q): RangeLinearQuantParamLayerWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " w_scale=53.0649, w_zero_point=0.0000\n", + " in_scale=127.0000, in_zero_point=0.0000\n", + " out_scale=4.2200, out_zero_point=0.0000\n", + " (wrapped_module): Linear(in_features=1024, out_features=1024, bias=False)\n", + " )\n", + " (linear_k): RangeLinearQuantParamLayerWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " w_scale=40.9795, w_zero_point=0.0000\n", + " in_scale=42.4306, in_zero_point=0.0000\n", + " out_scale=4.5466, out_zero_point=0.0000\n", + " (wrapped_module): Linear(in_features=1024, out_features=1024, bias=False)\n", + " )\n", + " (dropout): Dropout(p=0)\n", + " (eltwiseadd_qk): RangeLinearQuantEltwiseAddWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=4.219996452331543, in_0_zero_point=0.0\n", + " in_1_scale=4.546571731567383, in_1_zero_point=0.0\n", + " out_scale=2.974982261657715, out_zero_point=0.0\n", + " (wrapped_module): EltwiseAdd()\n", + " )\n", + " (eltwiseadd_norm_bias): RangeLinearQuantEltwiseAddWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=2.974982261657715, in_0_zero_point=0.0\n", + " in_1_scale=109.55178833007812, in_1_zero_point=0.0\n", + " out_scale=2.961179256439209, out_zero_point=0.0\n", + " (wrapped_module): EltwiseAdd()\n", + " )\n", + " (eltwisemul_norm_scaler): RangeLinearQuantEltwiseMultWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=771.8616943359375, in_0_zero_point=0.0\n", + " in_1_scale=100.56507873535156, in_1_zero_point=0.0\n", + " out_scale=611.1994018554688, out_zero_point=0.0\n", + " (wrapped_module): EltwiseMult()\n", + " )\n", + " (tanh): Tanh()\n", + " (matmul_score): RangeLinearQuantMatmulWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.0000, in_0_zero_point=0.0000\n", + " in_1_scale=611.1994, in_1_zero_point=0.0000\n", + " out_scale=6.3274, out_zero_point=0.0000\n", + " (wrapped_module): Matmul()\n", + " )\n", + " (softmax_att): Softmax()\n", + " (context_matmul): RangeLinearQuantMatmulWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=128.1420, in_0_zero_point=0.0000\n", + " in_1_scale=42.4306, in_1_zero_point=0.0000\n", + " out_scale=46.2056, out_zero_point=0.0000\n", + " (wrapped_module): BatchMatmul()\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0)\n", + " )\n", + " (rnn_layers): ModuleList(\n", + " (0): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " (1): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " (2): DistillerLSTM(2048, 1024, num_layers=1, dropout=0.00, bidirectional=False)\n", + " )\n", + " (embedder): RangeLinearEmbeddingWrapper(\n", + " (wrapped_module): Embedding(32317, 1024, padding_idx=0)\n", + " )\n", + " (classifier): Classifier(\n", + " (classifier): RangeLinearQuantParamLayerWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " w_scale=40.8311, w_zero_point=0.0000\n", + " in_scale=42.8144, in_zero_point=0.0000\n", + " out_scale=6.3242, out_zero_point=0.0000\n", + " (wrapped_module): Linear(in_features=1024, out_features=32317, bias=True)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.2)\n", + " (eltwiseadd_residuals): ModuleList(\n", + " (0): RangeLinearQuantEltwiseAddWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.01639556884766, in_0_zero_point=0.0\n", + " in_1_scale=127.00018310546875, in_1_zero_point=0.0\n", + " out_scale=63.68953323364258, out_zero_point=0.0\n", + " (wrapped_module): EltwiseAdd()\n", + " )\n", + " (1): RangeLinearQuantEltwiseAddWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.00001525878906, in_0_zero_point=0.0\n", + " in_1_scale=63.68953323364258, in_1_zero_point=0.0\n", + " out_scale=42.81440353393555, out_zero_point=0.0\n", + " (wrapped_module): EltwiseAdd()\n", + " )\n", + " )\n", + " (attention_concats): ModuleList(\n", + " (0): RangeLinearQuantConcatWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.0, in_0_zero_point=0.0\n", + " in_1_scale=46.20560836791992, in_1_zero_point=0.0\n", + " out_scale=46.20560836791992, out_zero_point=0.0\n", + " (wrapped_module): Concat()\n", + " )\n", + " (1): RangeLinearQuantConcatWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=127.00018310546875, in_0_zero_point=0.0\n", + " in_1_scale=46.20560836791992, in_1_zero_point=0.0\n", + " out_scale=46.20560836791992, out_zero_point=0.0\n", + " (wrapped_module): Concat()\n", + " )\n", + " (2): RangeLinearQuantConcatWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_accum=32, clip_acts=NONE, scale_approx_mult_bits=None\n", + " preset_activation_stats=True\n", + " in_0_scale=63.68953323364258, in_0_zero_point=0.0\n", + " in_1_scale=46.20560836791992, in_1_zero_point=0.0\n", + " out_scale=46.20560836791992, out_zero_point=0.0\n", + " (wrapped_module): Concat()\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy_input = (torch.ones(1, 2).to(dtype=torch.long),\n", + " torch.ones(1).to(dtype=torch.long),\n", + " torch.ones(1, 2).to(dtype=torch.long))\n", + "quantizer.prepare_model(dummy_input)\n", + "quantizer.model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you'd like to know how these functions replace the modules - I recommend reading the source code for them in \n", + "`{DISTILLER_ROOT}/distiller/quantization/range_linear.py:PostTrainLinearQuantizer`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluating the quantized model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 24/24 [13:49<00:00, 31.91s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 18.05\n" + ] + } + ], + "source": [ + "#torch.cuda.empty_cache()\n", + "evaluate(quantizer.model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Finding the right quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see here, we quantized our model entirely and it lost some accuracy, so we want to apply more strategies to quantize better. \n", + "Symmetric quantization means our range is the biggest we can hold our activations in:\n", + "$$\n", + " M = \\max \\{ |\\text{acts}|\\},\\, \\text{range}_{symmetric} = [-M, M]\n", + "$$\n", + "This way we waste resolution. However, if we use assymetric quantization - we may get better results:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0807 15:15:53.630417 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n", + "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n", + " 'opportunities for optimization in the model will be ignored.', UserWarning)\n", + "\n", + "W0807 15:15:53.755373 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n", + " UserWarning)\n", + "\n", + "100%|██████████| 24/24 [14:25<00:00, 28.42s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 18.52\n" + ] + } + ], + "source": [ + "# Basic quantizer defintion\n", + "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n", + " mode=\"ASYMMETRIC_SIGNED\", \n", + " model_activation_stats=stats_file)\n", + "quantizer.prepare_model()\n", + "evaluate(quantizer.model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here - we quantized asymmetrically, meaning our range still holds all the activations, but it's smaller than in the symmetrical case. \n", + "The formula is:\n", + "$$\n", + " \\text{range}_{asymmetric} = \\left[\\min\\{ \\text{acts}\\}, \\max \\{ \\text{acts}\\}\\right] \n", + " \\subset \\text{range}_{symmetric}\n", + "$$\n", + "And we indeed got a slightly better result. \n", + "However - some part of the activations during the evaluations are outliers, meaning they are way outside the range of most of their buddies. We're going to intercept this in two ways -\n", + "1. Quantize each channel separately, that way we achieve more accuracy. We'll add the argument `per_channel_wts=True`.\n", + "2. Limit the quantization range to a smaller one, thus clamping these outliers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll try using the same technique as in `quantize_lstm.ipynb` - clipping the activations according to the average range recorded for them: \n", + "\n", + "\n", + "$$\n", + " m = \\underset{b\\in\\text{batches}}{\\text{avg}}\\left\\{\\min_{b}\\{\\text{acts}\\}\\right\\},\\,\n", + " M = \\underset{b\\in\\text{batches}}{\\text{avg}}\\left\\{\\max_{b}\\{\\text{acts}\\}\\right\\}\n", + "$$\n", + "\n", + "\n", + "$$\n", + " \\text{range}_{clipped} = [m,M] \\subset \\text{range}_{asymmetric} \\subset \\text{range}_{symmetric}\n", + "$$\n", + "\n", + "This is done by specifying `clip_acts=\"AVG\"` in the quantizer. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0807 15:30:25.722338 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n", + "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n", + " 'opportunities for optimization in the model will be ignored.', UserWarning)\n", + "\n", + "W0807 15:30:26.860130 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n", + " UserWarning)\n", + "\n", + "100%|██████████| 24/24 [13:58<00:00, 29.49s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 9.63\n" + ] + } + ], + "source": [ + "# Basic quantizer defintion\n", + "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n", + " mode=\"ASYMMETRIC_SIGNED\", \n", + " model_activation_stats=stats_file,\n", + " per_channel_wts=True,\n", + " clip_acts=\"AVG\")\n", + "quantizer.prepare_model()\n", + "evaluate(quantizer.model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Oh no! This is bad... turns out that by clamping the outliers we actually \"removed\" useful features from important layers like the attention layer. In the attention layer we have a softmax which relies on high values to pass a correct score of importance of features. Let's try clipping all the other values, except in the attention layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0807 15:44:32.041999 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n", + "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n", + " 'opportunities for optimization in the model will be ignored.', UserWarning)\n", + "\n", + "W0807 15:44:33.205797 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n", + " UserWarning)\n", + "\n", + "100%|██████████| 24/24 [13:54<00:00, 32.72s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 16.94\n" + ] + } + ], + "source": [ + "# No clipping in the attention layer\n", + "overrides_yaml = \"\"\"\n", + ".*att_rnn.attn.*:\n", + " clip_acts: NONE # Quantize without clipping\n", + "\"\"\"\n", + "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n", + "# Basic quantizer defintion\n", + "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n", + " mode=\"ASYMMETRIC_SIGNED\", \n", + " model_activation_stats=stats_file,\n", + " overrides=overrides,\n", + " per_channel_wts=True,\n", + " clip_acts=\"AVG\")\n", + "quantizer.prepare_model()\n", + "evaluate(quantizer.model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The accuracy is somewhat \"restored\", by still we would like to get a score as close to the original model as possible. How about leaving the `classifier` asymmetric, without clipping it?" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0807 15:58:34.769241 140098610173696 range_linear.py:1063] /home/cvds_lab/lev_z/distiller/distiller/quantization/range_linear.py:1063: UserWarning: Model contains a bidirectional DistillerLSTM module. Automatic BN folding and statistics optimization based on tracing is not yet supported for models containing such modules.\n", + "Will perform specific optimization for the DistillerLSTM modules, but any other potential opportunities for optimization in the model will be ignored.\n", + " 'opportunities for optimization in the model will be ignored.', UserWarning)\n", + "\n", + "W0807 15:58:35.894991 140098610173696 quantizer.py:270] /home/cvds_lab/lev_z/distiller/distiller/quantization/quantizer.py:270: UserWarning: Module 'decoder.embedder' references to same module as 'encoder.embedder'. Replacing with reference the same wrapper.\n", + " UserWarning)\n", + "\n", + "100%|██████████| 24/24 [13:14<00:00, 28.11s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BLEU on test dataset: 21.49\n" + ] + } + ], + "source": [ + "# No clipping in the attention layer and in the final classifier\n", + "overrides_yaml = \"\"\"\n", + ".*att_rnn.attn.*:\n", + " clip_acts: NONE # Quantize without clipping\n", + "decoder.classifier.classifier:\n", + " clip_acts: NONE # Quantize without clipping\n", + "\"\"\"\n", + "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n", + "# Basic quantizer defintion\n", + "quantizer = PostTrainLinearQuantizer(deepcopy(model), \n", + " mode=\"ASYMMETRIC_SIGNED\", \n", + " model_activation_stats=stats_file,\n", + " overrides=overrides,\n", + " per_channel_wts=True,\n", + " clip_acts=\"AVG\")\n", + "quantizer.prepare_model()\n", + "evaluate(quantizer.model, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, some good results! So now we know better which layers are sensitive to clipping and which are complimented by it. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# References\n", + "\n", + "<a id=\"cite-wu2016google\"/><sup><a href=#ref-1>[^]</a></sup>Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and Macherey, Klaus and others. 2016. _Google's neural machine translation system: Bridging the gap between human and machine translation_.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<!--bibtex\n", + "\n", + "@article{wu2016google,\n", + " title={Google's neural machine translation system: Bridging the gap between human and machine translation},\n", + " author={Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and Macherey, Klaus and others},\n", + " journal={arXiv preprint arXiv:1609.08144},\n", + " year={2016}\n", + "}\n", + "\n", + "-->" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/GNMT/requirements.txt b/examples/GNMT/requirements.txt new file mode 100644 index 0000000..5cf238b --- /dev/null +++ b/examples/GNMT/requirements.txt @@ -0,0 +1,3 @@ +sacrebleu==1.2.10 +# numpy==1.14.2 +# mlperf-compliance==0.0.4 diff --git a/examples/GNMT/scripts/filter_dataset.py b/examples/GNMT/scripts/filter_dataset.py new file mode 100644 index 0000000..3168d47 --- /dev/null +++ b/examples/GNMT/scripts/filter_dataset.py @@ -0,0 +1,79 @@ +import argparse +from collections import Counter + + +def parse_args(): + parser = argparse.ArgumentParser(description='Clean dataset') + parser.add_argument('-f1', '--file1', help='file1') + parser.add_argument('-f2', '--file2', help='file2') + return parser.parse_args() + + +def save_output(fname, data): + with open(fname, 'w') as f: + f.writelines(data) + + +def main(): + """ + Discards all pairs of sentences which can't be decoded by latin-1 encoder. + + It aims to filter out sentences with rare unicode glyphs and pairs which + are most likely not valid English-German sentences. + + Examples of discarded sentences: + + ✿★★★Hommage au king de la pop ★★★✿ ✿★★★Que son âme repos... + + Ð”Ð»Ñ Ð¸Ñ… оÑущеÑÑ‚Ð²Ð»ÐµÐ½Ð¸Ñ Ð½Ð°Ð¼, прежде вÑего, необходимо преодолеть + Ð²Ð¾Ð·Ñ€Ð°Ð¶ÐµÐ½Ð¸Ñ Ñ€Ñ‹Ð½Ð¾Ñ‡Ð½Ñ‹Ñ… фундаменталиÑтов, которые хотÑÑ‚ ликвидировать или + уменьшить роль МВФ. + + practised as a scientist in various medical departments of the ⇗Medical + University of Hanover , the ⇗University of Ulm , and the ⇗RWTH Aachen + (rheumatology, pharmacology, physiology, pathology, microbiology, + immunology and electron-microscopy). + + The same shift】 and press ã€ã€‘ ã€alt out with a smaller diameter + circle. + + Brought to you by ABMSUBS ♥leira(Coordinator/Translator) + ♥chibichan93(Timer/Typesetter) ♥ja... + + Some examples: &0u - ☺ &0U - ☻ &tel - ☠&PI - ¶ &SU - ☼ &cH- - ♥ &M2=♫ + &sn - ﺵ SGML maps SGML to unicode. + """ + args = parse_args() + + c = Counter() + skipped = 0 + valid = 0 + data1 = [] + data2 = [] + + with open(args.file1) as f1, open(args.file2) as f2: + for idx, lines in enumerate(zip(f1, f2)): + line1, line2 = lines + if idx % 100000 == 1: + print('Processed {} lines'.format(idx)) + try: + line1.encode('latin1') + line2.encode('latin1') + except UnicodeEncodeError: + skipped += 1 + else: + data1.append(line1) + data2.append(line2) + valid += 1 + c.update(line1) + + ratio = valid / (skipped + valid) + print('Skipped: {}, Valid: {}, Valid ratio {}'.format(skipped, valid, ratio)) + print('Character frequency:', c) + + save_output(args.file1, data1) + save_output(args.file2, data2) + + +if __name__ == '__main__': + main() diff --git a/examples/GNMT/seq2seq/__init__.py b/examples/GNMT/seq2seq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/GNMT/seq2seq/data/__init__.py b/examples/GNMT/seq2seq/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/GNMT/seq2seq/data/config.py b/examples/GNMT/seq2seq/data/config.py new file mode 100644 index 0000000..69af2a1 --- /dev/null +++ b/examples/GNMT/seq2seq/data/config.py @@ -0,0 +1,21 @@ +PAD_TOKEN = '<pad>' +UNK_TOKEN = '<unk>' +BOS_TOKEN = '<s>' +EOS_TOKEN = '<\s>' + +PAD, UNK, BOS, EOS = [0, 1, 2, 3] + +VOCAB_FNAME = 'vocab.bpe.32000' + +SRC_TRAIN_FNAME = 'train.tok.clean.bpe.32000.en' +TGT_TRAIN_FNAME = 'train.tok.clean.bpe.32000.de' + +SRC_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.en' +TGT_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.de' + +SRC_TEST_FNAME = 'newstest2014.tok.bpe.32000.en' +TGT_TEST_FNAME = 'newstest2014.tok.bpe.32000.de' + +TGT_TEST_TARGET_FNAME = 'newstest2014.de' + +DETOKENIZER = 'mosesdecoder/scripts/tokenizer/detokenizer.perl' diff --git a/examples/GNMT/seq2seq/data/dataset.py b/examples/GNMT/seq2seq/data/dataset.py new file mode 100644 index 0000000..0bf39b4 --- /dev/null +++ b/examples/GNMT/seq2seq/data/dataset.py @@ -0,0 +1,123 @@ +import logging + +import torch +from torch.utils.data import Dataset +from torch.utils.data.sampler import SequentialSampler, RandomSampler +from seq2seq.data.sampler import BucketingSampler +from torch.utils.data import DataLoader + +import seq2seq.data.config as config + + +def build_collate_fn(batch_first=False, sort=False): + def collate_seq(seq): + lengths = [len(s) for s in seq] + batch_length = max(lengths) + + shape = (batch_length, len(seq)) + seq_tensor = torch.full(shape, config.PAD, dtype=torch.int64) + + for i, s in enumerate(seq): + end_seq = lengths[i] + seq_tensor[:end_seq, i].copy_(s[:end_seq]) + + if batch_first: + seq_tensor = seq_tensor.t() + + return (seq_tensor, lengths) + + def collate(seqs): + src_seqs, tgt_seqs = zip(*seqs) + if sort: + key = lambda item: len(item[1]) + indices, src_seqs = zip(*sorted(enumerate(src_seqs), key=key, + reverse=True)) + tgt_seqs = [tgt_seqs[idx] for idx in indices] + else: + indices = range(len(src_seqs)) + + return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]] + [indices]) + + return collate + + +class ParallelDataset(Dataset): + def __init__(self, src_fname, tgt_fname, tokenizer, + min_len, max_len, sort=False, max_size=None): + + self.min_len = min_len + self.max_len = max_len + + self.src = self.process_data(src_fname, tokenizer, max_size) + self.tgt = self.process_data(tgt_fname, tokenizer, max_size) + assert len(self.src) == len(self.tgt) + + self.filter_data(min_len, max_len) + assert len(self.src) == len(self.tgt) + + lengths = [len(s) + len(t) for (s, t) in zip(self.src, self.tgt)] + self.lengths = torch.tensor(lengths) + + if sort: + self.sort_by_length() + + def sort_by_length(self): + self.lengths, indices = self.lengths.sort(descending=True) + + self.src = [self.src[idx] for idx in indices] + self.tgt = [self.tgt[idx] for idx in indices] + + def filter_data(self, min_len, max_len): + logging.info('filtering data, min len: {}, max len: {}'.format(min_len, max_len)) + + initial_len = len(self.src) + + filtered_src = [] + filtered_tgt = [] + for src, tgt in zip(self.src, self.tgt): + if min_len <= len(src) <= max_len and \ + min_len <= len(tgt) <= max_len: + filtered_src.append(src) + filtered_tgt.append(tgt) + + self.src = filtered_src + self.tgt = filtered_tgt + + filtered_len = len(self.src) + logging.info('pairs before: {}, after: {}'.format(initial_len, filtered_len)) + + def process_data(self, fname, tokenizer, max_size): + logging.info('processing data from {}'.format(fname)) + data = [] + with open(fname) as dfile: + for idx, line in enumerate(dfile): + if max_size and idx == max_size: + break + entry = tokenizer.segment(line) + entry = torch.tensor(entry) + data.append(entry) + return data + + def __len__(self): + return len(self.src) + + def __getitem__(self, idx): + return self.src[idx], self.tgt[idx] + + def get_loader(self, batch_size=1, shuffle=False, num_workers=0, batch_first=False, + drop_last=False, distributed=False, bucket=True): + + collate_fn = build_collate_fn(batch_first, sort=True) + + if shuffle: + sampler = BucketingSampler(self, batch_size, bucket) + else: + sampler = SequentialSampler(self) + + return DataLoader(self, + batch_size=batch_size, + collate_fn=collate_fn, + sampler=sampler, + num_workers=num_workers, + pin_memory=False, + drop_last=drop_last) diff --git a/examples/GNMT/seq2seq/data/sampler.py b/examples/GNMT/seq2seq/data/sampler.py new file mode 100644 index 0000000..74085c4 --- /dev/null +++ b/examples/GNMT/seq2seq/data/sampler.py @@ -0,0 +1,85 @@ +import torch +from torch.utils.data.sampler import Sampler + +from seq2seq.utils import get_world_size, get_rank + + +class BucketingSampler(Sampler): + + def __init__(self, dataset, batch_size, bucket=True, world_size=None, rank=None): + if world_size is None: + world_size = get_world_size() + if rank is None: + rank = get_rank() + + self.dataset = dataset + self.world_size = world_size + self.rank = rank + self.epoch = 0 + self.bucket = bucket + + self.batch_size = batch_size + self.global_batch_size = batch_size * world_size + + self.data_len = len(self.dataset) + self.num_samples = self.data_len // self.global_batch_size \ + * self.global_batch_size + + def __iter__(self): + + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + # generate permutation + indices = torch.randperm(self.data_len, generator=g) + # make indices evenly divisible by (batch_size * world_size) + indices = indices[:self.num_samples] + + + if self.bucket: + # begin shards + batches_in_shard = 80 + shard_size = self.global_batch_size * batches_in_shard + nshards = (self.num_samples + shard_size - 1) // shard_size + + lengths = self.dataset.lengths[indices] + + shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)] + len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)] + + indices = [] + for len_shard in len_shards: + _, ind = len_shard.sort() + indices.append(ind) + + output = tuple(shard[idx] for shard,idx in zip(shards, indices)) + indices = torch.cat(output) + # global reshuffle + indices = indices.view(-1, self.global_batch_size) + order = torch.randperm(indices.shape[0], generator=g) + indices = indices[order, :] + indices = indices.view(-1) + # end shards + + + assert len(indices) == self.num_samples + + # build indices for each individual worker + # ranks are getting consecutive batches, + # default pytorch DistributedSampler assigns strided batches + # with offset = length / world_size + indices = indices.view(-1, self.batch_size) + indices = indices[self.rank::self.world_size].contiguous() + indices = indices.view(-1) + indices = indices.tolist() + + assert len(indices) == self.num_samples // self.world_size + + return iter(indices) + + def __len__(self): + return self.num_samples // self.world_size + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/examples/GNMT/seq2seq/data/tokenizer.py b/examples/GNMT/seq2seq/data/tokenizer.py new file mode 100644 index 0000000..09b827a --- /dev/null +++ b/examples/GNMT/seq2seq/data/tokenizer.py @@ -0,0 +1,44 @@ +import logging +from collections import defaultdict + +import seq2seq.data.config as config + +def default(): + return config.UNK + +class Tokenizer: + def __init__(self, vocab_fname, separator='@@'): + + self.separator = separator + + logging.info('building vocabulary from {}'.format(vocab_fname)) + vocab = [config.PAD_TOKEN, config.UNK_TOKEN, + config.BOS_TOKEN, config.EOS_TOKEN] + + with open(vocab_fname) as vfile: + for line in vfile: + vocab.append(line.strip()) + + logging.info('size of vocabulary: {}'.format(len(vocab))) + self.vocab_size = len(vocab) + + + self.tok2idx = defaultdict(default) + for idx, token in enumerate(vocab): + self.tok2idx[token] = idx + + self.idx2tok = {} + for key, value in self.tok2idx.items(): + self.idx2tok[value] = key + + def segment(self, line): + line = line.strip().split() + entry = [self.tok2idx[i] for i in line] + entry = [config.BOS] + entry + [config.EOS] + return entry + + def detokenize(self, inputs, delim=' '): + detok = delim.join([self.idx2tok[idx] for idx in inputs]) + detok = detok.replace( + self.separator+ ' ', '').replace(self.separator, '') + return detok diff --git a/examples/GNMT/seq2seq/inference/__init__.py b/examples/GNMT/seq2seq/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/GNMT/seq2seq/inference/beam_search.py b/examples/GNMT/seq2seq/inference/beam_search.py new file mode 100644 index 0000000..90c46b7 --- /dev/null +++ b/examples/GNMT/seq2seq/inference/beam_search.py @@ -0,0 +1,245 @@ +import torch + +from seq2seq.data.config import BOS +from seq2seq.data.config import EOS + + +class SequenceGenerator(object): + def __init__(self, + model, + beam_size=5, + max_seq_len=100, + cuda=False, + len_norm_factor=0.6, + len_norm_const=5, + cov_penalty_factor=0.1): + + self.model = model + self.cuda = cuda + self.beam_size = beam_size + self.max_seq_len = max_seq_len + self.len_norm_factor = len_norm_factor + self.len_norm_const = len_norm_const + self.cov_penalty_factor = cov_penalty_factor + + self.batch_first = self.model.batch_first + + def greedy_search(self, batch_size, initial_input, initial_context=None): + max_seq_len = self.max_seq_len + + translation = torch.zeros(batch_size, max_seq_len, dtype=torch.int64) + lengths = torch.ones(batch_size, dtype=torch.int64) + active = torch.arange(0, batch_size, dtype=torch.int64) + base_mask = torch.arange(0, batch_size, dtype=torch.int64) + + if self.cuda: + translation = translation.cuda() + lengths = lengths.cuda() + active = active.cuda() + base_mask = base_mask.cuda() + + translation[:, 0] = BOS + words, context = initial_input, initial_context + + if self.batch_first: + word_view = (-1, 1) + ctx_batch_dim = 0 + else: + word_view = (1, -1) + ctx_batch_dim = 1 + + counter = 0 + for idx in range(1, max_seq_len): + if not len(active): + break + counter += 1 + + words = words.view(word_view) + words, logprobs, attn, context = self.model.generate(words, context, 1) + words = words.view(-1) + + translation[active, idx] = words + lengths[active] += 1 + + terminating = (words == EOS) + + if terminating.any(): + not_terminating = ~terminating + + mask = base_mask[:len(active)] + mask = mask.masked_select(not_terminating) + active = active.masked_select(not_terminating) + + words = words[mask] + context[0] = context[0].index_select(ctx_batch_dim, mask) + context[1] = context[1].index_select(0, mask) + context[2] = context[2].index_select(1, mask) + + return translation, lengths, counter + + def beam_search(self, batch_size, initial_input, initial_context=None): + beam_size = self.beam_size + norm_const = self.len_norm_const + norm_factor = self.len_norm_factor + max_seq_len = self.max_seq_len + cov_penalty_factor = self.cov_penalty_factor + + translation = torch.zeros(batch_size * beam_size, max_seq_len, dtype=torch.int64) + lengths = torch.ones(batch_size * beam_size, dtype=torch.int64) + scores = torch.zeros(batch_size * beam_size, dtype=torch.float32) + + active = torch.arange(0, batch_size * beam_size, dtype=torch.int64) + base_mask = torch.arange(0, batch_size * beam_size, dtype=torch.int64) + global_offset = torch.arange(0, batch_size * beam_size, beam_size, dtype=torch.int64) + + eos_beam_fill = torch.tensor([0] + (beam_size - 1) * [float('-inf')]) + + if self.cuda: + translation = translation.cuda() + lengths = lengths.cuda() + active = active.cuda() + base_mask = base_mask.cuda() + scores = scores.cuda() + global_offset = global_offset.cuda() + eos_beam_fill = eos_beam_fill.cuda() + + translation[:, 0] = BOS + + words, context = initial_input, initial_context + + if self.batch_first: + word_view = (-1, 1) + ctx_batch_dim = 0 + attn_query_dim = 1 + else: + word_view = (1, -1) + ctx_batch_dim = 1 + attn_query_dim = 0 + + # replicate context + if self.batch_first: + # context[0] (encoder state): (batch, seq, feature) + _, seq, feature = context[0].shape + context[0] = context[0].unsqueeze(1) + context[0] = context[0].expand(-1, beam_size, -1, -1) + context[0] = context[0].contiguous().view(batch_size * beam_size, seq, feature) + # context[0]: (batch * beam, seq, feature) + else: + # context[0] (encoder state): (seq, batch, feature) + seq, _, feature = context[0].shape + context[0] = context[0].unsqueeze(2) + context[0] = context[0].expand(-1, -1, beam_size, -1) + context[0] = context[0].contiguous().view(seq, batch_size * beam_size, feature) + # context[0]: (seq, batch * beam, feature) + + #context[1] (encoder seq length): (batch) + context[1] = context[1].unsqueeze(1) + context[1] = context[1].expand(-1, beam_size) + context[1] = context[1].contiguous().view(batch_size * beam_size) + #context[1]: (batch * beam) + + accu_attn_scores = torch.zeros(batch_size * beam_size, seq) + if self.cuda: + accu_attn_scores = accu_attn_scores.cuda() + + counter = 0 + for idx in range(1, self.max_seq_len): + if not len(active): + break + counter += 1 + + eos_mask = (words == EOS) + eos_mask = eos_mask.view(-1, beam_size) + + terminating, _ = eos_mask.min(dim=1) + + lengths[active[~eos_mask.view(-1)]] += 1 + + words, logprobs, attn, context = self.model.generate(words, context, beam_size) + + attn = attn.float().squeeze(attn_query_dim) + attn = attn.masked_fill(eos_mask.view(-1).unsqueeze(1), 0) + accu_attn_scores[active] += attn + + # words: (batch, beam, k) + words = words.view(-1, beam_size, beam_size) + words = words.masked_fill(eos_mask.unsqueeze(2), EOS) + + # logprobs: (batch, beam, k) + logprobs = logprobs.float().view(-1, beam_size, beam_size) + + if eos_mask.any(): + logprobs[eos_mask] = eos_beam_fill + + active_scores = scores[active].view(-1, beam_size) + # new_scores: (batch, beam, k) + new_scores = active_scores.unsqueeze(2) + logprobs + + if idx == 1: + new_scores[:, 1:, :].fill_(float('-inf')) + + new_scores = new_scores.view(-1, beam_size * beam_size) + # index: (batch, beam) + _, index = new_scores.topk(beam_size, dim=1) + source_beam = index / beam_size + + new_scores = new_scores.view(-1, beam_size * beam_size) + best_scores = torch.gather(new_scores, 1, index) + scores[active] = best_scores.view(-1) + + words = words.view(-1, beam_size * beam_size) + words = torch.gather(words, 1, index) + + # words: (1, batch * beam) + words = words.view(word_view) + + offset = global_offset[:source_beam.shape[0]] + source_beam += offset.unsqueeze(1) + + translation[active, :] = translation[active[source_beam.view(-1)], :] + translation[active, idx] = words.view(-1) + + lengths[active] = lengths[active[source_beam.view(-1)]] + + context[2] = context[2].index_select(1, source_beam.view(-1)) + + if terminating.any(): + not_terminating = ~terminating + not_terminating = not_terminating.unsqueeze(1) + not_terminating = not_terminating.expand(-1, beam_size).contiguous() + + normalization_mask = active.view(-1, beam_size)[terminating] + + # length normalization + norm = lengths[normalization_mask].float() + norm = (norm_const + norm) / (norm_const + 1.0) + norm = norm ** norm_factor + + scores[normalization_mask] /= norm + + # coverage penalty + penalty = accu_attn_scores[normalization_mask] + penalty = penalty.clamp(0, 1) + penalty = penalty.log() + penalty[penalty == float('-inf')] = 0 + penalty = penalty.sum(dim=-1) + + scores[normalization_mask] += cov_penalty_factor * penalty + + mask = base_mask[:len(active)] + mask = mask.masked_select(not_terminating.view(-1)) + + words = words.index_select(ctx_batch_dim, mask) + context[0] = context[0].index_select(ctx_batch_dim, mask) + context[1] = context[1].index_select(0, mask) + context[2] = context[2].index_select(1, mask) + + active = active.masked_select(not_terminating.view(-1)) + + scores = scores.view(batch_size, beam_size) + _, idx = scores.max(dim=1) + + translation = translation[idx + global_offset, :] + lengths = lengths[idx + global_offset] + + return translation, lengths, counter diff --git a/examples/GNMT/seq2seq/inference/inference.py b/examples/GNMT/seq2seq/inference/inference.py new file mode 100644 index 0000000..e0ddbd9 --- /dev/null +++ b/examples/GNMT/seq2seq/inference/inference.py @@ -0,0 +1,88 @@ +import torch + +from seq2seq.data.config import BOS +from seq2seq.data.config import EOS +from seq2seq.inference.beam_search import SequenceGenerator +from seq2seq.utils import batch_padded_sequences + + +class Translator(object): + + def __init__(self, model, tok, + beam_size=5, + len_norm_factor=0.6, + len_norm_const=5.0, + cov_penalty_factor=0.1, + max_seq_len=50, + cuda=False): + + self.model = model + self.tok = tok + self.insert_target_start = [BOS] + self.insert_src_start = [BOS] + self.insert_src_end = [EOS] + self.batch_first = model.batch_first + self.cuda = cuda + self.beam_size = beam_size + + self.generator = SequenceGenerator( + model=self.model, + beam_size=beam_size, + max_seq_len=max_seq_len, + cuda=cuda, + len_norm_factor=len_norm_factor, + len_norm_const=len_norm_const, + cov_penalty_factor=cov_penalty_factor) + + def translate(self, input_sentences): + stats = {} + batch_size = len(input_sentences) + beam_size = self.beam_size + + src_tok = [torch.tensor(self.tok.segment(line)) for line in input_sentences] + + bos = [self.insert_target_start] * (batch_size * beam_size) + bos = torch.LongTensor(bos) + if self.batch_first: + bos = bos.view(-1, 1) + else: + bos = bos.view(1, -1) + + src = batch_padded_sequences(src_tok, self.batch_first, sort=True) + src, src_length, indices = src + + src_length = torch.LongTensor(src_length) + stats['total_enc_len'] = int(src_length.sum()) + + if self.cuda: + src = src.cuda() + src_length = src_length.cuda() + bos = bos.cuda() + + with torch.no_grad(): + context = self.model.encode(src, src_length) + context = [context, src_length, None] + + if beam_size == 1: + generator = self.generator.greedy_search + else: + generator = self.generator.beam_search + + preds, lengths, counter = generator(batch_size, bos, context) + + preds = preds.cpu() + lengths = lengths.cpu() + + output = [] + for idx, pred in enumerate(preds): + end = lengths[idx] - 1 + pred = pred[1: end] + pred = pred.tolist() + out = self.tok.detokenize(pred) + output.append(out) + + stats['total_dec_len'] = int(lengths.sum()) + stats['iters'] = counter + + output = [output[indices.index(i)] for i in range(len(output))] + return output, stats diff --git a/examples/GNMT/seq2seq/models/__init__.py b/examples/GNMT/seq2seq/models/__init__.py new file mode 100644 index 0000000..6217b23 --- /dev/null +++ b/examples/GNMT/seq2seq/models/__init__.py @@ -0,0 +1,4 @@ +from .seq2seq_base import Seq2Seq +from .gnmt import GNMT, ResidualRecurrentDecoder, ResidualRecurrentEncoder + +__all__ = ['GNMT'] diff --git a/examples/GNMT/seq2seq/models/attention.py b/examples/GNMT/seq2seq/models/attention.py new file mode 100644 index 0000000..3a5a0ee --- /dev/null +++ b/examples/GNMT/seq2seq/models/attention.py @@ -0,0 +1,164 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from distiller.modules import * + + +class BahdanauAttention(nn.Module): + """ + It should be very similar to tf.contrib.seq2seq.BahdanauAttention + """ + + def __init__(self, query_size, key_size, num_units, normalize=False, + dropout=0, batch_first=False): + + super(BahdanauAttention, self).__init__() + + self.normalize = normalize + self.batch_first = batch_first + self.num_units = num_units + + self.linear_q = nn.Linear(query_size, num_units, bias=False) + self.linear_k = nn.Linear(key_size, num_units, bias=False) + + self.linear_att = Parameter(torch.Tensor(num_units)) + + self.dropout = nn.Dropout(dropout) + self.mask = None + + # Adding submodules for basic ops to allow quantization: + self.eltwiseadd_qk = EltwiseAdd() + self.eltwiseadd_norm_bias = EltwiseAdd() + self.eltwisemul_norm_scaler = EltwiseMult() + self.tanh = nn.Tanh() + self.matmul_score = Matmul() + self.softmax_att = nn.Softmax(dim=-1) + self.context_matmul = BatchMatmul() + + if self.normalize: + self.normalize_scalar = Parameter(torch.Tensor(1)) + self.normalize_bias = Parameter(torch.Tensor(num_units)) + else: + self.register_parameter('normalize_scalar', None) + self.register_parameter('normalize_bias', None) + + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.num_units) + self.linear_att.data.uniform_(-stdv, stdv) + + if self.normalize: + self.normalize_scalar.data.fill_(stdv) + self.normalize_bias.data.zero_() + + def set_mask(self, context_len, context): + """ + sets self.mask which is applied before softmax + ones for inactive context fields, zeros for active context fields + + :param context_len: b + :param context: if batch_first: (b x t_k x n) else: (t_k x b x n) + + self.mask: (b x t_k) + """ + + if self.batch_first: + max_len = context.size(1) + else: + max_len = context.size(0) + + indices = torch.arange(0, max_len, dtype=torch.int64, device=context.device) + self.mask = indices >= (context_len.unsqueeze(1)) + + def calc_score(self, att_query, att_keys): + """ + Calculate Bahdanau score + + :param att_query: b x t_q x n + :param att_keys: b x t_k x n + + return b x t_q x t_k scores + """ + + b, t_k, n = att_keys.size() + t_q = att_query.size(1) + + att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) + att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) + sum_qk = self.eltwiseadd_qk(att_query, att_keys) + + if self.normalize: + sum_qk = self.eltwiseadd_norm_bias(sum_qk, self.normalize_bias) + + tmp = self.linear_att.to(torch.float32) + linear_att = tmp / tmp.norm() + linear_att = linear_att.to(self.normalize_scalar) + + linear_att = self.eltwisemul_norm_scaler(linear_att, self.normalize_scalar) + else: + linear_att = self.linear_att + + out = self.matmul_score(self.tanh(sum_qk),linear_att) + return out + + def forward(self, query, keys): + """ + + :param query: if batch_first: (b x t_q x n) else: (t_q x b x n) + :param keys: if batch_first: (b x t_k x n) else (t_k x b x n) + + :returns: (context, scores_normalized) + context: if batch_first: (b x t_q x n) else (t_q x b x n) + scores_normalized: if batch_first (b x t_q x t_k) else (t_q x b x t_k) + """ + + # first dim of keys and query has to be 'batch', it's needed for bmm + if not self.batch_first: + keys = keys.transpose(0, 1) + if query.dim() == 3: + query = query.transpose(0, 1) + + if query.dim() == 2: + single_query = True + query = query.unsqueeze(1) + else: + single_query = False + + b = query.size(0) + t_k = keys.size(1) + t_q = query.size(1) + + # FC layers to transform query and key + processed_query = self.linear_q(query) + # TODO move this out of decoder for efficiency during inference + processed_key = self.linear_k(keys) + + # scores: (b x t_q x t_k) + scores = self.calc_score(processed_query, processed_key) + + if self.mask is not None: + mask = self.mask.unsqueeze(1).expand(b, t_q, t_k) + # TODO I can't use -INF because of overflow check in pytorch + scores.data.masked_fill_(mask, -65504.0) + + # Normalize the scores, softmax over t_k + scores_normalized = self.softmax_att(scores) + + # Calculate the weighted average of the attention inputs according to + # the scores + scores_normalized = self.dropout(scores_normalized) + # context: (b x t_q x n) + context = self.context_matmul(scores_normalized, keys) + + if single_query: + context = context.squeeze(1) + scores_normalized = scores_normalized.squeeze(1) + elif not self.batch_first: + context = context.transpose(0, 1) + scores_normalized = scores_normalized.transpose(0, 1) + + return context, scores_normalized diff --git a/examples/GNMT/seq2seq/models/decoder.py b/examples/GNMT/seq2seq/models/decoder.py new file mode 100644 index 0000000..66d1333 --- /dev/null +++ b/examples/GNMT/seq2seq/models/decoder.py @@ -0,0 +1,140 @@ +import itertools + +import torch +import torch.nn as nn + +from seq2seq.models.attention import BahdanauAttention +import seq2seq.data.config as config +from distiller.modules import * + + +class RecurrentAttention(nn.Module): + + def __init__(self, input_size, context_size, hidden_size, num_layers=1, + bias=True, batch_first=False, dropout=0): + + super(RecurrentAttention, self).__init__() + + self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bias, + batch_first) + + self.attn = BahdanauAttention(hidden_size, context_size, context_size, + normalize=True, batch_first=batch_first) + + self.dropout = nn.Dropout(dropout) + + def forward(self, inputs, hidden, context, context_len): + # set attention mask, sequences have different lengths, this mask + # allows to include only valid elements of context in attention's + # softmax + self.attn.set_mask(context_len, context) + + rnn_outputs, hidden = self.rnn(inputs, hidden) + attn_outputs, scores = self.attn(rnn_outputs, context) + rnn_outputs = self.dropout(rnn_outputs) + + return rnn_outputs, hidden, attn_outputs, scores + + +class Classifier(nn.Module): + + def __init__(self, in_features, out_features, math='fp32'): + super(Classifier, self).__init__() + + self.out_features = out_features + + # padding required to trigger HMMA kernels + if math == 'fp16': + out_features = (out_features + 7) // 8 * 8 + + self.classifier = nn.Linear(in_features, out_features) + + def forward(self, x): + out = self.classifier(x) + out = out[..., :self.out_features] + return out + + +class ResidualRecurrentDecoder(nn.Module): + + def __init__(self, vocab_size, hidden_size=128, num_layers=8, bias=True, + dropout=0, batch_first=False, math='fp32', embedder=None): + + super(ResidualRecurrentDecoder, self).__init__() + + self.num_layers = num_layers + + self.att_rnn = RecurrentAttention(hidden_size, hidden_size, + hidden_size, num_layers=1, + batch_first=batch_first) + + self.rnn_layers = nn.ModuleList() + for _ in range(num_layers - 1): + self.rnn_layers.append( + nn.LSTM(2 * hidden_size, hidden_size, num_layers=1, bias=bias, + batch_first=batch_first)) + + if embedder is not None: + self.embedder = embedder + else: + self.embedder = nn.Embedding(vocab_size, hidden_size, + padding_idx=config.PAD) + + self.classifier = Classifier(hidden_size, vocab_size, math) + self.dropout = nn.Dropout(p=dropout) + + # Adding submodules for basic ops to allow quantization: + self.eltwiseadd_residuals = nn.ModuleList([EltwiseAdd() for _ in range(1, len(self.rnn_layers))]) + self.attention_concats = nn.ModuleList([Concat(2) for _ in range(len(self.rnn_layers))]) + + def init_hidden(self, hidden): + if hidden is not None: + # per-layer chunks + hidden = hidden.chunk(self.num_layers) + # (h, c) chunks for LSTM layer + hidden = tuple(i.chunk(2) for i in hidden) + else: + hidden = [None] * self.num_layers + + self.next_hidden = [] + return hidden + + def append_hidden(self, h): + if self.inference: + self.next_hidden.append(h) + + def package_hidden(self): + if self.inference: + hidden = torch.cat(tuple(itertools.chain(*self.next_hidden))) + else: + hidden = None + return hidden + + def forward(self, inputs, context, inference=False): + self.inference = inference + + enc_context, enc_len, hidden = context + hidden = self.init_hidden(hidden) + + x = self.embedder(inputs) + + x, h, attn, scores = self.att_rnn(x, hidden[0], enc_context, enc_len) + self.append_hidden(h) + + x = self.dropout(x) + x = self.attention_concats[0](x, attn) + x, h = self.rnn_layers[0](x, hidden[1]) + self.append_hidden(h) + + for i in range(1, len(self.rnn_layers)): + residual = x + x = self.dropout(x) + x = self.attention_concats[i](x, attn) + x, h = self.rnn_layers[i](x, hidden[i + 1]) + self.append_hidden(h) + x = self.eltwiseadd_residuals[i-1](x, residual) + + x = self.classifier(x) + hidden = self.package_hidden() + + return x, scores, [enc_context, enc_len, hidden] diff --git a/examples/GNMT/seq2seq/models/encoder.py b/examples/GNMT/seq2seq/models/encoder.py new file mode 100644 index 0000000..c7cf1bd --- /dev/null +++ b/examples/GNMT/seq2seq/models/encoder.py @@ -0,0 +1,62 @@ +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence +from distiller.modules import * + +import seq2seq.data.config as config + + +class ResidualRecurrentEncoder(nn.Module): + + def __init__(self, vocab_size, hidden_size=128, num_layers=8, bias=True, + dropout=0, batch_first=False, embedder=None): + + super(ResidualRecurrentEncoder, self).__init__() + self.batch_first = batch_first + self.rnn_layers = nn.ModuleList() + self.rnn_layers.append( + nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=bias, + batch_first=batch_first, bidirectional=True)) + + self.rnn_layers.append( + nn.LSTM((2 * hidden_size), hidden_size, num_layers=1, bias=bias, + batch_first=batch_first)) + + for _ in range(num_layers - 2): + self.rnn_layers.append( + nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=bias, + batch_first=batch_first)) + + self.dropout = nn.Dropout(p=dropout) + + if embedder is not None: + self.embedder = embedder + else: + self.embedder = nn.Embedding(vocab_size, hidden_size, + padding_idx=config.PAD) + + # Adding submodules for basic ops to allow quantization: + self.eltwiseadd_residuals = nn.ModuleList([EltwiseAdd() for _ in range(2, len(self.rnn_layers))]) + + def forward(self, inputs, lengths): + x = self.embedder(inputs) + + # bidirectional layer + x = pack_padded_sequence(x, lengths.cpu().numpy(), + batch_first=self.batch_first) + x, _ = self.rnn_layers[0](x) + x, _ = pad_packed_sequence(x, batch_first=self.batch_first) + + # 1st unidirectional layer + x = self.dropout(x) + x, _ = self.rnn_layers[1](x) + + # the rest of unidirectional layers, + # with residual connections starting from 3rd layer + for i in range(2, len(self.rnn_layers)): + residual = x + x = self.dropout(x) + x, _ = self.rnn_layers[i](x) + x = self.eltwiseadd_residuals[i-2](x, residual) + + return x diff --git a/examples/GNMT/seq2seq/models/gnmt.py b/examples/GNMT/seq2seq/models/gnmt.py new file mode 100644 index 0000000..c576029 --- /dev/null +++ b/examples/GNMT/seq2seq/models/gnmt.py @@ -0,0 +1,36 @@ +import torch.nn as nn + +import seq2seq.data.config as config +from .seq2seq_base import Seq2Seq +from .decoder import ResidualRecurrentDecoder +from .encoder import ResidualRecurrentEncoder + + +class GNMT(Seq2Seq): + def __init__(self, vocab_size, hidden_size=512, num_layers=8, bias=True, + dropout=0.2, batch_first=False, math='fp32', + share_embedding=False): + + super(GNMT, self).__init__(batch_first=batch_first) + + if share_embedding: + embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD) + else: + embedder = None + + self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, + num_layers, bias, dropout, + batch_first, embedder) + + self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, + num_layers, bias, dropout, + batch_first, math, embedder) + + + + def forward(self, input_encoder, input_enc_len, input_decoder): + context = self.encode(input_encoder, input_enc_len) + context = (context, input_enc_len, None) + output, _, _ = self.decode(input_decoder, context) + + return output diff --git a/examples/GNMT/seq2seq/models/seq2seq_base.py b/examples/GNMT/seq2seq/models/seq2seq_base.py new file mode 100644 index 0000000..844a3c4 --- /dev/null +++ b/examples/GNMT/seq2seq/models/seq2seq_base.py @@ -0,0 +1,22 @@ +import torch.nn as nn +from torch.nn.functional import log_softmax + + +class Seq2Seq(nn.Module): + def __init__(self, encoder=None, decoder=None, batch_first=False): + super(Seq2Seq, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.batch_first = batch_first + + def encode(self, inputs, lengths): + return self.encoder(inputs, lengths) + + def decode(self, inputs, context, inference=False): + return self.decoder(inputs, context, inference) + + def generate(self, inputs, context, beam_size): + logits, scores, new_context = self.decode(inputs, context, True) + logprobs = log_softmax(logits, dim=-1) + logprobs, words = logprobs.topk(beam_size, dim=-1) + return words, logprobs, scores, new_context diff --git a/examples/GNMT/seq2seq/utils.py b/examples/GNMT/seq2seq/utils.py new file mode 100644 index 0000000..8629447 --- /dev/null +++ b/examples/GNMT/seq2seq/utils.py @@ -0,0 +1,116 @@ +from contextlib import contextmanager +import os +import logging.config + +import numpy as np +import torch +from torch.nn.utils.rnn import pack_padded_sequence + +import seq2seq.data.config as config + + +def barrier(): + """ Calls all_reduce on dummy tensor.""" + if torch.distributed.is_initialized(): + torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) + torch.cuda.synchronize() + + +def get_rank(): + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + return rank + +def get_world_size(): + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +@contextmanager +def sync_workers(): + """ Gets distributed rank and synchronizes workers at exit""" + rank = get_rank() + yield rank + barrier() + + +def setup_logging(log_file='log.log'): + """Setup logging configuration + """ + class RankFilter(logging.Filter): + def __init__(self, rank): + self.rank = rank + + def filter(self, record): + record.rank = self.rank + return True + + rank = get_rank() + rank_filter = RankFilter(rank) + + logging.basicConfig(level=logging.DEBUG, + format="%(asctime)s - %(levelname)s - %(rank)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + filename=log_file, + filemode='w') + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(rank)s: %(message)s') + console.setFormatter(formatter) + logging.getLogger('').addHandler(console) + logging.getLogger('').addFilter(rank_filter) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, skip_first=True): + self.reset() + self.skip = skip_first + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + + if self.skip: + self.skip = False + else: + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def batch_padded_sequences(seq, batch_first=False, sort=False): + if sort: + key = lambda item: len(item[1]) + indices, seq = zip(*sorted(enumerate(seq), key=key, reverse=True)) + else: + indices = range(len(seq)) + + lengths = [len(sentence) for sentence in seq] + batch_length = max(lengths) + seq_tensor = torch.LongTensor(batch_length, len(seq)).fill_(config.PAD) + for idx, sentence in enumerate(seq): + end_seq = lengths[idx] + seq_tensor[:end_seq, idx].copy_(sentence[:end_seq]) + if batch_first: + seq_tensor = seq_tensor.t() + return seq_tensor, lengths, indices + + +def debug_tensor(tensor, name): + logging.info(name) + tensor = tensor.float().cpu().numpy() + logging.info('MIN: {min} MAX: {max} AVG: {mean} STD: {std} NAN: {nans} INF: {infs}' + .format(min=tensor.min(), max=tensor.max(), mean=tensor.mean(), + std=tensor.std(), nans=np.isnan(tensor).sum(), infs=np.isinf(tensor).sum())) diff --git a/examples/GNMT/translate.py b/examples/GNMT/translate.py new file mode 100644 index 0000000..75275cf --- /dev/null +++ b/examples/GNMT/translate.py @@ -0,0 +1,336 @@ + +# Copyright 2019 The MLPerf Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse +import codecs +import time +import warnings +from ast import literal_eval +from itertools import zip_longest + +import torch + +from seq2seq import models +from seq2seq.inference.inference import Translator +from seq2seq.utils import AverageMeter +import subprocess +import os +import seq2seq.data.config as config +from seq2seq.data.dataset import ParallelDataset +import logging +from seq2seq.utils import AverageMeter + + +def parse_args(): + parser = argparse.ArgumentParser(description='GNMT Translate', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + # data + dataset = parser.add_argument_group('data setup') + dataset.add_argument('--dataset-dir', default=None, required=True, + help='path to directory with input data') + dataset.add_argument('-i', '--input', required=True, + help='input file (tokenized)') + dataset.add_argument('-o', '--output', required=True, + help='output file (tokenized)') + dataset.add_argument('-m', '--model', required=True, + help='model checkpoint file') + dataset.add_argument('-r', '--reference', default=None, + help='full path to the file with reference \ + translations (for sacrebleu)') + + # parameters + params = parser.add_argument_group('inference setup') + params.add_argument('--batch-size', default=128, type=int, + help='batch size') + params.add_argument('--beam-size', default=5, type=int, + help='beam size') + params.add_argument('--max-seq-len', default=80, type=int, + help='maximum prediciton sequence length') + params.add_argument('--len-norm-factor', default=0.6, type=float, + help='length normalization factor') + params.add_argument('--cov-penalty-factor', default=0.1, type=float, + help='coverage penalty factor') + params.add_argument('--len-norm-const', default=5.0, type=float, + help='length normalization constant') + # general setup + general = parser.add_argument_group('general setup') + + general.add_argument('--mode', default='accuracy', choices=['accuracy', + 'performance'], help='test in accuracy or performance mode') + + general.add_argument('--math', default='fp16', choices=['fp32', 'fp16'], + help='arithmetic type') + + batch_first_parser = general.add_mutually_exclusive_group(required=False) + batch_first_parser.add_argument('--batch-first', dest='batch_first', + action='store_true', + help='uses (batch, seq, feature) data \ + format for RNNs') + batch_first_parser.add_argument('--seq-first', dest='batch_first', + action='store_false', + help='uses (seq, batch, feature) data \ + format for RNNs') + batch_first_parser.set_defaults(batch_first=True) + + cuda_parser = general.add_mutually_exclusive_group(required=False) + cuda_parser.add_argument('--cuda', dest='cuda', action='store_true', + help='enables cuda (use \'--no-cuda\' to disable)') + cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false', + help=argparse.SUPPRESS) + cuda_parser.set_defaults(cuda=True) + + cudnn_parser = general.add_mutually_exclusive_group(required=False) + cudnn_parser.add_argument('--cudnn', dest='cudnn', action='store_true', + help='enables cudnn (use \'--no-cudnn\' to disable)') + cudnn_parser.add_argument('--no-cudnn', dest='cudnn', action='store_false', + help=argparse.SUPPRESS) + cudnn_parser.set_defaults(cudnn=True) + + general.add_argument('--print-freq', '-p', default=1, type=int, + help='print log every PRINT_FREQ batches') + + return parser.parse_args() + + +def grouper(iterable, size, fillvalue=None): + args = [iter(iterable)] * size + return zip_longest(*args, fillvalue=fillvalue) + + +def write_output(output_file, lines): + for line in lines: + output_file.write(line) + output_file.write('\n') + + +def checkpoint_from_distributed(state_dict): + ret = False + for key, _ in state_dict.items(): + if key.find('module.') != -1: + ret = True + break + return ret + + +def unwrap_distributed(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key.replace('module.', '') + new_state_dict[new_key] = value + + return new_state_dict + + +def main(): + args = parse_args() + print(args) + + if args.cuda: + torch.cuda.set_device(0) + if not args.cuda and torch.cuda.is_available(): + warnings.warn('cuda is available but not enabled') + if args.math == 'fp16' and not args.cuda: + raise RuntimeError('fp16 requires cuda') + if not args.cudnn: + torch.backends.cudnn.enabled = False + + checkpoint = torch.load(args.model, map_location={'cuda:0': 'cpu'}) + + vocab_size = checkpoint['tokenizer'].vocab_size + model_config = dict(vocab_size=vocab_size, math=checkpoint['config'].math, + **literal_eval(checkpoint['config'].model_config)) + model_config['batch_first'] = args.batch_first + model = models.GNMT(**model_config) + + state_dict = checkpoint['state_dict'] + if checkpoint_from_distributed(state_dict): + state_dict = unwrap_distributed(state_dict) + + model.load_state_dict(state_dict) + + if args.math == 'fp32': + dtype = torch.FloatTensor + if args.math == 'fp16': + dtype = torch.HalfTensor + + model.type(dtype) + if args.cuda: + model = model.cuda() + model.eval() + + tokenizer = checkpoint['tokenizer'] + + + test_data = ParallelDataset( + src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME), + tgt_fname=os.path.join(args.dataset_dir, config.TGT_TEST_FNAME), + tokenizer=tokenizer, + min_len=0, + max_len=150, + sort=False) + + test_loader = test_data.get_loader(batch_size=args.batch_size, + batch_first=True, + shuffle=False, + num_workers=0, + drop_last=False, + distributed=False) + + translator = Translator(model, + tokenizer, + beam_size=args.beam_size, + max_seq_len=args.max_seq_len, + len_norm_factor=args.len_norm_factor, + len_norm_const=args.len_norm_const, + cov_penalty_factor=args.cov_penalty_factor, + cuda=args.cuda) + + model.eval() + torch.cuda.empty_cache() + + # only write the output to file in accuracy mode + if args.mode == 'accuracy': + test_file = open(args.output, 'w', encoding='UTF-8') + + batch_time = AverageMeter(False) + tot_tok_per_sec = AverageMeter(False) + iterations = AverageMeter(False) + enc_seq_len = AverageMeter(False) + dec_seq_len = AverageMeter(False) + stats = {} + + for i, (src, tgt, indices) in enumerate(test_loader): + translate_timer = time.time() + src, src_length = src + + if translator.batch_first: + batch_size = src.size(0) + else: + batch_size = src.size(1) + beam_size = args.beam_size + + bos = [translator.insert_target_start] * (batch_size * beam_size) + bos = torch.LongTensor(bos) + if translator.batch_first: + bos = bos.view(-1, 1) + else: + bos = bos.view(1, -1) + + src_length = torch.LongTensor(src_length) + stats['total_enc_len'] = int(src_length.sum()) + + if args.cuda: + src = src.cuda() + src_length = src_length.cuda() + bos = bos.cuda() + + with torch.no_grad(): + context = translator.model.encode(src, src_length) + context = [context, src_length, None] + + if beam_size == 1: + generator = translator.generator.greedy_search + else: + generator = translator.generator.beam_search + preds, lengths, counter = generator(batch_size, bos, context) + + stats['total_dec_len'] = lengths.sum().item() + stats['iters'] = counter + + preds = preds.cpu() + lengths = lengths.cpu() + + output = [] + for idx, pred in enumerate(preds): + end = lengths[idx] - 1 + pred = pred[1: end] + pred = pred.tolist() + out = translator.tok.detokenize(pred) + output.append(out) + + # only write the output to file in accuracy mode + if args.mode == 'accuracy': + output = [output[indices.index(i)] for i in range(len(output))] + for line in output: + test_file.write(line) + test_file.write('\n') + + + # Get timing + elapsed = time.time() - translate_timer + batch_time.update(elapsed, batch_size) + + total_tokens = stats['total_dec_len'] + stats['total_enc_len'] + ttps = total_tokens / elapsed + tot_tok_per_sec.update(ttps, batch_size) + + iterations.update(stats['iters']) + enc_seq_len.update(stats['total_enc_len'] / batch_size, batch_size) + dec_seq_len.update(stats['total_dec_len'] / batch_size, batch_size) + + if i % 5 == 0: + log = [] + log += 'TEST ' + log += 'Time {:.3f} ({:.3f})\t'.format(batch_time.val, batch_time.avg) + log += 'Decoder iters {:.1f} ({:.1f})\t'.format(iterations.val, iterations.avg) + log += 'Tok/s {:.0f} ({:.0f})'.format(tot_tok_per_sec.val, tot_tok_per_sec.avg) + log = ''.join(log) + print(log) + + + # summary timing + time_per_sentence = (batch_time.avg / batch_size) + log = [] + log += 'TEST SUMMARY:\n' + log += 'Lines translated: {}\t'.format(len(test_loader.dataset)) + log += 'Avg total tokens/s: {:.0f}\n'.format(tot_tok_per_sec.avg) + log += 'Avg time per batch: {:.3f} s\t'.format(batch_time.avg) + log += 'Avg time per sentence: {:.3f} ms\n'.format(1000 * time_per_sentence) + log += 'Avg encoder seq len: {:.2f}\t'.format(enc_seq_len.avg) + log += 'Avg decoder seq len: {:.2f}\t'.format(dec_seq_len.avg) + log += 'Total decoder iterations: {}'.format(int(iterations.sum)) + log = ''.join(log) + print(log) + + # only write the output to file in accuracy mode + if args.mode == 'accuracy': + test_file.close() + + test_path = args.output + # run moses detokenizer + detok_path = os.path.join(args.dataset_dir, config.DETOKENIZER) + detok_test_path = test_path + '.detok' + + with open(detok_test_path, 'w') as detok_test_file, \ + open(test_path, 'r') as test_file: + subprocess.run(['perl', detok_path], stdin=test_file, + stdout=detok_test_file, stderr=subprocess.DEVNULL) + + + # run sacrebleu + reference_path = os.path.join(args.dataset_dir, + config.TGT_TEST_TARGET_FNAME) + sacrebleu = subprocess.run(['sacrebleu --input {} {} --score-only -lc --tokenize intl'.format(detok_test_path, + reference_path)], + stdout=subprocess.PIPE, shell=True) + bleu = float(sacrebleu.stdout.strip()) + + print('BLEU on test dataset: {}'.format(bleu)) + + print('Finished evaluation on test set') + +if __name__ == '__main__': + main() diff --git a/examples/GNMT/verify_dataset.sh b/examples/GNMT/verify_dataset.sh new file mode 100644 index 0000000..d149e2a --- /dev/null +++ b/examples/GNMT/verify_dataset.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +set -e + +ACTUAL_SRC_TRAIN=`cat data/train.tok.clean.bpe.32000.en |md5sum` +EXPECTED_SRC_TRAIN='b7482095b787264a310d4933d197a134 -' +if [[ $ACTUAL_SRC_TRAIN = $EXPECTED_SRC_TRAIN ]]; then + echo "OK: correct data/train.tok.clean.bpe.32000.en" +else + echo "ERROR: incorrect data/train.tok.clean.bpe.32000.en" + echo "ERROR: expected $EXPECTED_SRC_TRAIN" + echo "ERROR: found $ACTUAL_SRC_TRAIN" +fi + +ACTUAL_TGT_TRAIN=`cat data/train.tok.clean.bpe.32000.de |md5sum` +EXPECTED_TGT_TRAIN='409064aedaef5b7c458ff19a7beda462 -' +if [[ $ACTUAL_TGT_TRAIN = $EXPECTED_TGT_TRAIN ]]; then + echo "OK: correct data/train.tok.clean.bpe.32000.de" +else + echo "ERROR: incorrect data/train.tok.clean.bpe.32000.de" + echo "ERROR: expected $EXPECTED_TGT_TRAIN" + echo "ERROR: found $ACTUAL_TGT_TRAIN" +fi + +ACTUAL_SRC_VAL=`cat data/newstest_dev.tok.clean.bpe.32000.en |md5sum` +EXPECTED_SRC_VAL='704c4ba8c8b63df1f6678d32b91438b5 -' +if [[ $ACTUAL_SRC_VAL = $EXPECTED_SRC_VAL ]]; then + echo "OK: correct data/newstest_dev.tok.clean.bpe.32000.en" +else + echo "ERROR: incorrect data/newstest_dev.tok.clean.bpe.32000.en" + echo "ERROR: expected $EXPECTED_SRC_VAL" + echo "ERROR: found $ACTUAL_SRC_VAL" +fi + +ACTUAL_TGT_VAL=`cat data/newstest_dev.tok.clean.bpe.32000.de |md5sum` +EXPECTED_TGT_VAL='d27f5a64c839e20c5caa8b9d60075dde -' +if [[ $ACTUAL_TGT_VAL = $EXPECTED_TGT_VAL ]]; then + echo "OK: correct data/newstest_dev.tok.clean.bpe.32000.de" +else + echo "ERROR: incorrect data/newstest_dev.tok.clean.bpe.32000.de" + echo "ERROR: expected $EXPECTED_TGT_VAL" + echo "ERROR: found $ACTUAL_TGT_VAL" +fi + +ACTUAL_SRC_TEST=`cat data/newstest2014.tok.bpe.32000.en |md5sum` +EXPECTED_SRC_TEST='cb014e2509f86cd81d5a87c240c07464 -' +if [[ $ACTUAL_SRC_TEST = $EXPECTED_SRC_TEST ]]; then + echo "OK: correct data/newstest2014.tok.bpe.32000.en" +else + echo "ERROR: incorrect data/newstest2014.tok.bpe.32000.en" + echo "ERROR: expected $EXPECTED_SRC_TEST" + echo "ERROR: found $ACTUAL_SRC_TEST" +fi + +ACTUAL_TGT_TEST=`cat data/newstest2014.tok.bpe.32000.de |md5sum` +EXPECTED_TGT_TEST='d616740f6026dc493e66efdf9ac1cb04 -' +if [[ $ACTUAL_TGT_TEST = $EXPECTED_TGT_TEST ]]; then + echo "OK: correct data/newstest2014.tok.bpe.32000.de" +else + echo "ERROR: incorrect data/newstest2014.tok.bpe.32000.de" + echo "ERROR: expected $EXPECTED_TGT_TEST" + echo "ERROR: found $ACTUAL_TGT_TEST" +fi + +ACTUAL_TGT_TEST_TARGET=`cat data/newstest2014.de |md5sum` +EXPECTED_TGT_TEST_TARGET='f6c3818b477e4a25cad68b61cc883c17 -' +if [[ $ACTUAL_TGT_TEST_TARGET = $EXPECTED_TGT_TEST_TARGET ]]; then + echo "OK: correct data/newstest2014.de" +else + echo "ERROR: incorrect data/newstest2014.de" + echo "ERROR: expected $EXPECTED_TGT_TEST_TARGET" + echo "ERROR: found $ACTUAL_TGT_TEST_TARGET" +fi -- GitLab