From 6782ccaeae0c02b542c4050d55abc585e34dd1bd Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Tue, 23 Jul 2019 18:19:47 +0300
Subject: [PATCH] Save dummy_input in quantizer metadata (#333)

And use it when calling prepare_model when loading from a checkpoint
---
 distiller/apputils/checkpoint.py    | 2 +-
 distiller/quantization/quantizer.py | 1 +
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index b8f71c1..e1e710d 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -161,7 +161,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
         msglogger.info('Loaded quantizer metadata from the checkpoint')
         qmd = checkpoint['quantizer_metadata']
         quantizer = qmd['type'](model, **qmd['params'])
-        quantizer.prepare_model()
+        quantizer.prepare_model(qmd['dummy_input'])
 
     if normalize_dataparallel_keys:
             checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 242b19a..fcbd934 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -213,6 +213,7 @@ class Quantizer(object):
         """
         msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))
 
+        self.model.quantizer_metadata["dummy_input"] = dummy_input
         if dummy_input is not None:
             summary_graph = distiller.SummaryGraph(self.model, dummy_input)
             self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)
-- 
GitLab