From 92fd0019e0b69afe8b2d7d6240124aa3195e1ac7 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Tue, 30 Apr 2019 12:25:27 +0300
Subject: [PATCH] Added PackedSequence functionality (#236)

* Update test_lstm_impl.py

* Added PackedSequence functionality

* Refactored forward implementation
---
 distiller/modules/rnn.py | 40 +++++++++++++++++++++++++++++++++++-----
 tests/test_lstm_impl.py  | 36 ++++++++++++++++++++++++++++++++----
 2 files changed, 67 insertions(+), 9 deletions(-)

diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py
index 7471f32..633c088 100644
--- a/distiller/modules/rnn.py
+++ b/distiller/modules/rnn.py
@@ -22,6 +22,7 @@ from itertools import product
 
 __all__ = ['DistillerLSTMCell', 'DistillerLSTM']
 
+
 class DistillerLSTMCell(nn.Module):
     """
     A single LSTM block.
@@ -231,9 +232,9 @@ class DistillerLSTM(nn.Module):
     def forward(self, x, h=None):
         is_packed_seq = isinstance(x, nn.utils.rnn.PackedSequence)
         if is_packed_seq:
-            x, lengths = nn.utils.rnn.pad_packed_sequence(x, self.batch_first)
+            return self.packed_sequence_forward(x, h)
 
-        elif self.batch_first:
+        if self.batch_first:
             # Transpose to sequence_first format
             x = x.transpose(0, 1)
         x_bsz = x.size(1)
@@ -242,14 +243,40 @@ class DistillerLSTM(nn.Module):
             h = self.init_hidden(x_bsz)
 
         y, h = self.forward_fn(x, h)
-        if is_packed_seq:
-            y = nn.utils.rnn.pack_padded_sequence(y, lengths, self.batch_first)
 
-        elif self.batch_first:
+        if self.batch_first:
             # Transpose back to batch_first format
             y = y.transpose(0, 1)
         return y, h
 
+    def packed_sequence_forward(self, x, h=None):
+        # Packed sequence treatment -
+        # the sequences are not of the same size, hence
+        # we split the padded tensor into the sequences.
+        # we take the sequence from each row in the batch.
+        x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
+        x_bsz = x.size(0)
+        if h is None:
+            h = self.init_hidden(x_bsz)
+        y_results = []
+        h_results = []
+        for i, (sequence, seq_len) in enumerate(zip(x, lengths)):
+            # Take the previous state according to the current batch.
+            # we unsqueeze to have a 3D tensor
+            h_current = (h[0][:, i, :].unsqueeze(1), h[1][:, i, :].unsqueeze(1))
+            # Take only the relevant timesteps according to seq_len
+            sequence = sequence[:seq_len].unsqueeze(1)  # sequence.shape = (seq_len, batch_size=1, input_dim)
+            # forward pass:
+            y, h_current = self.forward_fn(sequence, h_current)
+            # sequeeze back the batch into a single sequence
+            y_results.append(y.squeeze(1))
+            h_results.append(h_current)
+        # our result is a packed sequence
+        y = nn.utils.rnn.pack_sequence(y_results)
+        # concat hidden states per batches
+        h = torch.cat([t[0] for t in h_results], dim=1), torch.cat([t[1] for t in h_results], dim=1)
+        return y, h
+
     def process_layer_wise(self, x, h):
         results = []
         for step in x:
@@ -304,6 +331,9 @@ class DistillerLSTM(nn.Module):
         """
         Process a single timestep through the entire unidirectional layer chain.
         """
+        step_bsz = step.size(0)
+        if h is None:
+            h = self.init_hidden(step_bsz)
         h_all, c_all = h
         h_result = []
         out = step
diff --git a/tests/test_lstm_impl.py b/tests/test_lstm_impl.py
index 319015d..616aa6a 100644
--- a/tests/test_lstm_impl.py
+++ b/tests/test_lstm_impl.py
@@ -4,8 +4,11 @@ import distiller
 from distiller.modules import DistillerLSTM, DistillerLSTMCell
 import torch
 import torch.nn as nn
+from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, PackedSequence
+from torch.testing import assert_allclose
 
-ACCEPTABLE_ERROR = 5e-5
+ATOL = 5e-5
+RTOL = 1e-3
 BATCH_SIZE = 32
 SEQUENCE_SIZE = 35
 
@@ -45,11 +48,15 @@ def test_conversion():
 def assert_output(out_true, out_pred):
     y_t, h_t = out_true
     y_p, h_p = out_pred
-    assert (y_t - y_p).abs().max().item() < ACCEPTABLE_ERROR
+    if isinstance(y_t, PackedSequence):
+        y_t, lenghts_t = pad_packed_sequence(y_t)
+        y_p, lenghts_p = pad_packed_sequence(y_p)
+        assert all(lenghts_t == lenghts_p)
+    assert_allclose(y_p, y_t, RTOL, ATOL)
     h_h_t, h_c_t = h_t
     h_h_p, h_c_p = h_p
-    assert (h_h_t - h_h_p).abs().max().item() < ACCEPTABLE_ERROR
-    assert (h_c_t - h_c_p).abs().max().item() < ACCEPTABLE_ERROR
+    assert_allclose(h_h_p, h_h_t, RTOL, ATOL)
+    assert_allclose(h_c_p, h_c_t, RTOL, ATOL)
 
 
 @pytest.fixture(name='bidirectional', params=[False, True], ids=['bidirectional_off', 'bidirectional_on'])
@@ -91,3 +98,24 @@ def test_forward_lstm(input_size, hidden_size, num_layers, bidirectional):
     out_true = lstm(x, h)
     out_pred = lstm_man(x, h)
     assert_output(out_true, out_pred)
+
+
+@pytest.mark.parametrize(
+    "input_size, hidden_size, num_layers, input_lengths",
+    [
+        (1, 1, 2, [5, 4, 3]),
+        (3, 5, 7, [20, 15, 5]),
+        (500, 500, 5, [50, 35, 25])
+    ]
+)
+def test_packed_sequence(input_size, hidden_size, num_layers, input_lengths, bidirectional):
+    lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
+    lstm_man = DistillerLSTM.from_pytorch_impl(lstm)
+    lstm.eval()
+    lstm_man.eval()
+
+    h = lstm_man.init_hidden(BATCH_SIZE)
+    x = pack_sequence([torch.rand(length, input_size) for length in input_lengths])
+    out_true = lstm(x)
+    out_pred = lstm_man(x)
+    assert_output(out_true, out_pred)
-- 
GitLab