From 748cb0564de4b746945c59122e06871826fedcf9 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 16 Apr 2020 16:37:08 +0300
Subject: [PATCH] Add EltwiseSub

As requested in issue #496
---
 distiller/modules/__init__.py |  2 +-
 distiller/modules/eltwise.py  | 27 +++++++++++++++++++++++----
 2 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index e46bc81..d8a5352 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -21,7 +21,7 @@ from .rnn import *
 from .aggregate import *
 from .topology import *
 
-__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
+__all__ = ['EltwiseAdd', 'EltwiseSub', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
            'Concat', 'Chunk', 'Split', 'Stack',
            'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
            'Norm', 'Mean', 'BranchPoint', 'Print']
diff --git a/distiller/modules/eltwise.py b/distiller/modules/eltwise.py
index 8435059..6f5895d 100644
--- a/distiller/modules/eltwise.py
+++ b/distiller/modules/eltwise.py
@@ -19,8 +19,8 @@ import torch.nn as nn
 
 class EltwiseAdd(nn.Module):
     def __init__(self, inplace=False):
-        super(EltwiseAdd, self).__init__()
-
+        """Element-wise addition"""
+        super().__init__()
         self.inplace = inplace
 
     def forward(self, *input):
@@ -34,9 +34,27 @@ class EltwiseAdd(nn.Module):
         return res
 
 
+class EltwiseSub(nn.Module):
+    def __init__(self, inplace=False):
+        """Element-wise subtraction"""
+        super().__init__()
+        self.inplace = inplace
+
+    def forward(self, *input):
+        res = input[0]
+        if self.inplace:
+            for t in input[1:]:
+                res -= t
+        else:
+            for t in input[1:]:
+                res = res - t
+        return res
+
+
 class EltwiseMult(nn.Module):
     def __init__(self, inplace=False):
-        super(EltwiseMult, self).__init__()
+        """Element-wise multiplication"""
+        super().__init__()
         self.inplace = inplace
 
     def forward(self, *input):
@@ -52,7 +70,8 @@ class EltwiseMult(nn.Module):
 
 class EltwiseDiv(nn.Module):
     def __init__(self, inplace=False):
-        super(EltwiseDiv, self).__init__()
+        """Element-wise division"""
+        super().__init__()
         self.inplace = inplace
 
     def forward(self, x: torch.Tensor, y):
-- 
GitLab