Skip to content
Snippets Groups Projects
Commit 2bb9689f authored by Robert Muchsel's avatar Robert Muchsel Committed by Guy Jacob
Browse files

Allow quantize_bias to work in DorefaQuantizer and checkpoints (#17)

parent e7c7d94f
No related branches found
No related tags found
No related merge requests found
......@@ -86,9 +86,10 @@ class DorefaQuantizer(Quantizer):
1. Gradients quantization not supported yet
2. The paper defines special handling for 1-bit weights which isn't supported here yet
"""
def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}):
def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}, quantize_bias=False):
super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
bits_overrides=bits_overrides, train_with_fp_copy=True)
bits_overrides=bits_overrides, train_with_fp_copy=True,
quantize_bias=quantize_bias)
def dorefa_quantize_param(param_fp, num_bits):
scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
......
......@@ -81,7 +81,8 @@ class Quantizer(object):
self.model.quantizer_metadata = {'type': type(self),
'params': {'bits_activations': bits_activations,
'bits_weights': bits_weights,
'bits_overrides': copy.deepcopy(bits_overrides)}}
'bits_overrides': copy.deepcopy(bits_overrides),
'quantize_bias': quantize_bias}}
for k, v in bits_overrides.items():
qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment