Skip to content
Snippets Groups Projects
Commit 748cb056 authored by Neta Zmora's avatar Neta Zmora
Browse files

Add EltwiseSub

As requested in issue #496
parent 410a059b
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@ from .rnn import *
from .aggregate import *
from .topology import *
__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
__all__ = ['EltwiseAdd', 'EltwiseSub', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
'Concat', 'Chunk', 'Split', 'Stack',
'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
'Norm', 'Mean', 'BranchPoint', 'Print']
......
......@@ -19,8 +19,8 @@ import torch.nn as nn
class EltwiseAdd(nn.Module):
def __init__(self, inplace=False):
super(EltwiseAdd, self).__init__()
"""Element-wise addition"""
super().__init__()
self.inplace = inplace
def forward(self, *input):
......@@ -34,9 +34,27 @@ class EltwiseAdd(nn.Module):
return res
class EltwiseSub(nn.Module):
def __init__(self, inplace=False):
"""Element-wise subtraction"""
super().__init__()
self.inplace = inplace
def forward(self, *input):
res = input[0]
if self.inplace:
for t in input[1:]:
res -= t
else:
for t in input[1:]:
res = res - t
return res
class EltwiseMult(nn.Module):
def __init__(self, inplace=False):
super(EltwiseMult, self).__init__()
"""Element-wise multiplication"""
super().__init__()
self.inplace = inplace
def forward(self, *input):
......@@ -52,7 +70,8 @@ class EltwiseMult(nn.Module):
class EltwiseDiv(nn.Module):
def __init__(self, inplace=False):
super(EltwiseDiv, self).__init__()
"""Element-wise division"""
super().__init__()
self.inplace = inplace
def forward(self, x: torch.Tensor, y):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment