diff --git a/distiller/thinning.py b/distiller/thinning.py index c92f880a714f174451bc22edac544889cec0367a..94758369c3b051e3c0ae5e0a5fea3e0c29e6c20e 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -84,7 +84,14 @@ def get_normalized_recipe(recipe): def param_name_2_layer_name(param_name): - return param_name[:-len('weights')] + """Convert a weights tensor's name to the name of the layer using the tensor. + + By convention, PyTorch modules name their weights parameters as self.weight + (see for example: torch.nn.modules.conv) which means that their fully-qualified + name when enumerating a model's parameters is the modules name followed by '.weight'. + We exploit this convention to convert a weights tensor name to the fully-qualified + module name.""" + return param_name[:-len('.weight')] def directives_equal(d1, d2):