diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index ad6079cd324286f0c6f409d718ac8589d45c0602..2451bfae06e05b4f232ae101d486a2851d81c33d 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -143,7 +143,7 @@ class DorefaParamsBinarizationSTE(torch.autograd.Function): if inplace: ctx.mark_dirty(input) E = input.abs().mean() - output = input.sign() * E + output = torch.where(input == 0, torch.ones_like(input), torch.sign(input)) * E return output @staticmethod @@ -158,7 +158,6 @@ class DorefaQuantizer(Quantizer): Notes: 1. Gradients quantization not supported yet - 2. The paper defines special handling for 1-bit weights which isn't supported here yet """ def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_bias=None,