From 2058990a7f5946af6a992c07817169b490629d69 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 31 Oct 2019 13:23:10 +0200 Subject: [PATCH] distiller/utils.py: add param_name_2_module_name - add function `param_name_2_module_name` to help convert from a module's .weight or .bias parameter tensor name, to a fully-qualified module name - remove dead code --- distiller/utils.py | 33 ++------------------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/distiller/utils.py b/distiller/utils.py index 3b342b5..f50adca 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -740,34 +740,5 @@ def convert_tensors_recursively_to(val, *args, **kwargs): return val -# TODO: Is this needed? -def model_setattr(model, attr_name, val, register=False): - """ - Sets attribute of a model, through the entire hierarchy. - Args: - model (nn.Module): the model. - attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>() - val: the value of the attribute - register (bool): if True - register_buffer(val) if val is a torch.Tensor and - register_parameter(val) if it's an nn.Parameter. - """ - def split_name(name): - if '.' in name: - return name.rsplit('.', 1) - else: - return '', name - modules_dict = OrderedDict(model.named_modules()) - lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name) - while lowest_depth_container_name and lowest_depth_container_name not in modules_dict: - container_name, attr = split_name(lowest_depth_container_name) - lowest_depth_container_name = container_name - lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name) - lowest_depth_container = modules_dict[lowest_depth_container_name] # type: nn.Module - - if register and torch.is_tensor(val): - if isinstance(val, nn.Parameter): - lowest_depth_container.register_parameter(lowest_depth_attr_name, val) - else: - lowest_depth_container.register_buffer(lowest_depth_attr_name, val) - else: - setattr(lowest_depth_container, lowest_depth_attr_name, val) +def param_name_2_module_name(param_name): + return '.'.join(param_name.split('.')[:-1]) \ No newline at end of file -- GitLab