diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
index d7fedf94c322d034d2c48fed3e9c86b644627c8f..d48f0b9d740583f4cd5a853a681d90993cadf47c 100644
--- a/distiller/quantization/clipped_linear.py
+++ b/distiller/quantization/clipped_linear.py
@@ -125,15 +125,29 @@ class WRPNQuantizer(Quantizer):
 
         self.replacement_factory[nn.ReLU] = relu_replace_fn
 
-
 def dorefa_quantize_param(param_fp, param_meta):
-    scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False)
-    out = param_fp.tanh()
-    out = out / (2 * out.abs().max()) + 0.5
-    out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False)
-    out = 2 * out - 1
+    if param_meta.num_bits == 1:
+        out = DorefaParamsBinarizationSTE.apply(param_fp)
+    else:
+        scale, zero_point = asymmetric_linear_quantization_params(param_meta.num_bits, 0, 1, signed=False)
+        out = param_fp.tanh()
+        out = out / (2 * out.abs().max()) + 0.5
+        out = LinearQuantizeSTE.apply(out, scale, zero_point, True, False)
+        out = 2 * out - 1
     return out
 
+class DorefaParamsBinarizationSTE(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, inplace=False):
+        if inplace:
+            ctx.mark_dirty(input)
+        E = input.abs().mean()
+        output = input.sign() * E
+        return output
+    
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output, None
 
 class DorefaQuantizer(Quantizer):
     """