Skip to content
Snippets Groups Projects
Commit 87d7c6ce authored by tacker-oh's avatar tacker-oh Committed by Guy Jacob
Browse files

Proper handling of 0s in DoReFa 1-bit weights (#205)

Fixes #198.
Previously 0s were being mapped to 0, effectively yielding a third 
quantization level. This fix maps 0s to 1.
parent 72ef9160
No related branches found
No related tags found
No related merge requests found
......@@ -143,7 +143,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
if inplace:
ctx.mark_dirty(input)
E = input.abs().mean()
output = input.sign() * E
output = torch.where(input == 0, torch.ones_like(input), torch.sign(input)) * E
return output
@staticmethod
......@@ -158,7 +158,6 @@ class DorefaQuantizer(Quantizer):
Notes:
1. Gradients quantization not supported yet
2. The paper defines special handling for 1-bit weights which isn't supported here yet
"""
def __init__(self, model, optimizer,
bits_activations=32, bits_weights=32, bits_bias=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment