diff --git a/distiller/utils.py b/distiller/utils.py index 63127ff8b818637cf97a7ea7de685217c4dc151f..3c5c6b3c23b389a2c25e87be0c275fec173b8bdc 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -752,5 +752,37 @@ def convert_tensors_recursively_to(val, *args, **kwargs): return val +def model_setattr(model, attr_name, val, register=False): + """ + Sets attribute of a model, through the entire hierarchy. + Args: + model (nn.Module): the model. + attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>() + val: the value of the attribute + register (bool): if True - register_buffer(val) if val is a torch.Tensor and + register_parameter(val) if it's an nn.Parameter. + """ + def split_name(name): + if '.' in name: + return name.rsplit('.', 1) + else: + return '', name + modules_dict = OrderedDict(model.named_modules()) + lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name) + while lowest_depth_container_name and lowest_depth_container_name not in modules_dict: + container_name, attr = split_name(lowest_depth_container_name) + lowest_depth_container_name = container_name + lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name) + lowest_depth_container = modules_dict[lowest_depth_container_name] # type: nn.Module + + if register and torch.is_tensor(val): + if isinstance(val, nn.Parameter): + lowest_depth_container.register_parameter(lowest_depth_attr_name, val) + else: + lowest_depth_container.register_buffer(lowest_depth_attr_name, val) + else: + setattr(lowest_depth_container, lowest_depth_attr_name, val) + + def param_name_2_module_name(param_name): return '.'.join(param_name.split('.')[:-1]) \ No newline at end of file