From 12d6b7e2cf1b1256534396abc835b7aaa4e1540c Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Wed, 15 May 2019 16:53:03 +0300
Subject: [PATCH] Added convert_model_to_distiller_lstm (#259)

* Traverses recursively through entire model and replaces all submodules of type `nn.LSTM` and `nn.LSTMCell` with distiller versions
---
 distiller/modules/__init__.py |  4 ++--
 distiller/modules/rnn.py      | 19 ++++++++++++++++++-
 2 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index 282d7fd..5bd7d5c 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -16,8 +16,8 @@
 
 from .eltwise import EltwiseAdd, EltwiseMult
 from .grouping import *
-from .rnn import DistillerLSTM, DistillerLSTMCell
+from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm
 
 __all__ = ['EltwiseAdd', 'EltwiseMult',
            'Concat', 'Chunk', 'Split', 'Stack',
-           'DistillerLSTMCell', 'DistillerLSTM']
+           'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py
index 633c088..3788669 100644
--- a/distiller/modules/rnn.py
+++ b/distiller/modules/rnn.py
@@ -20,7 +20,7 @@ import numpy as np
 from .eltwise import EltwiseAdd, EltwiseMult
 from itertools import product
 
-__all__ = ['DistillerLSTMCell', 'DistillerLSTM']
+__all__ = ['DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
 
 
 class DistillerLSTMCell(nn.Module):
@@ -430,3 +430,20 @@ class DistillerLSTM(nn.Module):
                 self.num_layers,
                 self.dropout_factor,
                 self.bidirectional)
+
+
+def convert_model_to_distiller_lstm(model: nn.Module):
+    """
+    Replaces all `nn.LSTM`s and `nn.LSTMCell`s in the model with distiller versions.
+    Args:
+        model (nn.Module): the model
+    """
+    if isinstance(model, nn.LSTMCell):
+        return DistillerLSTMCell.from_pytorch_impl(model)
+    if isinstance(model, nn.LSTM):
+        return DistillerLSTM.from_pytorch_impl(model)
+    for name, module in model.named_children():
+        module = convert_model_to_distiller_lstm(module)
+        setattr(model, name, module)
+
+    return model
-- 
GitLab