diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index 2451bfae06e05b4f232ae101d486a2851d81c33d..5827554191f1d14424463f15387fc6e87e4fd15a 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -16,6 +16,7 @@ from collections import OrderedDict import torch.nn as nn +import torch.nn.functional as F from .quantizer import Quantizer from .q_utils import * @@ -27,34 +28,6 @@ msglogger = logging.getLogger() ### -class LearnedClippedLinearQuantizeSTE(torch.autograd.Function): - @staticmethod - def forward(ctx, input, clip_val, num_bits, dequantize, inplace): - ctx.save_for_backward(input, clip_val) - if inplace: - ctx.mark_dirty(input) - scale, zero_point = asymmetric_linear_quantization_params(num_bits, 0, clip_val.data[0], signed=False) - output = clamp(input, 0, clip_val.data[0], inplace) - output = linear_quantize(output, scale, zero_point, inplace) - if dequantize: - output = linear_dequantize(output, scale, zero_point, inplace) - return output - - @staticmethod - def backward(ctx, grad_output): - input, clip_val = ctx.saved_tensors - grad_input = grad_output.clone() - grad_input[input.le(0)] = 0 - grad_input[input.ge(clip_val.data[0])] = 0 - - grad_alpha = grad_output.clone() - grad_alpha[input.lt(clip_val.data[0])] = 0 - grad_alpha = grad_alpha.sum().expand_as(clip_val) - - # Straight-through estimator for the scale factor calculation - return grad_input, grad_alpha, None, None, None - - class ClippedLinearQuantization(nn.Module): def __init__(self, num_bits, clip_val, dequantize=True, inplace=False): super(ClippedLinearQuantization, self).__init__() @@ -84,13 +57,18 @@ class LearnedClippedLinearQuantization(nn.Module): self.inplace = inplace def forward(self, input): - input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, - self.dequantize, self.inplace) + # Clip between 0 to the learned clip_val + input = F.relu(input, self.inplace) + # Using the 'where' operation as follows gives us the correct gradient with respect to clip_val + input = torch.where(input < self.clip_val, input, self.clip_val) + with torch.no_grad(): + scale, zero_point = asymmetric_linear_quantization_params(self.num_bits, 0, self.clip_val, signed=False) + input = LinearQuantizeSTE.apply(input, scale, zero_point, self.dequantize, self.inplace) return input def __repr__(self): inplace_str = ', inplace' if self.inplace else '' - return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val, + return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val.item(), inplace_str) @@ -126,6 +104,7 @@ class WRPNQuantizer(Quantizer): self.replacement_factory[nn.ReLU] = relu_replace_fn + def dorefa_quantize_param(param_fp, param_meta): if param_meta.num_bits == 1: out = DorefaParamsBinarizationSTE.apply(param_fp) @@ -137,6 +116,7 @@ def dorefa_quantize_param(param_fp, param_meta): out = 2 * out - 1 return out + class DorefaParamsBinarizationSTE(torch.autograd.Function): @staticmethod def forward(ctx, input, inplace=False): @@ -150,6 +130,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function): def backward(ctx, grad_output): return grad_output, None + class DorefaQuantizer(Quantizer): """ Quantizer using the DoReFa scheme, as defined in: