diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index b8f71c118a44ddb560bb98f3881e21f8a65fc396..e1e710d5eb29548adfec063f0027297b8a5181d4 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -161,7 +161,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
         msglogger.info('Loaded quantizer metadata from the checkpoint')
         qmd = checkpoint['quantizer_metadata']
         quantizer = qmd['type'](model, **qmd['params'])
-        quantizer.prepare_model()
+        quantizer.prepare_model(qmd['dummy_input'])
 
     if normalize_dataparallel_keys:
             checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 242b19a2da9cb1ffd7c18a291a12ac5713541b2a..fcbd9342e833cba0829cda581a4a41c01e8ef9ca 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -213,6 +213,7 @@ class Quantizer(object):
         """
         msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))
 
+        self.model.quantizer_metadata["dummy_input"] = dummy_input
         if dummy_input is not None:
             summary_graph = distiller.SummaryGraph(self.model, dummy_input)
             self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)