diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index 282d7fde0c4bf0f387a19d731d932444508abf47..5bd7d5cfd5a2bcd8e688e9901f7cf7d4224c5469 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -16,8 +16,8 @@ from .eltwise import EltwiseAdd, EltwiseMult from .grouping import * -from .rnn import DistillerLSTM, DistillerLSTMCell +from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm __all__ = ['EltwiseAdd', 'EltwiseMult', 'Concat', 'Chunk', 'Split', 'Stack', - 'DistillerLSTMCell', 'DistillerLSTM'] + 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm'] diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py index 633c088f4def2e635324d4fd25e476a1e7779d47..37886696d908834b0a892cb987aa2016561b71c2 100644 --- a/distiller/modules/rnn.py +++ b/distiller/modules/rnn.py @@ -20,7 +20,7 @@ import numpy as np from .eltwise import EltwiseAdd, EltwiseMult from itertools import product -__all__ = ['DistillerLSTMCell', 'DistillerLSTM'] +__all__ = ['DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm'] class DistillerLSTMCell(nn.Module): @@ -430,3 +430,20 @@ class DistillerLSTM(nn.Module): self.num_layers, self.dropout_factor, self.bidirectional) + + +def convert_model_to_distiller_lstm(model: nn.Module): + """ + Replaces all `nn.LSTM`s and `nn.LSTMCell`s in the model with distiller versions. + Args: + model (nn.Module): the model + """ + if isinstance(model, nn.LSTMCell): + return DistillerLSTMCell.from_pytorch_impl(model) + if isinstance(model, nn.LSTM): + return DistillerLSTM.from_pytorch_impl(model) + for name, module in model.named_children(): + module = convert_model_to_distiller_lstm(module) + setattr(model, name, module) + + return model