Skip to content
Snippets Groups Projects
Commit 6dfa8747 authored by Guy Jacob's avatar Guy Jacob
Browse files

Quantization misc. fixes

parent 8cffe6c9
No related branches found
No related tags found
No related merge requests found
...@@ -29,7 +29,6 @@ from tabulate import tabulate ...@@ -29,7 +29,6 @@ from tabulate import tabulate
import torch import torch
import distiller import distiller
from distiller.utils import normalize_module_name from distiller.utils import normalize_module_name
import distiller.quantization as quantization
msglogger = logging.getLogger() msglogger = logging.getLogger()
...@@ -227,7 +226,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, ...@@ -227,7 +226,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
if qmd.get('pytorch_convert', False): if qmd.get('pytorch_convert', False):
msglogger.info('Converting Distiller PTQ model to PyTorch quantization API') msglogger.info('Converting Distiller PTQ model to PyTorch quantization API')
model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input=qmd['dummy_input']) model = quantizer.convert_to_pytorch(qmd['dummy_input'], backend=qmd.get('pytorch_convert_backend', None))
if normalize_dataparallel_keys: if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
......
...@@ -317,6 +317,7 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm') ...@@ -317,6 +317,7 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm')
# This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied # This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied
quantizer_metadata['pytorch_convert'] = True quantizer_metadata['pytorch_convert'] = True
quantizer_metadata['pytorch_convert_backend'] = backend
model.quantizer_metadata = quantizer_metadata model.quantizer_metadata = quantizer_metadata
return model return model
......
...@@ -223,6 +223,9 @@ class Quantizer(object): ...@@ -223,6 +223,9 @@ class Quantizer(object):
with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
will be ignored. A warning message will be shown. will be ignored. A warning message will be shown.
""" """
if self.prepared:
raise RuntimeError('prepare_model can be called only once')
msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))
self.model.quantizer_metadata["dummy_input"] = dummy_input self.model.quantizer_metadata["dummy_input"] = dummy_input
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment