From 2bb9689fe58d196ccbccd3f2f44ac27192eb64e1 Mon Sep 17 00:00:00 2001
From: Robert Muchsel <16564465+rotx-maxim@users.noreply.github.com>
Date: Thu, 5 Jul 2018 06:16:07 -0500
Subject: [PATCH] Allow quantize_bias to work in DorefaQuantizer and
 checkpoints (#17)

---
 distiller/quantization/clipped_linear.py | 5 +++--
 distiller/quantization/quantizer.py      | 3 ++-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
index 7711a05..734cf5f 100644
--- a/distiller/quantization/clipped_linear.py
+++ b/distiller/quantization/clipped_linear.py
@@ -86,9 +86,10 @@ class DorefaQuantizer(Quantizer):
         1. Gradients quantization not supported yet
         2. The paper defines special handling for 1-bit weights which isn't supported here yet
     """
-    def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}):
+    def __init__(self, model, bits_activations=32, bits_weights=32, bits_overrides={}, quantize_bias=False):
         super(DorefaQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_weights,
-                                              bits_overrides=bits_overrides, train_with_fp_copy=True)
+                                              bits_overrides=bits_overrides, train_with_fp_copy=True,
+                                              quantize_bias=quantize_bias)
 
         def dorefa_quantize_param(param_fp, num_bits):
             scale_factor = asymmetric_linear_quantization_scale_factor(num_bits, 0, 1)
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 6d1850b..854845b 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -81,7 +81,8 @@ class Quantizer(object):
         self.model.quantizer_metadata = {'type': type(self),
                                          'params': {'bits_activations': bits_activations,
                                                     'bits_weights': bits_weights,
-                                                    'bits_overrides': copy.deepcopy(bits_overrides)}}
+                                                    'bits_overrides': copy.deepcopy(bits_overrides),
+                                                    'quantize_bias': quantize_bias}}
 
         for k, v in bits_overrides.items():
             qbits = QBits(acts=v.get('acts', self.default_qbits.acts), wts=v.get('wts', self.default_qbits.wts))
-- 
GitLab