diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index 49f4cd58a5dadc6814ccca739006522c69e7837c..6865d5759b1e986808f65fefe9eebf9b5c022885 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -317,15 +317,17 @@ def test_param_quantization(model, optimizer, qbits, bits_overrides, explicit_ex if has_children(pre_quant_module): continue - num_bits = expected_qbits[name].wts + num_qbits = expected_qbits[name].wts for param_name, pre_quant_param in pre_quant_module.named_parameters(): - quantizable = num_bits is not None + quantizable = num_qbits is not None if param_name.endswith('bias'): quantizable = quantizable and quantize_bias # Bias number of bits is hard-coded to 32 for now... if quantizable: num_bits = 32 + else: + num_bits = num_qbits if quantizable and train_with_fp_copy: # "param_name" and "pre_quant_param" refer to the float copy