From 6dfa8747f1c39d5ab7af1d1f46ec26f450cbc006 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Thu, 13 Feb 2020 12:21:12 +0200 Subject: [PATCH] Quantization misc. fixes --- distiller/apputils/checkpoint.py | 3 +-- distiller/quantization/pytorch_quant_conversion.py | 1 + distiller/quantization/quantizer.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index 70e9de0..fc222e3 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -29,7 +29,6 @@ from tabulate import tabulate import torch import distiller from distiller.utils import normalize_module_name -import distiller.quantization as quantization msglogger = logging.getLogger() @@ -227,7 +226,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, if qmd.get('pytorch_convert', False): msglogger.info('Converting Distiller PTQ model to PyTorch quantization API') - model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input=qmd['dummy_input']) + model = quantizer.convert_to_pytorch(qmd['dummy_input'], backend=qmd.get('pytorch_convert_backend', None)) if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} diff --git a/distiller/quantization/pytorch_quant_conversion.py b/distiller/quantization/pytorch_quant_conversion.py index 98501b4..bda477a 100644 --- a/distiller/quantization/pytorch_quant_conversion.py +++ b/distiller/quantization/pytorch_quant_conversion.py @@ -317,6 +317,7 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm') # This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied quantizer_metadata['pytorch_convert'] = True + quantizer_metadata['pytorch_convert_backend'] = backend model.quantizer_metadata = quantizer_metadata return model diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index f4bc448..c3bd293 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -223,6 +223,9 @@ class Quantizer(object): with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' will be ignored. A warning message will be shown. """ + if self.prepared: + raise RuntimeError('prepare_model can be called only once') + msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) self.model.quantizer_metadata["dummy_input"] = dummy_input -- GitLab