diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index e46bc81d1a80dba96cf9b976ff742bb3a1b8eb63..d8a5352cfefcf0bcf1a961a910ee7402965f8ec9 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -21,7 +21,7 @@ from .rnn import *
 from .aggregate import *
 from .topology import *
 
-__all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
+__all__ = ['EltwiseAdd', 'EltwiseSub', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
            'Concat', 'Chunk', 'Split', 'Stack',
            'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
            'Norm', 'Mean', 'BranchPoint', 'Print']
diff --git a/distiller/modules/eltwise.py b/distiller/modules/eltwise.py
index 84350593f84fd97a7888eb83798389f0e90dd540..6f5895df2591093ae9a38678213c74eeba9c9b50 100644
--- a/distiller/modules/eltwise.py
+++ b/distiller/modules/eltwise.py
@@ -19,8 +19,8 @@ import torch.nn as nn
 
 class EltwiseAdd(nn.Module):
     def __init__(self, inplace=False):
-        super(EltwiseAdd, self).__init__()
-
+        """Element-wise addition"""
+        super().__init__()
         self.inplace = inplace
 
     def forward(self, *input):
@@ -34,9 +34,27 @@ class EltwiseAdd(nn.Module):
         return res
 
 
+class EltwiseSub(nn.Module):
+    def __init__(self, inplace=False):
+        """Element-wise subtraction"""
+        super().__init__()
+        self.inplace = inplace
+
+    def forward(self, *input):
+        res = input[0]
+        if self.inplace:
+            for t in input[1:]:
+                res -= t
+        else:
+            for t in input[1:]:
+                res = res - t
+        return res
+
+
 class EltwiseMult(nn.Module):
     def __init__(self, inplace=False):
-        super(EltwiseMult, self).__init__()
+        """Element-wise multiplication"""
+        super().__init__()
         self.inplace = inplace
 
     def forward(self, *input):
@@ -52,7 +70,8 @@ class EltwiseMult(nn.Module):
 
 class EltwiseDiv(nn.Module):
     def __init__(self, inplace=False):
-        super(EltwiseDiv, self).__init__()
+        """Element-wise division"""
+        super().__init__()
         self.inplace = inplace
 
     def forward(self, x: torch.Tensor, y):