diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index d7fedf94c322d034d2c48fed3e9c86b644627c8f..d48f0b9d740583f4cd5a853a681d90993cadf47c 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -125,15 +125,29 @@ class WRPNQuantizer(Quantizer): self.replacement_factory[nn.ReLU] = relu_replace_fn - def dorefa_quantize_param(param_fp, param_meta): - scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False) - out = param_fp.tanh() - out = out / (2 * out.abs().max()) + 0.5 - out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False) - out = 2 * out - 1 + if param_meta.num_bits == 1: + out = DorefaParamsBinarizationSTE.apply(param_fp) + else: + scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False) + out = param_fp.tanh() + out = out / (2 * out.abs().max()) + 0.5 + out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False) + out = 2 * out - 1 return out +class DorefaParamsBinarizationSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, input, inplace=False): + if inplace: + ctx.mark_dirty(input) + E = input.abs().mean() + output = input.sign() * E + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None class DorefaQuantizer(Quantizer): """