From 6782ccaeae0c02b542c4050d55abc585e34dd1bd Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Tue, 23 Jul 2019 18:19:47 +0300 Subject: [PATCH] Save dummy_input in quantizer metadata (#333) And use it when calling prepare_model when loading from a checkpoint --- distiller/apputils/checkpoint.py | 2 +- distiller/quantization/quantizer.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index b8f71c1..e1e710d 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -161,7 +161,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, msglogger.info('Loaded quantizer metadata from the checkpoint') qmd = checkpoint['quantizer_metadata'] quantizer = qmd['type'](model, **qmd['params']) - quantizer.prepare_model() + quantizer.prepare_model(qmd['dummy_input']) if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 242b19a..fcbd934 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -213,6 +213,7 @@ class Quantizer(object): """ msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) + self.model.quantizer_metadata["dummy_input"] = dummy_input if dummy_input is not None: summary_graph = distiller.SummaryGraph(self.model, dummy_input) self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False) -- GitLab