From 87d7c6cea158a5585ea23c254f6f5cfad4c84b5b Mon Sep 17 00:00:00 2001
From: tacker-oh <tk04m9@gmail.com>
Date: Mon, 8 Apr 2019 20:46:24 +0900
Subject: [PATCH] Proper handling of 0s in DoReFa 1-bit weights (#205)

Fixes #198.
Previously 0s were being mapped to 0, effectively yielding a third
quantization level. This fix maps 0s to 1.
---
 distiller/quantization/clipped_linear.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
index ad6079c..2451bfa 100644
--- a/distiller/quantization/clipped_linear.py
+++ b/distiller/quantization/clipped_linear.py
@@ -143,7 +143,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
         if inplace:
             ctx.mark_dirty(input)
         E = input.abs().mean()
-        output = input.sign() * E
+        output = torch.where(input == 0, torch.ones_like(input), torch.sign(input)) * E
         return output
     
     @staticmethod
@@ -158,7 +158,6 @@ class DorefaQuantizer(Quantizer):
 
     Notes:
         1. Gradients quantization not supported yet
-        2. The paper defines special handling for 1-bit weights which isn't supported here yet
     """
     def __init__(self, model, optimizer,
                  bits_activations=32, bits_weights=32, bits_bias=None,
-- 
GitLab