diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index e46bc81d1a80dba96cf9b976ff742bb3a1b8eb63..d8a5352cfefcf0bcf1a961a910ee7402965f8ec9 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -21,7 +21,7 @@ from .rnn import * from .aggregate import * from .topology import * -__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', +__all__ = ['EltwiseAdd', 'EltwiseSub', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', 'Concat', 'Chunk', 'Split', 'Stack', 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', 'Norm', 'Mean', 'BranchPoint', 'Print'] diff --git a/distiller/modules/eltwise.py b/distiller/modules/eltwise.py index 84350593f84fd97a7888eb83798389f0e90dd540..6f5895df2591093ae9a38678213c74eeba9c9b50 100644 --- a/distiller/modules/eltwise.py +++ b/distiller/modules/eltwise.py @@ -19,8 +19,8 @@ import torch.nn as nn class EltwiseAdd(nn.Module): def __init__(self, inplace=False): - super(EltwiseAdd, self).__init__() - + """Element-wise addition""" + super().__init__() self.inplace = inplace def forward(self, *input): @@ -34,9 +34,27 @@ class EltwiseAdd(nn.Module): return res +class EltwiseSub(nn.Module): + def __init__(self, inplace=False): + """Element-wise subtraction""" + super().__init__() + self.inplace = inplace + + def forward(self, *input): + res = input[0] + if self.inplace: + for t in input[1:]: + res -= t + else: + for t in input[1:]: + res = res - t + return res + + class EltwiseMult(nn.Module): def __init__(self, inplace=False): - super(EltwiseMult, self).__init__() + """Element-wise multiplication""" + super().__init__() self.inplace = inplace def forward(self, *input): @@ -52,7 +70,8 @@ class EltwiseMult(nn.Module): class EltwiseDiv(nn.Module): def __init__(self, inplace=False): - super(EltwiseDiv, self).__init__() + """Element-wise division""" + super().__init__() self.inplace = inplace def forward(self, x: torch.Tensor, y):