Skip to content
Snippets Groups Projects
Unverified Commit 9bc69fef authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Faster and more memory-efficient impl. of Learned-clipped quant (#336)

parent cb02798a
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from collections import OrderedDict from collections import OrderedDict
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .quantizer import Quantizer from .quantizer import Quantizer
from .q_utils import * from .q_utils import *
...@@ -27,34 +28,6 @@ msglogger = logging.getLogger() ...@@ -27,34 +28,6 @@ msglogger = logging.getLogger()
### ###
class LearnedClippedLinearQuantizeSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, input, clip_val, num_bits, dequantize, inplace):
ctx.save_for_backward(input, clip_val)
if inplace:
ctx.mark_dirty(input)
scale, zero_point = asymmetric_linear_quantization_params(num_bits, 0, clip_val.data[0], signed=False)
output = clamp(input, 0, clip_val.data[0], inplace)
output = linear_quantize(output, scale, zero_point, inplace)
if dequantize:
output = linear_dequantize(output, scale, zero_point, inplace)
return output
@staticmethod
def backward(ctx, grad_output):
input, clip_val = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input.le(0)] = 0
grad_input[input.ge(clip_val.data[0])] = 0
grad_alpha = grad_output.clone()
grad_alpha[input.lt(clip_val.data[0])] = 0
grad_alpha = grad_alpha.sum().expand_as(clip_val)
# Straight-through estimator for the scale factor calculation
return grad_input, grad_alpha, None, None, None
class ClippedLinearQuantization(nn.Module): class ClippedLinearQuantization(nn.Module):
def __init__(self, num_bits, clip_val, dequantize=True, inplace=False): def __init__(self, num_bits, clip_val, dequantize=True, inplace=False):
super(ClippedLinearQuantization, self).__init__() super(ClippedLinearQuantization, self).__init__()
...@@ -84,13 +57,18 @@ class LearnedClippedLinearQuantization(nn.Module): ...@@ -84,13 +57,18 @@ class LearnedClippedLinearQuantization(nn.Module):
self.inplace = inplace self.inplace = inplace
def forward(self, input): def forward(self, input):
input = LearnedClippedLinearQuantizeSTE.apply(input, self.clip_val, self.num_bits, # Clip between 0 to the learned clip_val
self.dequantize, self.inplace) input = F.relu(input, self.inplace)
# Using the 'where' operation as follows gives us the correct gradient with respect to clip_val
input = torch.where(input < self.clip_val, input, self.clip_val)
with torch.no_grad():
scale, zero_point = asymmetric_linear_quantization_params(self.num_bits, 0, self.clip_val, signed=False)
input = LinearQuantizeSTE.apply(input, scale, zero_point, self.dequantize, self.inplace)
return input return input
def __repr__(self): def __repr__(self):
inplace_str = ', inplace' if self.inplace else '' inplace_str = ', inplace' if self.inplace else ''
return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val, return '{0}(num_bits={1}, clip_val={2}{3})'.format(self.__class__.__name__, self.num_bits, self.clip_val.item(),
inplace_str) inplace_str)
...@@ -126,6 +104,7 @@ class WRPNQuantizer(Quantizer): ...@@ -126,6 +104,7 @@ class WRPNQuantizer(Quantizer):
self.replacement_factory[nn.ReLU] = relu_replace_fn self.replacement_factory[nn.ReLU] = relu_replace_fn
def dorefa_quantize_param(param_fp, param_meta): def dorefa_quantize_param(param_fp, param_meta):
if param_meta.num_bits == 1: if param_meta.num_bits == 1:
out = DorefaParamsBinarizationSTE.apply(param_fp) out = DorefaParamsBinarizationSTE.apply(param_fp)
...@@ -137,6 +116,7 @@ def dorefa_quantize_param(param_fp, param_meta): ...@@ -137,6 +116,7 @@ def dorefa_quantize_param(param_fp, param_meta):
out = 2 * out - 1 out = 2 * out - 1
return out return out
class DorefaParamsBinarizationSTE(torch.autograd.Function): class DorefaParamsBinarizationSTE(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, inplace=False): def forward(ctx, input, inplace=False):
...@@ -150,6 +130,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function): ...@@ -150,6 +130,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None return grad_output, None
class DorefaQuantizer(Quantizer): class DorefaQuantizer(Quantizer):
""" """
Quantizer using the DoReFa scheme, as defined in: Quantizer using the DoReFa scheme, as defined in:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment