From 87d7c6cea158a5585ea23c254f6f5cfad4c84b5b Mon Sep 17 00:00:00 2001 From: tacker-oh <tk04m9@gmail.com> Date: Mon, 8 Apr 2019 20:46:24 +0900 Subject: [PATCH] Proper handling of 0s in DoReFa 1-bit weights (#205) Fixes #198. Previously 0s were being mapped to 0, effectively yielding a third quantization level. This fix maps 0s to 1. --- distiller/quantization/clipped_linear.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index ad6079c..2451bfa 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -143,7 +143,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function): if inplace: ctx.mark_dirty(input) E = input.abs().mean() - output = input.sign() * E + output = torch.where(input == 0, torch.ones_like(input), torch.sign(input)) * E return output @staticmethod @@ -158,7 +158,6 @@ class DorefaQuantizer(Quantizer): Notes: 1. Gradients quantization not supported yet - 2. The paper defines special handling for 1-bit weights which isn't supported here yet """ def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_bias=None, -- GitLab