diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py
index 5827554191f1d14424463f15387fc6e87e4fd15a..c7159ab7f205ae6cedcf43fa70c0be6b0fbf31ba 100644
--- a/distiller/quantization/clipped_linear.py
+++ b/distiller/quantization/clipped_linear.py
@@ -192,9 +192,8 @@ class PACTQuantizer(Quantizer):
 
     # In PACT, LearnedClippedLinearQuantization is used for activation, which contains a learnt 'clip_val' parameter
     # We optimize this value separately from the main model parameters
-    def _get_updated_optimizer_params_groups(self):
-        base_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' not in name]}
+    def _get_new_optimizer_params_groups(self):
         clip_val_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' in name]}
         if self.act_clip_decay is not None:
             clip_val_group['weight_decay'] = self.act_clip_decay
-        return [base_group, clip_val_group]
+        return [clip_val_group]
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index db152ff1db709e58a1ce86debd17d0d48ada0bb4..369271cb1ba0499dc035c92cb79bc18251dc3a86 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -37,12 +37,14 @@ def has_bias(module):
 
 def hack_float_backup_parameter(module, name, num_bits):
     try:
-        data = dict(module.named_parameters())[name].data
+        param = dict(module.named_parameters())[name]
+        param_id = id(param)
     except KeyError:
         raise ValueError('Module has no Parameter named ' + name)
-    module.register_parameter(FP_BKP_PREFIX + name, nn.Parameter(data))
+    module.register_parameter(FP_BKP_PREFIX + name, param)
+    assert id(getattr(module, FP_BKP_PREFIX + name)) == param_id
     delattr(module, name)
-    module.register_buffer(name, torch.zeros_like(data))
+    module.register_buffer(name, torch.zeros_like(param))
 
     first = False
     if not hasattr(module, 'repr_mod'):
@@ -240,9 +242,8 @@ class Quantizer(object):
 
         # If an optimizer was passed, assume we need to update it
         if self.optimizer:
-            optimizer_type = type(self.optimizer)
-            new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
-            self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
+            for pg in self._get_new_optimizer_params_groups():
+                self.optimizer.add_param_group(pg)
 
         self._post_prepare_model()
 
@@ -308,19 +309,18 @@ class Quantizer(object):
                 # For container we call recursively
                 self._pre_process_container(module, full_name + '.')
 
-    def _get_updated_optimizer_params_groups(self):
+    def _get_new_optimizer_params_groups(self):
         """
-        Returns a list of model parameter groups and optimizer hyper-parameter overrides,
-        as expected by the __init__ function of torch.optim.Optimizer.
-        This is called after all model changes were made in prepare_model, in case an Optimizer instance was
-        passed to __init__.
+        If the quantizer adds new trainable parameters to the model, this function should return a list of one
+        or more parameter groups pertaining. Each parameter group is expected to be a dict in the format
+        expected by torch.optim.Optimizer.
+        For details, See https://pytorch.org/docs/stable/optim.html#per-parameter-options
 
         Subclasses which add parameters to the model should override as needed.
 
         :return: List of parameter groups
         """
-        # Default implementation - just return all model parameters as one group
-        return [{'params': self.model.parameters()}]
+        return list()
 
     def _post_prepare_model(self):
         pass