Skip to content
Snippets Groups Projects
Commit 87d7c6ce authored by tacker-oh's avatar tacker-oh Committed by Guy Jacob
Browse files

Proper handling of 0s in DoReFa 1-bit weights (#205)

Fixes #198.
Previously 0s were being mapped to 0, effectively yielding a third 
quantization level. This fix maps 0s to 1.
parent 72ef9160
No related branches found
No related tags found
No related merge requests found
...@@ -143,7 +143,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function): ...@@ -143,7 +143,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function):
if inplace: if inplace:
ctx.mark_dirty(input) ctx.mark_dirty(input)
E = input.abs().mean() E = input.abs().mean()
output = input.sign() * E output = torch.where(input == 0, torch.ones_like(input), torch.sign(input)) * E
return output return output
@staticmethod @staticmethod
...@@ -158,7 +158,6 @@ class DorefaQuantizer(Quantizer): ...@@ -158,7 +158,6 @@ class DorefaQuantizer(Quantizer):
Notes: Notes:
1. Gradients quantization not supported yet 1. Gradients quantization not supported yet
2. The paper defines special handling for 1-bit weights which isn't supported here yet
""" """
def __init__(self, model, optimizer, def __init__(self, model, optimizer,
bits_activations=32, bits_weights=32, bits_bias=None, bits_activations=32, bits_weights=32, bits_bias=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment