diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index b8f71c118a44ddb560bb98f3881e21f8a65fc396..e1e710d5eb29548adfec063f0027297b8a5181d4 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -161,7 +161,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *, msglogger.info('Loaded quantizer metadata from the checkpoint') qmd = checkpoint['quantizer_metadata'] quantizer = qmd['type'](model, **qmd['params']) - quantizer.prepare_model() + quantizer.prepare_model(qmd['dummy_input']) if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 242b19a2da9cb1ffd7c18a291a12ac5713541b2a..fcbd9342e833cba0829cda581a4a41c01e8ef9ca 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -213,6 +213,7 @@ class Quantizer(object): """ msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) + self.model.quantizer_metadata["dummy_input"] = dummy_input if dummy_input is not None: summary_graph = distiller.SummaryGraph(self.model, dummy_input) self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)