Skip to content
Snippets Groups Projects
Unverified Commit 4b6b5b19 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

QAT: Better handling of optimizer and of creation of fp32 weights copy (#399)

* Create float copy such that the actual tensor being learned stays
  the same
* This way the optimizer doesn't have to be re-created, just need to
  add parameter groups if algo requires it (e.g. PACT)
* This also means we don't care about pre-existing parameter groups,
  as opposed to the previous implementation which ASSUMED a single
  existing group
parent 3710c464
No related branches found
No related tags found
No related merge requests found
...@@ -192,9 +192,8 @@ class PACTQuantizer(Quantizer): ...@@ -192,9 +192,8 @@ class PACTQuantizer(Quantizer):
# In PACT, LearnedClippedLinearQuantization is used for activation, which contains a learnt 'clip_val' parameter # In PACT, LearnedClippedLinearQuantization is used for activation, which contains a learnt 'clip_val' parameter
# We optimize this value separately from the main model parameters # We optimize this value separately from the main model parameters
def _get_updated_optimizer_params_groups(self): def _get_new_optimizer_params_groups(self):
base_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' not in name]}
clip_val_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' in name]} clip_val_group = {'params': [param for name, param in self.model.named_parameters() if 'clip_val' in name]}
if self.act_clip_decay is not None: if self.act_clip_decay is not None:
clip_val_group['weight_decay'] = self.act_clip_decay clip_val_group['weight_decay'] = self.act_clip_decay
return [base_group, clip_val_group] return [clip_val_group]
...@@ -37,12 +37,14 @@ def has_bias(module): ...@@ -37,12 +37,14 @@ def has_bias(module):
def hack_float_backup_parameter(module, name, num_bits): def hack_float_backup_parameter(module, name, num_bits):
try: try:
data = dict(module.named_parameters())[name].data param = dict(module.named_parameters())[name]
param_id = id(param)
except KeyError: except KeyError:
raise ValueError('Module has no Parameter named ' + name) raise ValueError('Module has no Parameter named ' + name)
module.register_parameter(FP_BKP_PREFIX + name, nn.Parameter(data)) module.register_parameter(FP_BKP_PREFIX + name, param)
assert id(getattr(module, FP_BKP_PREFIX + name)) == param_id
delattr(module, name) delattr(module, name)
module.register_buffer(name, torch.zeros_like(data)) module.register_buffer(name, torch.zeros_like(param))
first = False first = False
if not hasattr(module, 'repr_mod'): if not hasattr(module, 'repr_mod'):
...@@ -240,9 +242,8 @@ class Quantizer(object): ...@@ -240,9 +242,8 @@ class Quantizer(object):
# If an optimizer was passed, assume we need to update it # If an optimizer was passed, assume we need to update it
if self.optimizer: if self.optimizer:
optimizer_type = type(self.optimizer) for pg in self._get_new_optimizer_params_groups():
new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) self.optimizer.add_param_group(pg)
self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
self._post_prepare_model() self._post_prepare_model()
...@@ -308,19 +309,18 @@ class Quantizer(object): ...@@ -308,19 +309,18 @@ class Quantizer(object):
# For container we call recursively # For container we call recursively
self._pre_process_container(module, full_name + '.') self._pre_process_container(module, full_name + '.')
def _get_updated_optimizer_params_groups(self): def _get_new_optimizer_params_groups(self):
""" """
Returns a list of model parameter groups and optimizer hyper-parameter overrides, If the quantizer adds new trainable parameters to the model, this function should return a list of one
as expected by the __init__ function of torch.optim.Optimizer. or more parameter groups pertaining. Each parameter group is expected to be a dict in the format
This is called after all model changes were made in prepare_model, in case an Optimizer instance was expected by torch.optim.Optimizer.
passed to __init__. For details, See https://pytorch.org/docs/stable/optim.html#per-parameter-options
Subclasses which add parameters to the model should override as needed. Subclasses which add parameters to the model should override as needed.
:return: List of parameter groups :return: List of parameter groups
""" """
# Default implementation - just return all model parameters as one group return list()
return [{'params': self.model.parameters()}]
def _post_prepare_model(self): def _post_prepare_model(self):
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment