diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 369271cb1ba0499dc035c92cb79bc18251dc3a86..2b3ac788e21b034fd8d87d38ea464825926b6309 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -220,6 +220,8 @@ class Quantizer(object):
             summary_graph = distiller.SummaryGraph(self.model, dummy_input)
             self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)
 
+        model_device = distiller.model_device(self.model)
+
         self._pre_prepare_model(dummy_input)
 
         self._pre_process_container(self.model)
@@ -247,6 +249,9 @@ class Quantizer(object):
 
         self._post_prepare_model()
 
+        # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers
+        self.model.to(model_device)
+
         msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
 
     def _pre_prepare_model(self, dummy_input):