Skip to content
Snippets Groups Projects
Commit 6782ccae authored by Lev Zlotnik's avatar Lev Zlotnik Committed by Guy Jacob
Browse files

Save dummy_input in quantizer metadata (#333)

And use it when calling prepare_model when loading from a checkpoint
parent 88474af2
No related branches found
No related tags found
No related merge requests found
......@@ -161,7 +161,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
msglogger.info('Loaded quantizer metadata from the checkpoint')
qmd = checkpoint['quantizer_metadata']
quantizer = qmd['type'](model, **qmd['params'])
quantizer.prepare_model()
quantizer.prepare_model(qmd['dummy_input'])
if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
......
......@@ -213,6 +213,7 @@ class Quantizer(object):
"""
msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))
self.model.quantizer_metadata["dummy_input"] = dummy_input
if dummy_input is not None:
summary_graph = distiller.SummaryGraph(self.model, dummy_input)
self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment