diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index 70e9de0deccc088539d4fd8f820fdb47220b2d7c..fc222e3ce74ecb794cb66830915517a5e02293cb 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -29,7 +29,6 @@ from tabulate import tabulate import torch import distiller from distiller.utils import normalize_module_name -import distiller.quantization as quantization msglogger = logging.getLogger() @@ -227,7 +226,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, if qmd.get('pytorch_convert', False): msglogger.info('Converting Distiller PTQ model to PyTorch quantization API') - model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input=qmd['dummy_input']) + model = quantizer.convert_to_pytorch(qmd['dummy_input'], backend=qmd.get('pytorch_convert_backend', None)) if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} diff --git a/distiller/quantization/pytorch_quant_conversion.py b/distiller/quantization/pytorch_quant_conversion.py index 98501b4f956f853db6d6237ca01897ec12b9eaac..bda477a2099c9b3309038c68c8207023f88464e7 100644 --- a/distiller/quantization/pytorch_quant_conversion.py +++ b/distiller/quantization/pytorch_quant_conversion.py @@ -317,6 +317,7 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm') # This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied quantizer_metadata['pytorch_convert'] = True + quantizer_metadata['pytorch_convert_backend'] = backend model.quantizer_metadata = quantizer_metadata return model diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index f4bc4488c571a1c322696e7294c8ad3d4e1f0c56..c3bd293e17bfc6757638a50e1f54314b9aedc51d 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -223,6 +223,9 @@ class Quantizer(object): with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' will be ignored. A warning message will be shown. """ + if self.prepared: + raise RuntimeError('prepare_model can be called only once') + msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) self.model.quantizer_metadata["dummy_input"] = dummy_input