diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index 282d7fde0c4bf0f387a19d731d932444508abf47..5bd7d5cfd5a2bcd8e688e9901f7cf7d4224c5469 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -16,8 +16,8 @@
 
 from .eltwise import EltwiseAdd, EltwiseMult
 from .grouping import *
-from .rnn import DistillerLSTM, DistillerLSTMCell
+from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm
 
 __all__ = ['EltwiseAdd', 'EltwiseMult',
            'Concat', 'Chunk', 'Split', 'Stack',
-           'DistillerLSTMCell', 'DistillerLSTM']
+           'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py
index 633c088f4def2e635324d4fd25e476a1e7779d47..37886696d908834b0a892cb987aa2016561b71c2 100644
--- a/distiller/modules/rnn.py
+++ b/distiller/modules/rnn.py
@@ -20,7 +20,7 @@ import numpy as np
 from .eltwise import EltwiseAdd, EltwiseMult
 from itertools import product
 
-__all__ = ['DistillerLSTMCell', 'DistillerLSTM']
+__all__ = ['DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm']
 
 
 class DistillerLSTMCell(nn.Module):
@@ -430,3 +430,20 @@ class DistillerLSTM(nn.Module):
                 self.num_layers,
                 self.dropout_factor,
                 self.bidirectional)
+
+
+def convert_model_to_distiller_lstm(model: nn.Module):
+    """
+    Replaces all `nn.LSTM`s and `nn.LSTMCell`s in the model with distiller versions.
+    Args:
+        model (nn.Module): the model
+    """
+    if isinstance(model, nn.LSTMCell):
+        return DistillerLSTMCell.from_pytorch_impl(model)
+    if isinstance(model, nn.LSTM):
+        return DistillerLSTM.from_pytorch_impl(model)
+    for name, module in model.named_children():
+        module = convert_model_to_distiller_lstm(module)
+        setattr(model, name, module)
+
+    return model