Skip to content
Snippets Groups Projects
Commit 2bb9689f authored by Robert Muchsel's avatar Robert Muchsel Committed by Guy Jacob
Browse files

Allow quantize_bias to work in DorefaQuantizer and checkpoints (#17)

parent e7c7d94f
No related branches found
No related tags found
No related merge requests found
...@@ -86,9 +86,10 @@ class DorefaQuantizer(Quantizer): ...@@ -86,9 +86,10 @@ class DorefaQuantizer(Quantizer):
1. Gradients quantization not supported yet 1. Gradients quantization not supported yet
2. The paper defines special handling for 1-bit weights which isn't supported here yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet
""" """
def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}): def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}, quantize_bias=False):
super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
bits_overrides=bits_overrides, train_with_fp_copy=True) bits_overrides=bits_overrides, train_with_fp_copy=True,
quantize_bias=quantize_bias)
def dorefa_quantize_param(param_fp, num_bits): def dorefa_quantize_param(param_fp, num_bits):
scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1) scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
......
...@@ -81,7 +81,8 @@ class Quantizer(object): ...@@ -81,7 +81,8 @@ class Quantizer(object):
self.model.quantizer_metadata = {'type': type(self), self.model.quantizer_metadata = {'type': type(self),
'params': {'bits_activations': bits_activations, 'params': {'bits_activations': bits_activations,
'bits_weights': bits_weights, 'bits_weights': bits_weights,
'bits_overrides': copy.deepcopy(bits_overrides)}} 'bits_overrides': copy.deepcopy(bits_overrides),
'quantize_bias': quantize_bias}}
for k, v in bits_overrides.items(): for k, v in bits_overrides.items():
qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts)) qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment