diff --git a/distiller/regularization/regularizer.py b/distiller/regularization/regularizer.py index 182e3846f97bae02fec09b250df72753b9117276..ec4666e2e49abbab105ccae4643af58bf84654bd 100755 --- a/distiller/regularization/regularizer.py +++ b/distiller/regularization/regularizer.py @@ -13,16 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import torch +import torch.nn as nn EPSILON = 1e-8 + class _Regularizer(object): def __init__(self, name, model, reg_regims, threshold_criteria): """Regularization base class. Args: - reg_regims: regularization regiment. A dictionary of - reg_regims[<param-name>] = [ lambda, structure-type] + name (str): the name of the regularizer. + model (nn.Module): the model on which to apply regularization. + reg_regims (dict[str, float or tuple[float, Any]]): regularization regiment. A dictionary of + reg_regims[<param-name>] = [ lambda[, additional_configuration]] + threshold_criteria (str): the criterion for which to calculate the threshold. """ self.name = name self.model = model @@ -30,7 +35,24 @@ class _Regularizer(object): self.threshold_criteria = threshold_criteria def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): + """ + Applies the regularization loss onto regularization loss. + Args: + param (nn.Parameter): the parameter on which to calculate the regularization + param_name (str): the name of the parameter relative to top level module. + regularizer_loss (torch.Tensor): the previous regularization loss calculated, + zeros_mask_dict (dict): the masks configuration. + Returns: + torch.Tensor: regularization_loss after applying the additional loss from current parameter. + """ raise NotImplementedError def threshold(self, param, param_name, zeros_mask_dict): + """ + Calculates the threshold for pruning. + Args: + param (nn.Parameter): the parameter on which to calculate the regularization + param_name (str): the name of the parameter relative to top level module. + zeros_mask_dict (dict): the masks configuration. + """ raise NotImplementedError