diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
index 2451bfae06e05b4f232ae101d486a2851d81c33d..5827554191f1d14424463f15387fc6e87e4fd15a 100644
--- a/distiller/quantization/clipped_linear.py
+++ b/distiller/quantization/clipped_linear.py
@@ -16,6 +16,7 @@
 
 from collections import OrderedDict
 import torch.nn as nn
+import torch.nn.functional as F
 
 from .quantizer import Quantizer
 from .q_utils import *
@@ -27,34 +28,6 @@ msglogger = logging.getLogger()
 ###
 
 
-class LearnedClippedLinearQuantizeSTE(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, input, clip_val, num_bits, dequantize, inplace):
-        ctx.save_for_backward(input, clip_val)
-        if inplace:
-            ctx.mark_dirty(input)
-        scale, zero_point = asymmetric_linear_quantization_params(num_bits, 0, clip_val.data[0], signed=False)
-        output = clamp(input, 0, clip_val.data[0], inplace)
-        output = linear_quantize(output, scale, zero_point, inplace)
-        if dequantize:
-            output = linear_dequantize(output, scale, zero_point, inplace)
-        return output
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        input, clip_val = ctx.saved_tensors
-        grad_input = grad_output.clone()
-        grad_input[input.le(0)] = 0
-        grad_input[input.ge(clip_val.data[0])] = 0
-
-        grad_alpha = grad_output.clone()
-        grad_alpha[input.lt(clip_val.data[0])] = 0
-        grad_alpha = grad_alpha.sum().expand_as(clip_val)
-
-        # Straight-through estimator for the scale factor calculation
-        return grad_input, grad_alpha, None, None, None
-
-
 class ClippedLinearQuantization(nn.Module):
     def __init__(self, num_bits, clip_val, dequantize=True, inplace=False):
         super(ClippedLinearQuantization, self).__init__()
@@ -84,13 +57,18 @@ class LearnedClippedLinearQuantization(nn.Module):
         self.inplace = inplace
 
     def forward(self, input):
-        input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits,
-                                                      self.dequantize, self.inplace)
+        # Clip between 0 to the learned clip_val
+        input = F.relu(input, self.inplace)
+        # Using the 'where' operation as follows gives us the correct gradient with respect to clip_val
+        input = torch.where(input < self.clip_val, input, self.clip_val)
+        with torch.no_grad():
+            scale, zero_point = asymmetric_linear_quantization_params(self.num_bits, 0, self.clip_val, signed=False)
+        input = LinearQuantizeSTE.apply(input, scale, zero_point, self.dequantize, self.inplace)
         return input
 
     def __repr__(self):
         inplace_str = ', inplace' if self.inplace else ''
-        return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val,
+        return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val.item(),
                                                            inplace_str)
 
 
@@ -126,6 +104,7 @@ class WRPNQuantizer(Quantizer):
 
         self.replacement_factory[nn.ReLU] = relu_replace_fn
 
+
 def dorefa_quantize_param(param_fp, param_meta):
     if param_meta.num_bits == 1:
         out = DorefaParamsBinarizationSTE.apply(param_fp)
@@ -137,6 +116,7 @@ def dorefa_quantize_param(param_fp, param_meta):
         out = 2 * out - 1
     return out
 
+
 class DorefaParamsBinarizationSTE(torch.autograd.Function):
     @staticmethod
     def forward(ctx, input, inplace=False):
@@ -150,6 +130,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
     def backward(ctx, grad_output):
         return grad_output, None
 
+
 class DorefaQuantizer(Quantizer):
     """
     Quantizer using the DoReFa scheme, as defined in: