Skip to content
Snippets Groups Projects
Commit 6dfa8747 authored by Guy Jacob's avatar Guy Jacob
Browse files

Quantization misc. fixes

parent 8cffe6c9
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,6 @@ from tabulate import tabulate
import torch
import distiller
from distiller.utils import normalize_module_name
import distiller.quantization as quantization
msglogger = logging.getLogger()
......@@ -227,7 +226,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
if qmd.get('pytorch_convert', False):
msglogger.info('Converting Distiller PTQ model to PyTorch quantization API')
model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input=qmd['dummy_input'])
model = quantizer.convert_to_pytorch(qmd['dummy_input'], backend=qmd.get('pytorch_convert_backend', None))
if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
......
......@@ -317,6 +317,7 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm')
# This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied
quantizer_metadata['pytorch_convert'] = True
quantizer_metadata['pytorch_convert_backend'] = backend
model.quantizer_metadata = quantizer_metadata
return model
......
......@@ -223,6 +223,9 @@ class Quantizer(object):
with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
will be ignored. A warning message will be shown.
"""
if self.prepared:
raise RuntimeError('prepare_model can be called only once')
msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))
self.model.quantizer_metadata["dummy_input"] = dummy_input
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment