Skip to content
Snippets Groups Projects
Commit eebe7a78 authored by JoyFreemanYan's avatar JoyFreemanYan Committed by Guy Jacob
Browse files

DoReFa 1-bit weights (#127)

parent 8d694a03
No related branches found
No related tags found
No related merge requests found
......@@ -125,15 +125,29 @@ class WRPNQuantizer(Quantizer):
self.replacement_factory[nn.ReLU] = relu_replace_fn
def dorefa_quantize_param(param_fp, param_meta):
scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False)
out = param_fp.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False)
out = 2 * out - 1
if param_meta.num_bits == 1:
out = DorefaParamsBinarizationSTE.apply(param_fp)
else:
scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False)
out = param_fp.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False)
out = 2 * out - 1
return out
class DorefaParamsBinarizationSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, input, inplace=False):
if inplace:
ctx.mark_dirty(input)
E = input.abs().mean()
output = input.sign() * E
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class DorefaQuantizer(Quantizer):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment