Skip to content
Snippets Groups Projects
Commit 6d9afab5 authored by levzlotnik's avatar levzlotnik
Browse files

Added model_setattr - sets an parameter/buffer/module

of a model by name relative to the root of the model.
parent 0520ffaf
No related branches found
No related tags found
No related merge requests found
...@@ -752,5 +752,37 @@ def convert_tensors_recursively_to(val, *args, **kwargs): ...@@ -752,5 +752,37 @@ def convert_tensors_recursively_to(val, *args, **kwargs):
return val return val
def model_setattr(model, attr_name, val, register=False):
"""
Sets attribute of a model, through the entire hierarchy.
Args:
model (nn.Module): the model.
attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>()
val: the value of the attribute
register (bool): if True - register_buffer(val) if val is a torch.Tensor and
register_parameter(val) if it's an nn.Parameter.
"""
def split_name(name):
if '.' in name:
return name.rsplit('.', 1)
else:
return '', name
modules_dict = OrderedDict(model.named_modules())
lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name)
while lowest_depth_container_name and lowest_depth_container_name not in modules_dict:
container_name, attr = split_name(lowest_depth_container_name)
lowest_depth_container_name = container_name
lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name)
lowest_depth_container = modules_dict[lowest_depth_container_name] # type: nn.Module
if register and torch.is_tensor(val):
if isinstance(val, nn.Parameter):
lowest_depth_container.register_parameter(lowest_depth_attr_name, val)
else:
lowest_depth_container.register_buffer(lowest_depth_attr_name, val)
else:
setattr(lowest_depth_container, lowest_depth_attr_name, val)
def param_name_2_module_name(param_name): def param_name_2_module_name(param_name):
return '.'.join(param_name.split('.')[:-1]) return '.'.join(param_name.split('.')[:-1])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment