From 6d9afab5528dff994816b85a4f3f7a8c5c2ffee7 Mon Sep 17 00:00:00 2001 From: levzlotnik <lev.zlotnik@intel.com> Date: Mon, 2 Dec 2019 15:41:15 +0200 Subject: [PATCH] Added model_setattr - sets an parameter/buffer/module of a model by name relative to the root of the model. --- distiller/utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/distiller/utils.py b/distiller/utils.py index 63127ff..3c5c6b3 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -752,5 +752,37 @@ def convert_tensors_recursively_to(val, *args, **kwargs): return val +def model_setattr(model, attr_name, val, register=False): + """ + Sets attribute of a model, through the entire hierarchy. + Args: + model (nn.Module): the model. + attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>() + val: the value of the attribute + register (bool): if True - register_buffer(val) if val is a torch.Tensor and + register_parameter(val) if it's an nn.Parameter. + """ + def split_name(name): + if '.' in name: + return name.rsplit('.', 1) + else: + return '', name + modules_dict = OrderedDict(model.named_modules()) + lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name) + while lowest_depth_container_name and lowest_depth_container_name not in modules_dict: + container_name, attr = split_name(lowest_depth_container_name) + lowest_depth_container_name = container_name + lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name) + lowest_depth_container = modules_dict[lowest_depth_container_name] # type: nn.Module + + if register and torch.is_tensor(val): + if isinstance(val, nn.Parameter): + lowest_depth_container.register_parameter(lowest_depth_attr_name, val) + else: + lowest_depth_container.register_buffer(lowest_depth_attr_name, val) + else: + setattr(lowest_depth_container, lowest_depth_attr_name, val) + + def param_name_2_module_name(param_name): return '.'.join(param_name.split('.')[:-1]) \ No newline at end of file -- GitLab