diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index fd892ea60ff6cf54ee37e00f8cd1ca175b84c5c9..282d7fde0c4bf0f387a19d731d932444508abf47 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -15,6 +15,9 @@ # from .eltwise import EltwiseAdd, EltwiseMult -from .concat import Concat +from .grouping import * +from .rnn import DistillerLSTM, DistillerLSTMCell -__all__ = ['EltwiseAdd', 'EltwiseMult', 'Concat'] +__all__ = ['EltwiseAdd', 'EltwiseMult', + 'Concat', 'Chunk', 'Split', 'Stack', + 'DistillerLSTMCell', 'DistillerLSTM'] diff --git a/distiller/modules/concat.py b/distiller/modules/grouping.py similarity index 52% rename from distiller/modules/concat.py rename to distiller/modules/grouping.py index 1a43416acd238fefbaa530845ee3dddd3a3cc1ca..011eb5f082af47ba697a8536d878dd659593ca99 100644 --- a/distiller/modules/concat.py +++ b/distiller/modules/grouping.py @@ -25,3 +25,32 @@ class Concat(nn.Module): def forward(self, *seq): return torch.cat(seq, dim=self.dim) + + +class Chunk(nn.Module): + def __init__(self, chunks, dim=0): + super(Chunk, self).__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, tensor): + return tensor.chunk(self.chunks, dim=self.dim) + + +class Split(nn.Module): + def __init__(self, split_size_or_sections, dim=0): + super(Split, self).__init__() + self.split_size_or_sections = split_size_or_sections + self.dim = dim + + def forward(self, tensor): + return torch.split(tensor, self.split_size_or_sections, dim=self.dim) + + +class Stack(nn.Module): + def __init__(self, dim=0): + super(Stack, self).__init__() + self.dim = dim + + def forward(self, seq): + return torch.stack(seq, dim=self.dim) diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7471f32fdf90f20f6e8367f3e8dd14faa49778b9 --- /dev/null +++ b/distiller/modules/rnn.py @@ -0,0 +1,402 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torch.nn as nn +import numpy as np +from .eltwise import EltwiseAdd, EltwiseMult +from itertools import product + +__all__ = ['DistillerLSTMCell', 'DistillerLSTM'] + +class DistillerLSTMCell(nn.Module): + """ + A single LSTM block. + The calculation of the output takes into account the input and the previous output and cell state: + https://pytorch.org/docs/stable/nn.html#lstmcell + Args: + input_size (int): the size of the input + hidden_size (int): the size of the hidden state / output + bias (bool): use bias. default: True + + """ + def __init__(self, input_size, hidden_size, bias=True): + super(DistillerLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + # Treat f,i,o,c_ex as one single object: + self.fc_gate_x = nn.Linear(input_size, hidden_size * 4, bias=bias) + self.fc_gate_h = nn.Linear(hidden_size, hidden_size * 4, bias=bias) + self.eltwiseadd_gate = EltwiseAdd() + # Apply activations separately: + self.act_f = nn.Sigmoid() + self.act_i = nn.Sigmoid() + self.act_o = nn.Sigmoid() + self.act_g = nn.Tanh() + # Calculate cell: + self.eltwisemult_cell_forget = EltwiseMult() + self.eltwisemult_cell_input = EltwiseMult() + self.eltwiseadd_cell = EltwiseAdd() + # Calculate hidden: + self.act_h = nn.Tanh() + self.eltwisemult_hidden = EltwiseMult() + self.init_weights() + + def forward(self, x, h=None): + """ + Implemented as defined in https://pytorch.org/docs/stable/nn.html#lstmcell. + """ + x_bsz, x_device = x.size(1), x.device + if h is None: + h = self.init_hidden(x_bsz, device=x_device) + + h_prev, c_prev = h + fc_gate = self.eltwiseadd_gate(self.fc_gate_x(x), self.fc_gate_h(h_prev)) + i, f, g, o = torch.chunk(fc_gate, 4, dim=1) + i, f, g, o = self.act_i(i), self.act_f(f), self.act_g(g), self.act_o(o) + cf, ci = self.eltwisemult_cell_forget(f, c_prev), self.eltwisemult_cell_input(i, g) + c = self.eltwiseadd_cell(cf, ci) + h = self.eltwisemult_hidden(o, self.act_h(c)) + return h, c + + def init_hidden(self, batch_size, device='cuda:0'): + h_0 = torch.zeros(batch_size, self.hidden_size).to(device) + c_0 = torch.zeros(batch_size, self.hidden_size).to(device) + return h_0, c_0 + + def init_weights(self): + initrange = 1 / np.sqrt(self.hidden_size) + self.fc_gate_x.weight.data.uniform_(-initrange, initrange) + self.fc_gate_h.weight.data.uniform_(-initrange, initrange) + + def to_pytorch_impl(self): + module = nn.LSTMCell(self.input_size, self.hidden_size, self.bias) + module.weight_hh, module.weight_ih = \ + nn.Parameter(self.fc_gate_h.weight.clone().detach()), \ + nn.Parameter(self.fc_gate_x.weight.clone().detach()) + if self.bias: + module.bias_hh, module.bias_ih = \ + nn.Parameter(self.fc_gate_h.bias.clone().detach()), \ + nn.Parameter(self.fc_gate_x.bias.clone().detach()) + return module + + @staticmethod + def from_pytorch_impl(lstmcell: nn.LSTMCell): + module = DistillerLSTMCell(input_size=lstmcell.input_size, hidden_size=lstmcell.hidden_size, bias=lstmcell.bias) + module.fc_gate_x.weight = nn.Parameter(lstmcell.weight_ih.clone().detach()) + module.fc_gate_h.weight = nn.Parameter(lstmcell.weight_hh.clone().detach()) + if lstmcell.bias: + module.fc_gate_x.bias = nn.Parameter(lstmcell.bias_ih.clone().detach()) + module.fc_gate_h.bias = nn.Parameter(lstmcell.bias_hh.clone().detach()) + + return module + + def __repr__(self): + return "%s(%d, %d)" % (self.__class__.__name__, self.input_size, self.hidden_size) + + +def process_sequence_wise(cell, x, h=None): + """ + Process the entire sequence through an LSTMCell. + Args: + cell (DistillerLSTMCell): the cell. + x (torch.Tensor): the input + h (tuple of torch.Tensor-s): the hidden states of the LSTMCell. + Returns: + y (torch.Tensor): the output + h (tuple of torch.Tensor-s): the new hidden states of the LSTMCell. + """ + results = [] + for step in x: + y, h = cell(step, h) + results.append(y) + h = (y, h) + return torch.stack(results), h + + +def _repackage_hidden_unidirectional(h): + """ + Repackages the hidden state into nn.LSTM format. (unidirectional use) + """ + h_all = [t[0] for t in h] + c_all = [t[1] for t in h] + return torch.stack(h_all, 0), torch.stack(c_all, 0) + + +def _repackage_hidden_bidirectional(h_result): + """ + Repackages the hidden state into nn.LSTM format. (bidirectional use) + """ + h_all = [t[0] for t in h_result] + c_all = [t[1] for t in h_result] + return torch.cat(h_all, dim=0), torch.cat(c_all, dim=0) + + +def _unpack_bidirectional_input_h(h): + """ + Unpack the bidirectional hidden states into states of the 2 separate directions. + """ + h_t, c_t = h + h_front, h_back = h_t[::2], h_t[1::2] + c_front, c_back = c_t[::2], c_t[1::2] + h_front = (h_front, c_front) + h_back = (h_back, c_back) + return h_front, h_back + + +class DistillerLSTM(nn.Module): + """ + A modular implementation of an LSTM module. + Args: + input_size (int): size of the input + hidden_size (int): size of the hidden connections and output. + num_layers (int): number of LSTMCells + bias (bool): use bias + batch_first (bool): the format of the sequence is (batch_size, seq_len, dim). default: False + dropout : dropout factor + bidirectional (bool): Whether or not the LSTM is bidirectional. default: False (unidirectional). + bidirectional_type (int): 1 or 2, corresponds to type 1 and type 2 as per + https://github.com/pytorch/pytorch/issues/4930. default: 2 + """ + def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, + dropout=0.5, bidirectional=False, bidirectional_type=2): + super(DistillerLSTM, self).__init__() + if num_layers < 1: + raise ValueError("Number of layers has to be at least 1.") + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bidirectional = bidirectional + self.bias = bias + self.batch_first = batch_first + self.bidirectional_type = bidirectional_type + + if bidirectional: + # Following https://github.com/pytorch/pytorch/issues/4930 - + if bidirectional_type == 1: + raise NotImplementedError + # # Process each timestep at the entire layers chain - + # # each timestep is forwarded through `front` and `back` chains independently, + # # similarily to a unidirectional LSTM. + # self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size, bias)] + + # [LSTMCell(hidden_size, hidden_size, bias) + # for _ in range(1, num_layers)]) + # + # self.cells_reverse = nn.ModuleList([LSTMCell(input_size, hidden_size, bias)] + + # [LSTMCell(hidden_size, hidden_size, bias) + # for _ in range(1, num_layers)]) + # self.forward_fn = self.process_layer_wise + # self.layer_chain_fn = self._layer_chain_bidirectional_type1 + + elif bidirectional_type == 2: + # Process the entire sequence at each layer consecutively - + # the output of one layer is the sequence processed through the `front` and `back` cells + # and the input to the next layers are both `output_front` and `output_back`. + self.cells = nn.ModuleList([DistillerLSTMCell(input_size, hidden_size, bias)] + + [DistillerLSTMCell(2 * hidden_size, hidden_size, bias) + for _ in range(1, num_layers)]) + + self.cells_reverse = nn.ModuleList([DistillerLSTMCell(input_size, hidden_size, bias)] + + [DistillerLSTMCell(2 * hidden_size, hidden_size, bias) + for _ in range(1, num_layers)]) + self.forward_fn = self._bidirectional_type2_forward + + else: + raise ValueError("The only allowed types are [1, 2].") + else: + self.cells = nn.ModuleList([DistillerLSTMCell(input_size, hidden_size, bias)] + + [DistillerLSTMCell(hidden_size, hidden_size, bias) + for _ in range(1, num_layers)]) + self.forward_fn = self.process_layer_wise + self.layer_chain_fn = self._layer_chain_unidirectional + + self.dropout = nn.Dropout(dropout) + self.dropout_factor = dropout + + def forward(self, x, h=None): + is_packed_seq = isinstance(x, nn.utils.rnn.PackedSequence) + if is_packed_seq: + x, lengths = nn.utils.rnn.pad_packed_sequence(x, self.batch_first) + + elif self.batch_first: + # Transpose to sequence_first format + x = x.transpose(0, 1) + x_bsz = x.size(1) + + if h is None: + h = self.init_hidden(x_bsz) + + y, h = self.forward_fn(x, h) + if is_packed_seq: + y = nn.utils.rnn.pack_padded_sequence(y, lengths, self.batch_first) + + elif self.batch_first: + # Transpose back to batch_first format + y = y.transpose(0, 1) + return y, h + + def process_layer_wise(self, x, h): + results = [] + for step in x: + y, h = self.layer_chain_fn(step, h) + results.append(y) + return torch.stack(results), h + + def _bidirectional_type2_forward(self, x, h): + """ + Processes the entire sequence through a layer and passes the output sequence to the next layer. + """ + out = x + h_h_result = [] + h_c_result = [] + (h_front_all, c_front_all), (h_back_all, c_back_all) = _unpack_bidirectional_input_h(h) + for i, (cell_front, cell_back) in enumerate(zip(self.cells, self.cells_reverse)): + h_front, h_back = (h_front_all[i], c_front_all[i]), (h_back_all[i], c_back_all[i]) + + # Sequence treatment: + out_front, h_front = process_sequence_wise(cell_front, out, h_front) + out_back, h_back = process_sequence_wise(cell_back, out.flip([0]), h_back) + out = torch.cat([out_front, out_back.flip([0])], dim=-1) + + h_h_result += [h_front[0], h_back[0]] + h_c_result += [h_front[1], h_back[1]] + if i < self.num_layers-1: + out = self.dropout(out) + h = torch.stack(h_h_result, dim=0), torch.stack(h_c_result, dim=0) + return out, h + + def _layer_chain_bidirectional_type1(self, x, h): + # """ + # Process a single timestep through the entire bidirectional layer chain. + # """ + # (h_front_all, c_front_all), (h_back_all, c_back_all) = _repackage_bidirectional_input_h(h) + # h_result = [] + # out_front, out_back = x, x.flip([0]) + # for i, (cell_front, cell_back) in enumerate(zip(self.cells, self.cells_reverse)): + # h_front, h_back = (h_front_all[i], c_front_all[i]), (h_back_all[i], c_back_all[i]) + # h_front, c_front = cell_front(out_front, h_front) + # h_back, c_back = cell_back(out_back, h_back) + # out_front, out_back = h_front, h_back + # if i < self.num_layers-1: + # out_front, out_back = self.dropout(out_front), self.dropout(out_back) + # h_current = torch.stack([h_front, h_back]), torch.stack([c_front, c_back]) + # h_result.append(h_current) + # h_result = _repackage_hidden_bidirectional(h_result) + # return torch.cat([out_front, out_back], dim=-1), h_result + raise NotImplementedError + + def _layer_chain_unidirectional(self, step, h): + """ + Process a single timestep through the entire unidirectional layer chain. + """ + h_all, c_all = h + h_result = [] + out = step + for i, cell in enumerate(self.cells): + h = h_all[i], c_all[i] + out, hid = cell(out, h) + if i < self.num_layers-1: + out = self.dropout(out) + h_result.append((out, hid)) + h_result = _repackage_hidden_unidirectional(h_result) + return out, h_result + + def init_hidden(self, batch_size): + weight = next(self.parameters()) + n_dir = 2 if self.bidirectional else 1 + return (weight.new_zeros(self.num_layers * n_dir, batch_size, self.hidden_size), + weight.new_zeros(self.num_layers * n_dir, batch_size, self.hidden_size)) + + def init_weights(self): + for cell in self.hidden_cells: + cell.init_weights() + + def flatten_parameters(self): + pass + + def to_pytorch_impl(self): + if self.bidirectional and self.bidirectional_type == 1: + raise TypeError("Pytorch implementation of bidirectional LSTM doesn't support type 1.") + + module = nn.LSTM(input_size=self.input_size, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + dropout=self.dropout_factor, + bias=self.bias, + batch_first=self.batch_first, + bidirectional=self.bidirectional) + param_gates = ['i', 'h'] + + param_types = ['weight'] + if self.bias: + param_types.append('bias') + + suffixes = [''] + if self.bidirectional: + suffixes.append('_reverse') + + for i in range(self.num_layers): + for ptype, pgate, psuffix in product(param_types, param_gates, suffixes): + cell = self.cells[i] if psuffix == '' else self.cells_reverse[i] + lstm_pth_param_name = "%s_%sh_l%d%s" % (ptype, pgate, i, psuffix) # e.g. `weight_ih_l0` + gate_name = "fc_gate_%s" % ('x' if pgate == 'i' else 'h') # `fc_gate_x` or `fc_gate_h` + gate = getattr(cell, gate_name) # e.g. `cell.fc_gate_x` + param_tensor = getattr(gate, ptype).clone().detach() + + # same as `module.weight_ih_l0 = nn.Parameter(param_tensor)`: + setattr(module, lstm_pth_param_name, nn.Parameter(param_tensor)) + + module.flatten_parameters() + return module + + @staticmethod + def from_pytorch_impl(lstm: nn.LSTM): + bidirectional = lstm.bidirectional + + module = DistillerLSTM(lstm.input_size, lstm.hidden_size, lstm.num_layers, bias=lstm.bias, + batch_first=lstm.batch_first, + dropout=lstm.dropout, bidirectional=bidirectional) + param_gates = ['i', 'h'] + + param_types = ['weight'] + if lstm.bias: + param_types.append('bias') + + suffixes = [''] + if bidirectional: + suffixes.append('_reverse') + + for i in range(lstm.num_layers): + for ptype, pgate, psuffix in product(param_types, param_gates, suffixes): + cell = module.cells[i] if psuffix == '' else module.cells_reverse[i] + lstm_pth_param_name = "%s_%sh_l%d%s" % (ptype, pgate, i, psuffix) # e.g. `weight_ih_l0` + gate_name = "fc_gate_%s" % ('x' if pgate == 'i' else 'h') # `fc_gate_x` or `fc_gate_h` + gate = getattr(cell, gate_name) # e.g. `cell.fc_gate_x` + param_tensor = getattr(lstm, lstm_pth_param_name).clone().detach() # e.g. `lstm.weight_ih_l0.detach()` + setattr(gate, ptype, nn.Parameter(param_tensor)) + + return module + + def __repr__(self): + return "%s(%d, %d, num_layers=%d, dropout=%.2f, bidirectional=%s)" % \ + (self.__class__.__name__, + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout_factor, + self.bidirectional) diff --git a/distiller/utils.py b/distiller/utils.py index e719d23b17ff41f3ed4abcee20688b66f1171b24..99557f51b4149e745221d39f7c1e80bd3f47ca01 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -643,6 +643,7 @@ def filter_kwargs(dict_to_filter, function_to_call): invalid_args[key] = dict_to_filter[key] return valid_args, invalid_args + def convert_tensors_recursively_to(val, *args, **kwargs): """ Applies `.to(*args, **kwargs)` to each tensor inside val tree. Other values remain the same.""" if isinstance(val, torch.Tensor): diff --git a/examples/word_language_model/manual_lstm_pretrained_stats.yaml b/examples/word_language_model/manual_lstm_pretrained_stats.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30874cc52f13dc58660d8b2fdb7cfd950175da7b --- /dev/null +++ b/examples/word_language_model/manual_lstm_pretrained_stats.yaml @@ -0,0 +1,566 @@ +encoder: + inputs: + 0: + min: 0 + max: 33270 + avg_min: 10.219235186035833 + avg_max: 14172.569453376202 + mean: 2843.6353892971997 + std: 5516.717805093252 + shape: (35, 10) + output: + min: -1.0055707693099976 + max: 0.8849266171455383 + avg_min: -0.7984857493083201 + avg_max: 0.7398376246357271 + mean: 0.0011389567903235516 + std: 0.106637374597835 + shape: (35, 10, 1500) +rnn.cells.0.fc_gate_x: + inputs: + 0: + min: -1.0055707693099976 + max: 0.8849266171455383 + avg_min: -0.47863298946239113 + avg_max: 0.45979668644318944 + mean: 0.0011388687080601616 + std: 0.10664238753626366 + shape: (10, 1500) + output: + min: -7.170577049255371 + max: 6.01923131942749 + avg_min: -4.400209629704961 + avg_max: 3.8546500646612807 + mean: -0.5324024169690869 + std: 1.136934240306631 + shape: (10, 6000) +rnn.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5738822829100009 + avg_max: 0.5650099425972491 + mean: 0.00019907150779485835 + std: 0.10291856745491001 + shape: (10, 1500) + output: + min: -12.353250503540039 + max: 12.154502868652344 + avg_min: -4.178367384987595 + avg_max: 3.8204386773444914 + mean: -0.24742181252201967 + std: 0.6138123563333803 + shape: (10, 6000) +rnn.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -7.170577049255371 + max: 6.01923131942749 + avg_min: -4.400209629704961 + avg_max: 3.8546500646612807 + mean: -0.5324024169690869 + std: 1.136934240306631 + shape: (10, 6000) + 1: + min: -12.353250503540039 + max: 12.154502868652344 + avg_min: -4.178367384987595 + avg_max: 3.8204386773444914 + mean: -0.24742181252201967 + std: 0.6138123563333803 + shape: (10, 6000) + output: + min: -15.612003326416016 + max: 15.450967788696289 + avg_min: -6.393781979404043 + avg_max: 5.522592915806375 + mean: -0.7798242293363614 + std: 1.3385719958875721 + shape: (10, 6000) +rnn.cells.0.act_f: + inputs: + 0: + min: -15.612003326416016 + max: 12.804134368896484 + avg_min: -3.588743975214784 + avg_max: 5.171411778303171 + mean: -0.8314223118150013 + std: 0.9994718826025869 + shape: (10, 1500) + output: + min: 1.6587961226832704e-07 + max: 0.9999972581863403 + avg_min: 0.03998190438239164 + avg_max: 0.9902226688948058 + mean: 0.31776015875831237 + std: 0.18128372322608863 + shape: (10, 1500) +rnn.cells.0.act_i: + inputs: + 0: + min: -12.639559745788574 + max: 11.856657028198242 + avg_min: -6.246054550997017 + avg_max: 3.4008014973694394 + mean: -1.353105740375623 + std: 1.4006327442395972 + shape: (10, 1500) + output: + min: 3.2412126529379748e-06 + max: 0.9999929666519165 + avg_min: 0.0034082192363447438 + avg_max: 0.9486433887111401 + mean: 0.26847494655177206 + std: 0.2133546549199177 + shape: (10, 1500) +rnn.cells.0.act_o: + inputs: + 0: + min: -9.940855979919434 + max: 8.575215339660645 + avg_min: -4.700205807891148 + avg_max: 3.172773839445486 + mean: -0.9314432533239825 + std: 1.108967504745895 + shape: (10, 1500) + output: + min: 4.816373984795064e-05 + max: 0.9998113512992859 + avg_min: 0.012553264552863666 + avg_max: 0.9470582806313093 + mean: 0.318497704179476 + std: 0.19847815427571672 + shape: (10, 1500) +rnn.cells.0.act_c_ex: + inputs: + 0: + min: -11.641252517700195 + max: 15.450967788696289 + avg_min: -4.806999210727877 + avg_max: 4.59695213219979 + mean: -0.003325613286218699 + std: 1.4208848461012338 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9993293769645618 + avg_max: 0.9990078361602005 + mean: -6.53550379788874e-05 + std: 0.7375218759945332 + shape: (10, 1500) +rnn.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 1.6587961226832704e-07 + max: 0.9999972581863403 + avg_min: 0.03998190438239164 + avg_max: 0.9902226688948058 + mean: 0.31776015875831237 + std: 0.18128372322608863 + shape: (10, 1500) + 1: + min: -18.4894962310791 + max: 12.76415729522705 + avg_min: -1.5611846597817076 + avg_max: 1.5773179970730573 + mean: -0.000943341485093316 + std: 0.3030234156804856 + shape: (10, 1500) + output: + min: -18.10362434387207 + max: 12.600543975830078 + avg_min: -1.2997477680276228 + avg_max: 1.2838877794119525 + mean: -3.0967179330471856e-05 + std: 0.13354148372150612 + shape: (10, 1500) +rnn.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 3.2412126529379748e-06 + max: 0.9999929666519165 + avg_min: 0.0034082192363447438 + avg_max: 0.9486433887111401 + mean: 0.26847494655177206 + std: 0.2133546549199177 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9993293769645618 + avg_max: 0.9990078361602005 + mean: -6.53550379788874e-05 + std: 0.7375218759945332 + shape: (10, 1500) + output: + min: -0.9999427199363708 + max: 0.9998571276664734 + avg_min: -0.9042230041123592 + avg_max: 0.8991517046765057 + mean: -0.0009125015388671264 + std: 0.2702341313401001 + shape: (10, 1500) +rnn.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -18.10362434387207 + max: 12.600543975830078 + avg_min: -1.2997477680276228 + avg_max: 1.2838877794119525 + mean: -3.0967179330471856e-05 + std: 0.13354148372150612 + shape: (10, 1500) + 1: + min: -0.9999427199363708 + max: 0.9998571276664734 + avg_min: -0.9042230041123592 + avg_max: 0.8991517046765057 + mean: -0.0009125015388671264 + std: 0.2702341313401001 + shape: (10, 1500) + output: + min: -18.4894962310791 + max: 12.76415729522705 + avg_min: -1.5612795371970447 + avg_max: 1.5773962020150967 + mean: -0.0009434687136402496 + std: 0.3030323653891493 + shape: (10, 1500) +rnn.cells.0.act_h: + inputs: + 0: + min: -18.4894962310791 + max: 12.76415729522705 + avg_min: -1.5612795371970447 + avg_max: 1.5773962020150967 + mean: -0.0009434687136402496 + std: 0.3030323653891493 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.8615034911275732 + avg_max: 0.8620808821404143 + mean: -0.0007417835671436678 + std: 0.26512230259603614 + shape: (10, 1500) +rnn.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 4.816373984795064e-05 + max: 0.9998113512992859 + avg_min: 0.012553264552863666 + avg_max: 0.9470582806313093 + mean: 0.318497704179476 + std: 0.19847815427571672 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.8615034911275732 + avg_max: 0.8620808821404143 + mean: -0.0007417835671436678 + std: 0.26512230259603614 + shape: (10, 1500) + output: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) +rnn.cells.1.fc_gate_x: + inputs: + 0: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) + output: + min: -14.241074562072754 + max: 14.333085060119629 + avg_min: -4.85037961406248 + avg_max: 4.998203643021349 + mean: -0.2753187046971623 + std: 1.2196254520335272 + shape: (10, 6000) +rnn.cells.1.fc_gate_h: + inputs: + 0: + min: -0.9998947978019714 + max: 0.9999969601631165 + avg_min: -0.9060620619319191 + avg_max: 0.9108070881413983 + mean: 0.0004922378678580036 + std: 0.2089721818735985 + shape: (10, 1500) + output: + min: -11.289525985717773 + max: 12.210380554199219 + avg_min: -5.1187274851617595 + avg_max: 5.199231204853789 + mean: -0.2947964010912462 + std: 1.2404614338320914 + shape: (10, 6000) +rnn.cells.1.eltwiseadd_gate: + inputs: + 0: + min: -14.241074562072754 + max: 14.333085060119629 + avg_min: -4.85037961406248 + avg_max: 4.998203643021349 + mean: -0.2753187046971623 + std: 1.2196254520335272 + shape: (10, 6000) + 1: + min: -11.289525985717773 + max: 12.210380554199219 + avg_min: -5.1187274851617595 + avg_max: 5.199231204853789 + mean: -0.2947964010912462 + std: 1.2404614338320914 + shape: (10, 6000) + output: + min: -18.20545196533203 + max: 17.143177032470703 + avg_min: -7.381283136345936 + avg_max: 7.463963007097637 + mean: -0.5701151059095978 + std: 1.9359252436683256 + shape: (10, 6000) +rnn.cells.1.act_f: + inputs: + 0: + min: -9.876046180725098 + max: 13.716924667358398 + avg_min: -4.492908775294287 + avg_max: 6.802375712015881 + mean: -0.3530231521011865 + std: 1.531050922142447 + shape: (10, 1500) + output: + min: 5.1388426072662696e-05 + max: 0.999998927116394 + avg_min: 0.013935547709276378 + avg_max: 0.9985429286978663 + mean: 0.41393780546997105 + std: 0.25393672173390425 + shape: (10, 1500) +rnn.cells.1.act_i: + inputs: + 0: + min: -16.290634155273438 + max: 15.543155670166016 + avg_min: -6.920624311107358 + avg_max: 5.485540903277943 + mean: -1.2801286601129584 + std: 1.8144594321741832 + shape: (10, 1500) + output: + min: 8.415258179184093e-08 + max: 0.9999998807907104 + avg_min: 0.0013664462433857262 + avg_max: 0.9920014234082808 + mean: 0.2965666530026901 + std: 0.2653741353156098 + shape: (10, 1500) +rnn.cells.1.act_o: + inputs: + 0: + min: -16.701549530029297 + max: 17.143177032470703 + avg_min: -6.635049015350669 + avg_max: 6.535849047371007 + mean: -0.665266030305992 + std: 1.9653585984229183 + shape: (10, 1500) + output: + min: 5.579678230560603e-08 + max: 1.0 + avg_min: 0.0020782844380410146 + avg_max: 0.9973914314131413 + mean: 0.3911302530009079 + std: 0.30228075066032956 + shape: (10, 1500) +rnn.cells.1.act_c_ex: + inputs: + 0: + min: -18.20545196533203 + max: 16.333803176879883 + avg_min: -6.664151172013942 + avg_max: 6.65981226445493 + mean: 0.017957419423132157 + std: 2.1413243397436355 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999723520145487 + avg_max: 0.9999761455298802 + mean: 0.006299477315604894 + std: 0.8122228369780807 + shape: (10, 1500) +rnn.cells.1.eltwisemult_cell_forget: + inputs: + 0: + min: 5.1388426072662696e-05 + max: 0.999998927116394 + avg_min: 0.013935547709276378 + avg_max: 0.9985429286978663 + mean: 0.41393780546997105 + std: 0.25393672173390425 + shape: (10, 1500) + 1: + min: -48.280601501464844 + max: 71.20462036132812 + avg_min: -5.697866783911992 + avg_max: 8.45928254428721 + mean: 0.005575296806025323 + std: 0.6651294887644327 + shape: (10, 1500) + output: + min: -48.25635528564453 + max: 70.38892364501953 + avg_min: -5.401933133100925 + avg_max: 8.148694984845944 + mean: 0.004142155777030078 + std: 0.5223340602367401 + shape: (10, 1500) +rnn.cells.1.eltwisemult_cell_input: + inputs: + 0: + min: 8.415258179184093e-08 + max: 0.9999998807907104 + avg_min: 0.0013664462433857262 + avg_max: 0.9920014234082808 + mean: 0.2965666530026901 + std: 0.2653741353156098 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999723520145487 + avg_max: 0.9999761455298802 + mean: 0.006299477315604894 + std: 0.8122228369780807 + shape: (10, 1500) + output: + min: -0.9999998807907104 + max: 0.9999983906745911 + avg_min: -0.9792872371932335 + avg_max: 0.9801845271534885 + mean: 0.0014335625686553053 + std: 0.3403960596000532 + shape: (10, 1500) +rnn.cells.1.eltwiseadd_cell: + inputs: + 0: + min: -48.25635528564453 + max: 70.38892364501953 + avg_min: -5.401933133100925 + avg_max: 8.148694984845944 + mean: 0.004142155777030078 + std: 0.5223340602367401 + shape: (10, 1500) + 1: + min: -0.9999998807907104 + max: 0.9999983906745911 + avg_min: -0.9792872371932335 + avg_max: 0.9801845271534885 + mean: 0.0014335625686553053 + std: 0.3403960596000532 + shape: (10, 1500) + output: + min: -48.280601501464844 + max: 71.20462036132812 + avg_min: -5.698131194071056 + avg_max: 8.459744671469375 + mean: 0.0055757183429154394 + std: 0.6651466271437165 + shape: (10, 1500) +rnn.cells.1.act_h: + inputs: + 0: + min: -48.280601501464844 + max: 71.20462036132812 + avg_min: -5.698131194071056 + avg_max: 8.459744671469375 + mean: 0.0055757183429154394 + std: 0.6651466271437165 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.992246895147991 + avg_max: 0.9949495861078208 + mean: 0.0021621192841517175 + std: 0.38166194375973717 + shape: (10, 1500) +rnn.cells.1.eltwisemult_hidden: + inputs: + 0: + min: 5.579678230560603e-08 + max: 1.0 + avg_min: 0.0020782844380410146 + avg_max: 0.9973914314131413 + mean: 0.3911302530009079 + std: 0.30228075066032956 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.992246895147991 + avg_max: 0.9949495861078208 + mean: 0.0021621192841517175 + std: 0.38166194375973717 + shape: (10, 1500) + output: + min: -0.9998947978019714 + max: 0.9999969601631165 + avg_min: -0.9061040613617598 + avg_max: 0.9108494260767619 + mean: 0.0004923663002494507 + std: 0.20897672763986075 + shape: (10, 1500) +rnn.dropout: + inputs: + 0: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) + output: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) +decoder: + inputs: + 0: + min: -0.9998947978019714 + max: 0.9999969601631165 + avg_min: -0.9841988293687626 + avg_max: 0.9854154916438258 + mean: 0.000492518011768678 + std: 0.20896516693829006 + shape: (35, 10, 1500) + output: + min: -16.018861770629883 + max: 34.73637390136719 + avg_min: -7.234289410796578 + avg_max: 18.421518564990844 + mean: -0.019850742917843275 + std: 1.6108417678641131 + shape: (35, 10, 33278) diff --git a/examples/word_language_model/manual_lstm_pretrained_stats_new.yaml b/examples/word_language_model/manual_lstm_pretrained_stats_new.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9e9127ea9dd7b2a9ab0bc2920feda13c5215dfb --- /dev/null +++ b/examples/word_language_model/manual_lstm_pretrained_stats_new.yaml @@ -0,0 +1,566 @@ +encoder: + inputs: + 0: + min: 0 + max: 33270 + avg_min: 10.219235186035833 + avg_max: 14172.569453376202 + mean: 2843.6353892971997 + std: 5516.717805093252 + shape: (35, 10) + output: + min: -1.0055707693099976 + max: 0.8849266171455383 + avg_min: -0.7984857493083201 + avg_max: 0.7398376246357271 + mean: 0.0011389567903235516 + std: 0.106637374597835 + shape: (35, 10, 1500) +rnn.cells.0.fc_gate_x: + inputs: + 0: + min: -1.0055707693099976 + max: 0.8849266171455383 + avg_min: -0.47863298946239113 + avg_max: 0.45979668644318944 + mean: 0.0011388687080601616 + std: 0.10664238753626366 + shape: (10, 1500) + output: + min: -7.170577049255371 + max: 6.01923131942749 + avg_min: -4.400209629704961 + avg_max: 3.8546500646612807 + mean: -0.5324024169690869 + std: 1.136934240306631 + shape: (10, 6000) +rnn.cells.0.fc_gate_h: + inputs: + 0: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5738822829100009 + avg_max: 0.5650099425972491 + mean: 0.00019907150779485835 + std: 0.10291856745491001 + shape: (10, 1500) + output: + min: -12.353250503540039 + max: 12.154502868652344 + avg_min: -4.178367384987595 + avg_max: 3.8204386773444914 + mean: -0.24742181252201967 + std: 0.6138123563333803 + shape: (10, 6000) +rnn.cells.0.eltwiseadd_gate: + inputs: + 0: + min: -7.170577049255371 + max: 6.01923131942749 + avg_min: -4.400209629704961 + avg_max: 3.8546500646612807 + mean: -0.5324024169690869 + std: 1.136934240306631 + shape: (10, 6000) + 1: + min: -12.353250503540039 + max: 12.154502868652344 + avg_min: -4.178367384987595 + avg_max: 3.8204386773444914 + mean: -0.24742181252201967 + std: 0.6138123563333803 + shape: (10, 6000) + output: + min: -6.0 + max: 6.0 + avg_min: -6.393781979404043 + avg_max: 5.522592915806375 + mean: -0.7798242293363614 + std: 1.3385719958875721 + shape: (10, 6000) +rnn.cells.0.act_f: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -3.588743975214784 + avg_max: 5.171411778303171 + mean: -0.8314223118150013 + std: 0.9994718826025869 + shape: (10, 1500) + output: + min: 1.6587961226832704e-07 + max: 0.9999972581863403 + avg_min: 0.03998190438239164 + avg_max: 0.9902226688948058 + mean: 0.31776015875831237 + std: 0.18128372322608863 + shape: (10, 1500) +rnn.cells.0.act_i: + inputs: + 0: + min: -6.0 + max: 6 + avg_min: -6.246054550997017 + avg_max: 3.4008014973694394 + mean: -1.353105740375623 + std: 1.4006327442395972 + shape: (10, 1500) + output: + min: 3.2412126529379748e-06 + max: 0.9999929666519165 + avg_min: 0.0034082192363447438 + avg_max: 0.9486433887111401 + mean: 0.26847494655177206 + std: 0.2133546549199177 + shape: (10, 1500) +rnn.cells.0.act_o: + inputs: + 0: + min: -6 + max: 6 + avg_min: -4.700205807891148 + avg_max: 3.172773839445486 + mean: -0.9314432533239825 + std: 1.108967504745895 + shape: (10, 1500) + output: + min: 4.816373984795064e-05 + max: 0.9998113512992859 + avg_min: 0.012553264552863666 + avg_max: 0.9470582806313093 + mean: 0.318497704179476 + std: 0.19847815427571672 + shape: (10, 1500) +rnn.cells.0.act_c_ex: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -4.806999210727877 + avg_max: 4.59695213219979 + mean: -0.003325613286218699 + std: 1.4208848461012338 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9993293769645618 + avg_max: 0.9990078361602005 + mean: -6.53550379788874e-05 + std: 0.7375218759945332 + shape: (10, 1500) +rnn.cells.0.eltwisemult_cell_forget: + inputs: + 0: + min: 1.6587961226832704e-07 + max: 0.9999972581863403 + avg_min: 0.03998190438239164 + avg_max: 0.9902226688948058 + mean: 0.31776015875831237 + std: 0.18128372322608863 + shape: (10, 1500) + 1: + min: -18.4894962310791 + max: 12.76415729522705 + avg_min: -1.5611846597817076 + avg_max: 1.5773179970730573 + mean: -0.000943341485093316 + std: 0.3030234156804856 + shape: (10, 1500) + output: + min: -18.10362434387207 + max: 12.600543975830078 + avg_min: -1.2997477680276228 + avg_max: 1.2838877794119525 + mean: -3.0967179330471856e-05 + std: 0.13354148372150612 + shape: (10, 1500) +rnn.cells.0.eltwisemult_cell_input: + inputs: + 0: + min: 3.2412126529379748e-06 + max: 0.9999929666519165 + avg_min: 0.0034082192363447438 + avg_max: 0.9486433887111401 + mean: 0.26847494655177206 + std: 0.2133546549199177 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9993293769645618 + avg_max: 0.9990078361602005 + mean: -6.53550379788874e-05 + std: 0.7375218759945332 + shape: (10, 1500) + output: + min: -0.9999427199363708 + max: 0.9998571276664734 + avg_min: -0.9042230041123592 + avg_max: 0.8991517046765057 + mean: -0.0009125015388671264 + std: 0.2702341313401001 + shape: (10, 1500) +rnn.cells.0.eltwiseadd_cell: + inputs: + 0: + min: -18.10362434387207 + max: 12.600543975830078 + avg_min: -1.2997477680276228 + avg_max: 1.2838877794119525 + mean: -3.0967179330471856e-05 + std: 0.13354148372150612 + shape: (10, 1500) + 1: + min: -0.9999427199363708 + max: 0.9998571276664734 + avg_min: -0.9042230041123592 + avg_max: 0.8991517046765057 + mean: -0.0009125015388671264 + std: 0.2702341313401001 + shape: (10, 1500) + output: + min: -18.4894962310791 + max: 12.76415729522705 + avg_min: -1.5612795371970447 + avg_max: 1.5773962020150967 + mean: -0.0009434687136402496 + std: 0.3030323653891493 + shape: (10, 1500) +rnn.cells.0.act_h: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -1.5612795371970447 + avg_max: 1.5773962020150967 + mean: -0.0009434687136402496 + std: 0.3030323653891493 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.8615034911275732 + avg_max: 0.8620808821404143 + mean: -0.0007417835671436678 + std: 0.26512230259603614 + shape: (10, 1500) +rnn.cells.0.eltwisemult_hidden: + inputs: + 0: + min: 4.816373984795064e-05 + max: 0.9998113512992859 + avg_min: 0.012553264552863666 + avg_max: 0.9470582806313093 + mean: 0.318497704179476 + std: 0.19847815427571672 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.8615034911275732 + avg_max: 0.8620808821404143 + mean: -0.0007417835671436678 + std: 0.26512230259603614 + shape: (10, 1500) + output: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) +rnn.cells.1.fc_gate_x: + inputs: + 0: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) + output: + min: -14.241074562072754 + max: 14.333085060119629 + avg_min: -4.85037961406248 + avg_max: 4.998203643021349 + mean: -0.2753187046971623 + std: 1.2196254520335272 + shape: (10, 6000) +rnn.cells.1.fc_gate_h: + inputs: + 0: + min: -0.9998947978019714 + max: 0.9999969601631165 + avg_min: -0.9060620619319191 + avg_max: 0.9108070881413983 + mean: 0.0004922378678580036 + std: 0.2089721818735985 + shape: (10, 1500) + output: + min: -11.289525985717773 + max: 12.210380554199219 + avg_min: -5.1187274851617595 + avg_max: 5.199231204853789 + mean: -0.2947964010912462 + std: 1.2404614338320914 + shape: (10, 6000) +rnn.cells.1.eltwiseadd_gate: + inputs: + 0: + min: -14.241074562072754 + max: 14.333085060119629 + avg_min: -4.85037961406248 + avg_max: 4.998203643021349 + mean: -0.2753187046971623 + std: 1.2196254520335272 + shape: (10, 6000) + 1: + min: -11.289525985717773 + max: 12.210380554199219 + avg_min: -5.1187274851617595 + avg_max: 5.199231204853789 + mean: -0.2947964010912462 + std: 1.2404614338320914 + shape: (10, 6000) + output: + min: -18.20545196533203 + max: 17.143177032470703 + avg_min: -7.381283136345936 + avg_max: 7.463963007097637 + mean: -0.5701151059095978 + std: 1.9359252436683256 + shape: (10, 6000) +rnn.cells.1.act_f: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -4.492908775294287 + avg_max: 6.802375712015881 + mean: -0.3530231521011865 + std: 1.531050922142447 + shape: (10, 1500) + output: + min: 5.1388426072662696e-05 + max: 0.999998927116394 + avg_min: 0.013935547709276378 + avg_max: 0.9985429286978663 + mean: 0.41393780546997105 + std: 0.25393672173390425 + shape: (10, 1500) +rnn.cells.1.act_i: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -6.920624311107358 + avg_max: 5.485540903277943 + mean: -1.2801286601129584 + std: 1.8144594321741832 + shape: (10, 1500) + output: + min: 8.415258179184093e-08 + max: 0.9999998807907104 + avg_min: 0.0013664462433857262 + avg_max: 0.9920014234082808 + mean: 0.2965666530026901 + std: 0.2653741353156098 + shape: (10, 1500) +rnn.cells.1.act_o: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -6.635049015350669 + avg_max: 6.535849047371007 + mean: -0.665266030305992 + std: 1.9653585984229183 + shape: (10, 1500) + output: + min: 5.579678230560603e-08 + max: 1.0 + avg_min: 0.0020782844380410146 + avg_max: 0.9973914314131413 + mean: 0.3911302530009079 + std: 0.30228075066032956 + shape: (10, 1500) +rnn.cells.1.act_c_ex: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -6.664151172013942 + avg_max: 6.65981226445493 + mean: 0.017957419423132157 + std: 2.1413243397436355 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.9999723520145487 + avg_max: 0.9999761455298802 + mean: 0.006299477315604894 + std: 0.8122228369780807 + shape: (10, 1500) +rnn.cells.1.eltwisemult_cell_forget: + inputs: + 0: + min: 5.1388426072662696e-05 + max: 0.999998927116394 + avg_min: 0.013935547709276378 + avg_max: 0.9985429286978663 + mean: 0.41393780546997105 + std: 0.25393672173390425 + shape: (10, 1500) + 1: + min: -48.280601501464844 + max: 71.20462036132812 + avg_min: -5.697866783911992 + avg_max: 8.45928254428721 + mean: 0.005575296806025323 + std: 0.6651294887644327 + shape: (10, 1500) + output: + min: -48.25635528564453 + max: 70.38892364501953 + avg_min: -5.401933133100925 + avg_max: 8.148694984845944 + mean: 0.004142155777030078 + std: 0.5223340602367401 + shape: (10, 1500) +rnn.cells.1.eltwisemult_cell_input: + inputs: + 0: + min: 8.415258179184093e-08 + max: 0.9999998807907104 + avg_min: 0.0013664462433857262 + avg_max: 0.9920014234082808 + mean: 0.2965666530026901 + std: 0.2653741353156098 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.9999723520145487 + avg_max: 0.9999761455298802 + mean: 0.006299477315604894 + std: 0.8122228369780807 + shape: (10, 1500) + output: + min: -0.9999998807907104 + max: 0.9999983906745911 + avg_min: -0.9792872371932335 + avg_max: 0.9801845271534885 + mean: 0.0014335625686553053 + std: 0.3403960596000532 + shape: (10, 1500) +rnn.cells.1.eltwiseadd_cell: + inputs: + 0: + min: -48.25635528564453 + max: 70.38892364501953 + avg_min: -5.401933133100925 + avg_max: 8.148694984845944 + mean: 0.004142155777030078 + std: 0.5223340602367401 + shape: (10, 1500) + 1: + min: -0.9999998807907104 + max: 0.9999983906745911 + avg_min: -0.9792872371932335 + avg_max: 0.9801845271534885 + mean: 0.0014335625686553053 + std: 0.3403960596000532 + shape: (10, 1500) + output: + min: -48.280601501464844 + max: 71.20462036132812 + avg_min: -5.698131194071056 + avg_max: 8.459744671469375 + mean: 0.0055757183429154394 + std: 0.6651466271437165 + shape: (10, 1500) +rnn.cells.1.act_h: + inputs: + 0: + min: -6.0 + max: 6.0 + avg_min: -5.698131194071056 + avg_max: 8.459744671469375 + mean: 0.0055757183429154394 + std: 0.6651466271437165 + shape: (10, 1500) + output: + min: -1.0 + max: 1.0 + avg_min: -0.992246895147991 + avg_max: 0.9949495861078208 + mean: 0.0021621192841517175 + std: 0.38166194375973717 + shape: (10, 1500) +rnn.cells.1.eltwisemult_hidden: + inputs: + 0: + min: 5.579678230560603e-08 + max: 1.0 + avg_min: 0.0020782844380410146 + avg_max: 0.9973914314131413 + mean: 0.3911302530009079 + std: 0.30228075066032956 + shape: (10, 1500) + 1: + min: -1.0 + max: 1.0 + avg_min: -0.992246895147991 + avg_max: 0.9949495861078208 + mean: 0.0021621192841517175 + std: 0.38166194375973717 + shape: (10, 1500) + output: + min: -0.9998947978019714 + max: 0.9999969601631165 + avg_min: -0.9061040613617598 + avg_max: 0.9108494260767619 + mean: 0.0004923663002494507 + std: 0.20897672763986075 + shape: (10, 1500) +rnn.dropout: + inputs: + 0: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) + output: + min: -0.9941717386245728 + max: 0.9952908158302307 + avg_min: -0.5739075602421044 + avg_max: 0.5650357380101981 + mean: 0.00019907324964651022 + std: 0.10292063063382816 + shape: (10, 1500) +decoder: + inputs: + 0: + min: -0.9998947978019714 + max: 0.9999969601631165 + avg_min: -0.9841988293687626 + avg_max: 0.9854154916438258 + mean: 0.000492518011768678 + std: 0.20896516693829006 + shape: (35, 10, 1500) + output: + min: -16.018861770629883 + max: 34.73637390136719 + avg_min: -7.234289410796578 + avg_max: 18.421518564990844 + mean: -0.019850742917843275 + std: 1.6108417678641131 + shape: (35, 10, 33278) diff --git a/examples/word_language_model/model.py b/examples/word_language_model/model.py index 612f0960c50bd34849294d90596d3b05f4bcbbd1..d52dc009c2959ca32319a77517713026c9d93340 100755 --- a/examples/word_language_model/model.py +++ b/examples/word_language_model/model.py @@ -1,10 +1,18 @@ +import torch import torch.nn as nn +from distiller.modules import * + class RNNModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): super(RNNModel, self).__init__() + self.ntoken = ntoken + self.ninp = ninp + self.nhid = nhid + self.nlayers = nlayers + self.tie_weights = tie_weights self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) if rnn_type in ['LSTM', 'GRU']: @@ -36,7 +44,6 @@ class RNNModel(nn.Module): self.nhid = nhid self.nlayers = nlayers - def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) @@ -58,3 +65,35 @@ class RNNModel(nn.Module): weight.new_zeros(self.nlayers, bsz, self.nhid)) else: return weight.new_zeros(self.nlayers, bsz, self.nhid) + + +class DistillerRNNModel(nn.Module): + """This is the distiller version of RNNModel, which uses DistillerLSTM instead of nn.LSTM.""" + def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): + super(DistillerRNNModel, self).__init__() + self.ntoken = ntoken + self.ninp = ninp + self.nhid = nhid + self.nlayers = nlayers + self.encoder = nn.Embedding(ntoken, ninp) + self.rnn = DistillerLSTM(ninp, nhid, nlayers, dropout=dropout) + self.decoder = nn.Linear(nhid, ntoken) + self.init_weights() + if tie_weights: + if nhid != ninp: + raise ValueError('When using the tied flag, nhid must be equal to emsize') + self.decoder.weight = self.encoder.weight + + def forward(self, x, h): + emb = self.encoder(x) + y, h = self.rnn(emb, h) + decoded = self.decoder(y) + return decoded, h + + def init_weights(self): + initrange = 0.1 + self.encoder.weight.data.uniform_(-initrange, initrange) + self.decoder.weight.data.uniform_(-initrange, initrange) + + def init_hidden(self, batch_size): + return self.rnn.init_hidden(batch_size) diff --git a/examples/word_language_model/quantize_lstm.ipynb b/examples/word_language_model/quantize_lstm.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d12df739c7015f4c0531f633a37ee849f44976e5 --- /dev/null +++ b/examples/word_language_model/quantize_lstm.ipynb @@ -0,0 +1,833 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quantizing RNN Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we show how to quantize recurrent models. \n", + "Using a pretrained model `model.RNNModel`, we convert the built-in pytorch implementation of LSTM to our own, modular implementation. \n", + "The pretrained model was generated with: \n", + "```time python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --tied --wd=1e-6``` \n", + "The reason we replace the LSTM that is because the inner operations in the pytorch implementation are not accessible to us, but we still want to quantize these operations. <br />\n", + "Afterwards we can try different techniques to quantize the whole model. \n", + "\n", + "_NOTE_: We use `tqdm` to plot progress bars, since it's not in `requirements.txt` you should install it using \n", + "`pip install tqdm`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from model import DistillerRNNModel, RNNModel\n", + "from data import Corpus\n", + "import torch\n", + "from torch import nn\n", + "import distiller\n", + "from distiller.modules import DistillerLSTM as LSTM\n", + "from tqdm import tqdm # for pretty progress bar\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preprocess the data:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "corpus = Corpus('./data/wikitext-2/')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def batchify(data, bsz):\n", + " # Work out how cleanly we can divide the dataset into bsz parts.\n", + " nbatch = data.size(0) // bsz\n", + " # Trim off any extra elements that wouldn't cleanly fit (remainders).\n", + " data = data.narrow(0, 0, nbatch * bsz)\n", + " # Evenly divide the data across the bsz batches.\n", + " data = data.view(bsz, -1).t().contiguous()\n", + " return data.to(device)\n", + "device = 'cuda:0'\n", + "batch_size = 20\n", + "eval_batch_size = 10\n", + "train_data = batchify(corpus.train, batch_size)\n", + "val_data = batchify(corpus.valid, eval_batch_size)\n", + "test_data = batchify(corpus.test, eval_batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading the model and converting to our own implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RNNModel(\n", + " (drop): Dropout(p=0.65)\n", + " (encoder): Embedding(33278, 1500)\n", + " (rnn): LSTM(1500, 1500, num_layers=2, dropout=0.65)\n", + " (decoder): Linear(in_features=1500, out_features=33278, bias=True)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rnn_model = torch.load('./checkpoint.pth.tar.best')\n", + "rnn_model = rnn_model.to(device)\n", + "rnn_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we convert the pytorch LSTM implementation to our own, by calling `LSTM.from_pytorch_impl`:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DistillerRNNModel(\n", + " (encoder): Embedding(33278, 1500)\n", + " (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n", + " (decoder): Linear(in_features=1500, out_features=33278, bias=True)\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def manual_model(pytorch_model_: RNNModel):\n", + " nlayers, ninp, nhid, ntoken, tie_weights = \\\n", + " pytorch_model_.nlayers, \\\n", + " pytorch_model_.ninp, \\\n", + " pytorch_model_.nhid, \\\n", + " pytorch_model_.ntoken, \\\n", + " pytorch_model_.tie_weights\n", + "\n", + " model = DistillerRNNModel(nlayers=nlayers, ninp=ninp, nhid=nhid, ntoken=ntoken, tie_weights=tie_weights).to(device)\n", + " model.eval()\n", + " model.encoder.weight = nn.Parameter(pytorch_model_.encoder.weight.clone().detach())\n", + " model.decoder.weight = nn.Parameter(pytorch_model_.decoder.weight.clone().detach())\n", + " model.decoder.bias = nn.Parameter(pytorch_model_.decoder.bias.clone().detach())\n", + " model.rnn = LSTM.from_pytorch_impl(pytorch_model_.rnn)\n", + "\n", + " return model\n", + "\n", + "man_model = manual_model(rnn_model)\n", + "torch.save(man_model, 'manual.checkpoint.pth.tar')\n", + "man_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batching the data for evaluation:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sequence_len = 35\n", + "def get_batch(source, i):\n", + " seq_len = min(sequence_len, len(source) - 1 - i)\n", + " data = source[i:i+seq_len]\n", + " target = source[i+1:i+1+seq_len].view(-1)\n", + " return data, target\n", + "\n", + "hidden = rnn_model.init_hidden(eval_batch_size)\n", + "data, targets = get_batch(test_data, 0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check that the convertion has succeeded:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Max error in y: 0.000011\n" + ] + } + ], + "source": [ + "rnn_model.eval()\n", + "man_model.eval()\n", + "y_t, h_t = rnn_model(data, hidden)\n", + "y_p, h_p = man_model(data, hidden)\n", + "\n", + "print(\"Max error in y: %f\" % (y_t-y_p).abs().max().item())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the evaluation:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "def repackage_hidden(h):\n", + " \"\"\"Wraps hidden states in new Tensors, to detach them from their history.\"\"\"\n", + " if isinstance(h, torch.Tensor):\n", + " return h.detach()\n", + " else:\n", + " return tuple(repackage_hidden(v) for v in h)\n", + " \n", + "\n", + "def evaluate(model, data_source):\n", + " # Turn on evaluation mode which disables dropout.\n", + " model.eval()\n", + " total_loss = 0.\n", + " ntokens = len(corpus.dictionary)\n", + " hidden = model.init_hidden(eval_batch_size)\n", + " with torch.no_grad():\n", + " # The line below was fixed as per: https://github.com/pytorch/examples/issues/214\n", + " for i in tqdm(range(0, data_source.size(0), sequence_len)):\n", + " data, targets = get_batch(data_source, i)\n", + " output, hidden = model(data, hidden)\n", + " output_flat = output.view(-1, ntokens)\n", + " total_loss += len(data) * criterion(output_flat, targets).item()\n", + " hidden = repackage_hidden(hidden)\n", + " return total_loss / len(data_source)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quantizing the model:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Collect activation statistics:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model uses activation statistics to determine how big the quantization range is. The bigger the range - the larger the round off error after quantization which leads to accuracy drop. \n", + "Our goal is to minimize the range s.t. it contains the absolute most of our data. \n", + "After that, we divide the range into chunks of equal size, according to the number of bits, and transform the data according to this scale factor. \n", + "Read more on scale factor calculation [in our docs](https://nervanasystems.github.io/distiller/algo_quantization.html).\n", + "\n", + "The class `QuantCalibrationStatsCollector` collects the statistics for defining the range $r = max - min$. \n", + "\n", + "Each forward pass, the collector records the values of inputs and outputs, for each layer:\n", + "- absolute over all batches min, max (stored in `min`, `max`)\n", + "- average over batches, per batch min, max (stored in `avg_min`, `avg_max`)\n", + "- mean\n", + "- std\n", + "- shape of output tensor \n", + "\n", + "All these values can be used to define the range of quantization, e.g. we can use the absolute `min`, `max` to define the range." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context\n", + "\n", + "man_model = torch.load('./manual.checkpoint.pth.tar')\n", + "distiller.utils.assign_layer_fq_names(man_model)\n", + "collector = QuantCalibrationStatsCollector(man_model)\n", + "\n", + "if not os.path.isfile('manual_lstm_pretrained_stats.yaml'):\n", + " with collector_context(collector) as collector:\n", + " val_loss = evaluate(man_model, val_data)\n", + " collector.save('manual_lstm_pretrained_stats.yaml')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Quantize Model:\n", + " \n", + "We quantize the model after the training has completed. \n", + "Here we check the baseline model perplexity, to have an idea how good the quantization is." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [00:23<00:00, 26.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss: 4.46\t|\t ppl: 86.78\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from distiller.quantization import PostTrainLinearQuantizer, LinearQuantMode\n", + "from copy import deepcopy\n", + "\n", + "# Load and evaluate the baseline model.\n", + "man_model = torch.load('./manual.checkpoint.pth.tar')\n", + "val_loss = evaluate(man_model, val_data)\n", + "print('val_loss:%8.2f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we do our magic - __Quantizing the model__. \n", + "The quantizer replaces the layers in out model with their quantized versions. \n", + "We can see that our model has changed:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the quantizer\n", + "quantizer = PostTrainLinearQuantizer(\n", + " deepcopy(man_model),\n", + " model_activation_stats='./manual_lstm_pretrained_stats.yaml')\n", + "\n", + "# Quantizer magic:\n", + "quantizer.prepare_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DistillerRNNModel(\n", + " (encoder): RangeLinearEmbeddingWrapper(\n", + " (wrapped_module): Embedding(33278, 1500)\n", + " )\n", + " (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n", + " (decoder): RangeLinearQuantParamLayerWrapper(\n", + " mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False\n", + " preset_activation_stats=True\n", + " w_scale=126.2964, w_zero_point=0.0000\n", + " in_scale=127.0004, in_zero_point=0.0000\n", + " out_scale=3.6561, out_zero_point=0.0000\n", + " (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "quantizer.model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [02:06<00:00, 5.71it/s]\n" + ] + } + ], + "source": [ + "val_loss = evaluate(quantizer.model.to(device), val_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss: 4.64\t|\t ppl: 103.26\n" + ] + } + ], + "source": [ + "print('val_loss:%8.2f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see here, the perplexity has increased much - meaning our quantization has damaged the accuracy of our model. \n", + "Let's try quantizing each channel separately, and making the range of the quantization asymmetric. \n", + "Also - we replaced the `min`, `max` boundaries manually in the file. \n", + "The idea is - the quantizer takes the absolute `min`, `max` boundaries by default, and in the original file many of the activations had a very large range that makes our quants very big - while we want to minimize their size since each quant corresponds to a roundoff error. \n", + "The activations in every LSTM are either `sigmoid` or `tanh`, and since these are bounded respectively by\n", + "$[0,1]$, $[-1,1]$ and they saturate very quickly - we can clip the inputs to be between in the range of $[-6,6]$." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DistillerRNNModel(\n", + " (encoder): RangeLinearEmbeddingWrapper(\n", + " (wrapped_module): Embedding(33278, 1500)\n", + " )\n", + " (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n", + " (decoder): RangeLinearQuantParamLayerWrapper(\n", + " mode=ASYMMETRIC_SIGNED, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=True\n", + " preset_activation_stats=True\n", + " w_scale=PerCh, w_zero_point=PerCh\n", + " in_scale=127.5069, in_zero_point=1.0000\n", + " out_scale=5.0241, out_zero_point=48.0000\n", + " (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "quantizer = PostTrainLinearQuantizer(\n", + " deepcopy(man_model),\n", + " model_activation_stats='./manual_lstm_pretrained_stats_new.yaml',\n", + " mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n", + " per_channel_wts=True\n", + ")\n", + "quantizer.prepare_model()\n", + "quantizer.model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [02:09<00:00, 5.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss: 4.61\t|\t ppl: 100.92\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "val_loss = evaluate(quantizer.model.to(device), val_data)\n", + "print('val_loss:%8.2f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A tiny bit better, but still no good. Let us try the half precision version of the model:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [00:29<00:00, 21.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss: 4.463242\t|\t ppl: 86.77\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model_fp16 = deepcopy(man_model).half()\n", + "val_loss = evaluate(model_fp16, val_data)\n", + "print('val_loss: %8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The result is very close to our original model! That means that the roundoff when quantizing lineary is what hurts our accuracy. Let's try then quantizing everything except elemtentwise operations, as stated in \n", + "[`Effective Quantization Methods for Recurrent Neural Networks`](https://arxiv.org/abs/1611.10176) :" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [01:20<00:00, 8.19it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss:4.463708\t|\t ppl: 86.81\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "overrides_yaml = \"\"\"\n", + ".*eltwise.*:\n", + " fp16: true\n", + "encoder:\n", + " fp16: true\n", + "decoder:\n", + " fp16: true\n", + "\"\"\"\n", + "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n", + "quantizer = PostTrainLinearQuantizer(\n", + " deepcopy(man_model),\n", + " model_activation_stats='./manual_lstm_pretrained_stats_new.yaml',\n", + " mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n", + " overrides=overrides,\n", + " per_channel_wts=True\n", + ")\n", + "quantizer.prepare_model()\n", + "val_loss = evaluate(quantizer.model.to(device), val_data)\n", + "print('val_loss:%8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DistillerRNNModel(\n", + " (encoder): FP16Wrapper(\n", + " (wrapped_module): Embedding(33278, 1500)\n", + " )\n", + " (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n", + " (decoder): FP16Wrapper(\n", + " (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "quantizer.model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The accuracy is still holding up very well, even though we quantized the inner linear layers! \n", + "Now, lets try to choose different boundaries for `min`, `max` - \n", + "Instead of using absolute ones, we take the average of all batches (`avg_min`, `avg_max`), which is an indication of where usually most of the boundaries lie. This is done by specifying the `clip_acts` parameter to `ClipMode.AVG` or `\"AVG\"` in the quantizer ctor:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [02:31<00:00, 3.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss:4.487813\t|\t ppl: 88.93\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "overrides_yaml = \"\"\"\n", + "encoder:\n", + " fp16: true\n", + "decoder:\n", + " fp16: true\n", + "\"\"\"\n", + "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n", + "quantizer = PostTrainLinearQuantizer(\n", + " deepcopy(man_model),\n", + " model_activation_stats='./manual_lstm_pretrained_stats.yaml',\n", + " mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n", + " overrides=overrides,\n", + " per_channel_wts=True,\n", + " clip_acts=\"AVG\"\n", + ")\n", + "quantizer.prepare_model()\n", + "val_loss = evaluate(quantizer.model.to(device), val_data)\n", + "print('val_loss:%8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! Even though we quantized all of the layers except the embedding and the decoder - we got almost no accuracy penalty. Lets try quantizing them as well:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [02:24<00:00, 4.84it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_loss:4.487492\t|\t ppl: 88.90\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "quantizer = PostTrainLinearQuantizer(\n", + " deepcopy(man_model),\n", + " model_activation_stats='./manual_lstm_pretrained_stats_new.yaml',\n", + " mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n", + " per_channel_wts=True,\n", + " clip_acts=\"AVG\"\n", + ")\n", + "quantizer.prepare_model()\n", + "val_loss = evaluate(quantizer.model.to(device), val_data)\n", + "print('val_loss:%8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DistillerRNNModel(\n", + " (encoder): RangeLinearEmbeddingWrapper(\n", + " (wrapped_module): Embedding(33278, 1500)\n", + " )\n", + " (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n", + " (decoder): RangeLinearQuantParamLayerWrapper(\n", + " mode=ASYMMETRIC_SIGNED, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=AVG, per_channel_wts=True\n", + " preset_activation_stats=True\n", + " w_scale=PerCh, w_zero_point=PerCh\n", + " in_scale=129.4670, in_zero_point=1.0000\n", + " out_scale=9.9393, out_zero_point=56.0000\n", + " (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "quantizer.model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we see that sometimes quantizing with the right boundaries gives better results than actually using floating point operations (even though they are half precision). " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "Choosing the right boundaries for quantization was crucial for achieving almost no degradation in accrucay of LSTM. \n", + " \n", + "Here we showed how to use the distiller quantization API to quantize an RNN model, by converting the pytorch implementation into a modular one and then quantizing each layer separately." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_lstm_impl.py b/tests/test_lstm_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..319015d69a6ccc342b1fe9c661164254c5af048c --- /dev/null +++ b/tests/test_lstm_impl.py @@ -0,0 +1,93 @@ +import pytest + +import distiller +from distiller.modules import DistillerLSTM, DistillerLSTMCell +import torch +import torch.nn as nn + +ACCEPTABLE_ERROR = 5e-5 +BATCH_SIZE = 32 +SEQUENCE_SIZE = 35 + + +def test_basic(): + lstmcell = DistillerLSTMCell(3, 5) + assert lstmcell.fc_gate_x.weight.shape == (5 * 4, 3) + assert lstmcell.fc_gate_h.weight.shape == (5 * 4, 5) + assert lstmcell.fc_gate_x.bias.shape == (5 * 4,) + assert lstmcell.fc_gate_h.bias.shape == (5 * 4,) + + lstm = DistillerLSTM(3, 5, 4, False, False, 0.0, True) + assert lstm.bidirectional_type == 2 + assert lstm.cells[0].fc_gate_x.weight.shape == (5 * 4, 3) + assert lstm.cells[1].fc_gate_x.weight.shape == (5 * 4, 5 * 2) + + +def test_conversion(): + lc_man = DistillerLSTMCell(3, 5) + lc_pth = lc_man.to_pytorch_impl() + lc_man1 = DistillerLSTMCell.from_pytorch_impl(lc_pth) + + assert (lc_man.fc_gate_x.weight == lc_man1.fc_gate_x.weight).all() + assert (lc_man.fc_gate_h.weight == lc_man1.fc_gate_h.weight).all() + + l_man = DistillerLSTM(3, 5, 2) + l_pth = l_man.to_pytorch_impl() + l_man1 = DistillerLSTM.from_pytorch_impl(l_pth) + + for i in range(l_man.num_layers): + assert (l_man1.cells[i].fc_gate_x.weight == l_man.cells[i].fc_gate_x.weight).all() + assert (l_man1.cells[i].fc_gate_h.weight == l_man.cells[i].fc_gate_h.weight).all() + assert (l_man1.cells[i].fc_gate_x.bias == l_man.cells[i].fc_gate_x.bias).all() + assert (l_man1.cells[i].fc_gate_h.bias == l_man.cells[i].fc_gate_h.bias).all() + + +def assert_output(out_true, out_pred): + y_t, h_t = out_true + y_p, h_p = out_pred + assert (y_t - y_p).abs().max().item() < ACCEPTABLE_ERROR + h_h_t, h_c_t = h_t + h_h_p, h_c_p = h_p + assert (h_h_t - h_h_p).abs().max().item() < ACCEPTABLE_ERROR + assert (h_c_t - h_c_p).abs().max().item() < ACCEPTABLE_ERROR + + +@pytest.fixture(name='bidirectional', params=[False, True], ids=['bidirectional_off', 'bidirectional_on']) +def fixture_bidirectional(request): + return request.param + + +@pytest.mark.parametrize( + "input_size, hidden_size, num_layers", + [ + (1, 1, 2), + (3, 5, 7), + (1500, 1500, 5) + ] +) +def test_forward_lstm(input_size, hidden_size, num_layers, bidirectional): + # Test conversion from pytorch implementation + lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + lstm_man = DistillerLSTM.from_pytorch_impl(lstm) + lstm.eval() + lstm_man.eval() + + h = lstm_man.init_hidden(BATCH_SIZE) + x = torch.rand(SEQUENCE_SIZE, BATCH_SIZE, input_size) + + out_true = lstm(x, h) + out_pred = lstm_man(x, h) + assert_output(out_true, out_pred) + # Test conversion to pytorch implementation + lstm_man = DistillerLSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + lstm = lstm_man.to_pytorch_impl() + + lstm.eval() + lstm_man.eval() + + h = lstm_man.init_hidden(BATCH_SIZE) + x = torch.rand(SEQUENCE_SIZE, BATCH_SIZE, input_size) + + out_true = lstm(x, h) + out_pred = lstm_man(x, h) + assert_output(out_true, out_pred)