diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index 7711a05cbe186dabc59bc8eabc5116df928e2d98..734cf5f72ff72932cdfce453487b172a1d68ff6a 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -86,9 +86,10 @@ class DorefaQuantizer(Quantizer): 1. Gradients quantization not supported yet 2. The paper defines special handling for 1-bit weights which isn't supported here yet """ - def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}): + def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}, quantize_bias=False): super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights, - bits_overrides=bits_overrides, train_with_fp_copy=True) + bits_overrides=bits_overrides, train_with_fp_copy=True, + quantize_bias=quantize_bias) def dorefa_quantize_param(param_fp, num_bits): scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1) diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 6d1850bb19e6904cc88c64c15a36575be39b9588..854845b10a8e9501808580da627716ff2a47787b 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -81,7 +81,8 @@ class Quantizer(object): self.model.quantizer_metadata = {'type': type(self), 'params': {'bits_activations': bits_activations, 'bits_weights': bits_weights, - 'bits_overrides': copy.deepcopy(bits_overrides)}} + 'bits_overrides': copy.deepcopy(bits_overrides), + 'quantize_bias': quantize_bias}} for k, v in bits_overrides.items(): qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts))