diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 369271cb1ba0499dc035c92cb79bc18251dc3a86..2b3ac788e21b034fd8d87d38ea464825926b6309 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -220,6 +220,8 @@ class Quantizer(object): summary_graph = distiller.SummaryGraph(self.model, dummy_input) self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False) + model_device = distiller.model_device(self.model) + self._pre_prepare_model(dummy_input) self._pre_process_container(self.model) @@ -247,6 +249,9 @@ class Quantizer(object): self._post_prepare_model() + # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers + self.model.to(model_device) + msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) def _pre_prepare_model(self, dummy_input):