From eebe7a7881d4b393d0c7afe770f4ad63b79fac2c Mon Sep 17 00:00:00 2001 From: JoyFreemanYan <JoyF.Yan@gmail.com> Date: Sun, 27 Jan 2019 16:45:31 +0800 Subject: [PATCH] DoReFa 1-bit weights (#127) --- distiller/quantization/clipped_linear.py | 26 ++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index d7fedf9..d48f0b9 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -125,15 +125,29 @@ class WRPNQuantizer(Quantizer): self.replacement_factory[nn.ReLU] = relu_replace_fn - def dorefa_quantize_param(param_fp, param_meta): - scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False) - out = param_fp.tanh() - out = out / (2 * out.abs().max()) + 0.5 - out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False) - out = 2 * out - 1 + if param_meta.num_bits == 1: + out = DorefaParamsBinarizationSTE.apply(param_fp) + else: + scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False) + out = param_fp.tanh() + out = out / (2 * out.abs().max()) + 0.5 + out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False) + out = 2 * out - 1 return out +class DorefaParamsBinarizationSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, input, inplace=False): + if inplace: + ctx.mark_dirty(input) + E = input.abs().mean() + output = input.sign() * E + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None class DorefaQuantizer(Quantizer): """ -- GitLab