diff --git a/distiller/quantization/clipped_linear.py b/distiller/quantization/clipped_linear.py index 5827554191f1d14424463f15387fc6e87e4fd15a..c7159ab7f205ae6cedcf43fa70c0be6b0fbf31ba 100644 --- a/distiller/quantization/clipped_linear.py +++ b/distiller/quantization/clipped_linear.py @@ -192,9 +192,8 @@ class PACTQuantizer(Quantizer): # In PACT, LearnedClippedLinearQuantization is used for activation, which contains a learnt 'clip_val' parameter # We optimize this value separately from the main model parameters - def _get_updated_optimizer_params_groups(self): - base_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' not in name]} + def _get_new_optimizer_params_groups(self): clip_val_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' in name]} if self.act_clip_decay is not None: clip_val_group['weight_decay'] = self.act_clip_decay - return [base_group, clip_val_group] + return [clip_val_group] diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index db152ff1db709e58a1ce86debd17d0d48ada0bb4..369271cb1ba0499dc035c92cb79bc18251dc3a86 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -37,12 +37,14 @@ def has_bias(module): def hack_float_backup_parameter(module, name, num_bits): try: - data = dict(module.named_parameters())[name].data + param = dict(module.named_parameters())[name] + param_id = id(param) except KeyError: raise ValueError('Module has no Parameter named ' + name) - module.register_parameter(FP_BKP_PREFIX + name, nn.Parameter(data)) + module.register_parameter(FP_BKP_PREFIX + name, param) + assert id(getattr(module, FP_BKP_PREFIX + name)) == param_id delattr(module, name) - module.register_buffer(name, torch.zeros_like(data)) + module.register_buffer(name, torch.zeros_like(param)) first = False if not hasattr(module, 'repr_mod'): @@ -240,9 +242,8 @@ class Quantizer(object): # If an optimizer was passed, assume we need to update it if self.optimizer: - optimizer_type = type(self.optimizer) - new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) - self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) + for pg in self._get_new_optimizer_params_groups(): + self.optimizer.add_param_group(pg) self._post_prepare_model() @@ -308,19 +309,18 @@ class Quantizer(object): # For container we call recursively self._pre_process_container(module, full_name + '.') - def _get_updated_optimizer_params_groups(self): + def _get_new_optimizer_params_groups(self): """ - Returns a list of model parameter groups and optimizer hyper-parameter overrides, - as expected by the __init__ function of torch.optim.Optimizer. - This is called after all model changes were made in prepare_model, in case an Optimizer instance was - passed to __init__. + If the quantizer adds new trainable parameters to the model, this function should return a list of one + or more parameter groups pertaining. Each parameter group is expected to be a dict in the format + expected by torch.optim.Optimizer. + For details, See https://pytorch.org/docs/stable/optim.html#per-parameter-options Subclasses which add parameters to the model should override as needed. :return: List of parameter groups """ - # Default implementation - just return all model parameters as one group - return [{'params': self.model.parameters()}] + return list() def _post_prepare_model(self): pass