From 2058990a7f5946af6a992c07817169b490629d69 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 31 Oct 2019 13:23:10 +0200
Subject: [PATCH] distiller/utils.py: add param_name_2_module_name

- add function `param_name_2_module_name` to help convert from a
module's .weight or .bias parameter tensor name, to a fully-qualified
module name

- remove dead code
---
 distiller/utils.py | 33 ++-------------------------------
 1 file changed, 2 insertions(+), 31 deletions(-)

diff --git a/distiller/utils.py b/distiller/utils.py
index 3b342b5..f50adca 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -740,34 +740,5 @@ def convert_tensors_recursively_to(val, *args, **kwargs):
     return val
 
 
-# TODO: Is this needed?
-def model_setattr(model, attr_name, val, register=False):
-    """
-    Sets attribute of a model, through the entire hierarchy.
-    Args:
-        model (nn.Module): the model.
-        attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>()
-        val: the value of the attribute
-        register (bool): if True - register_buffer(val) if val is a torch.Tensor and
-          register_parameter(val) if it's an nn.Parameter.
-    """
-    def split_name(name):
-        if '.' in name:
-            return name.rsplit('.', 1)
-        else:
-            return '', name
-    modules_dict = OrderedDict(model.named_modules())
-    lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name)
-    while lowest_depth_container_name and lowest_depth_container_name not in modules_dict:
-        container_name, attr = split_name(lowest_depth_container_name)
-        lowest_depth_container_name = container_name
-        lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name)
-    lowest_depth_container = modules_dict[lowest_depth_container_name]  # type: nn.Module
-
-    if register and torch.is_tensor(val):
-        if isinstance(val, nn.Parameter):
-            lowest_depth_container.register_parameter(lowest_depth_attr_name, val)
-        else:
-            lowest_depth_container.register_buffer(lowest_depth_attr_name, val)
-    else:
-        setattr(lowest_depth_container, lowest_depth_attr_name, val)
+def param_name_2_module_name(param_name):
+    return '.'.join(param_name.split('.')[:-1])
\ No newline at end of file
-- 
GitLab