From 12d6b7e2cf1b1256534396abc835b7aaa4e1540c Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Wed, 15 May 2019 16:53:03 +0300 Subject: [PATCH] Added convert_model_to_distiller_lstm (#259) * Traverses recursively through entire model and replaces all submodules of type `nn.LSTM` and `nn.LSTMCell` with distiller versions --- distiller/modules/__init__.py | 4 ++-- distiller/modules/rnn.py | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index 282d7fd..5bd7d5c 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -16,8 +16,8 @@ from .eltwise import EltwiseAdd, EltwiseMult from .grouping import * -from .rnn import DistillerLSTM, DistillerLSTMCell +from .rnn import DistillerLSTM, DistillerLSTMCell, convert_model_to_distiller_lstm __all__ = ['EltwiseAdd', 'EltwiseMult', 'Concat', 'Chunk', 'Split', 'Stack', - 'DistillerLSTMCell', 'DistillerLSTM'] + 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm'] diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py index 633c088..3788669 100644 --- a/distiller/modules/rnn.py +++ b/distiller/modules/rnn.py @@ -20,7 +20,7 @@ import numpy as np from .eltwise import EltwiseAdd, EltwiseMult from itertools import product -__all__ = ['DistillerLSTMCell', 'DistillerLSTM'] +__all__ = ['DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm'] class DistillerLSTMCell(nn.Module): @@ -430,3 +430,20 @@ class DistillerLSTM(nn.Module): self.num_layers, self.dropout_factor, self.bidirectional) + + +def convert_model_to_distiller_lstm(model: nn.Module): + """ + Replaces all `nn.LSTM`s and `nn.LSTMCell`s in the model with distiller versions. + Args: + model (nn.Module): the model + """ + if isinstance(model, nn.LSTMCell): + return DistillerLSTMCell.from_pytorch_impl(model) + if isinstance(model, nn.LSTM): + return DistillerLSTM.from_pytorch_impl(model) + for name, module in model.named_children(): + module = convert_model_to_distiller_lstm(module) + setattr(model, name, module) + + return model -- GitLab