diff --git a/distiller/scheduler.py b/distiller/scheduler.py index e6bcd551d07c05642bc6090284e36fa94b9dc0f2..a0b7490991db206448b2b733a7d24874369203f6 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -24,58 +24,10 @@ import torch from .quantization.quantizer import FP_BKP_PREFIX from .policy import PolicyLoss, LossComponent from .utils import model_device, normalize_module_name -msglogger = logging.getLogger() -import distiller - - -class ParameterMasker(object): - def __init__(self, param_name): - msglogger.debug('Created masker for parameter {0}'.format(param_name)) - self.mask = None # Mask lazily initialized by pruners - self.param_name = param_name # For debug/logging purposes - self.is_regularization_mask = False - self.use_double_copies = False - self.mask_on_forward_only = False - self.unmasked_copy = None - self.backward_hook_handle = None - - def apply_mask(self, parameter): - """Apply a mask on the weights tensor (parameter).""" - if self.mask is None: - msglogger.debug('No mask for parameter {0}'.format(self.param_name)) - return - if self.use_double_copies: - self.unmasked_copy = parameter.clone().detach() - self.mask_tensor(parameter) - if self.is_regularization_mask: - self.mask = None - return parameter - - def mask_tensor(self, tensor): - if self.mask is not None: - tensor.data.mul_(self.mask) - - def mask_gradient(self, gradient): - if self.mask is not None: - return gradient.mul(self.mask) - - def revert_weights(self, parameter): - if not self.use_double_copies or self.unmasked_copy is None: - msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name)) - return - #msglogger.info('Parameter {} before {}'.format(self.param_name, distiller.sparsity(parameter))) - parameter.data.copy_(self.unmasked_copy) - #msglogger.info('Parameter {} after {}'.format(self.param_name, distiller.sparsity(parameter))) - self.unmasked_copy = None -def create_model_masks_dict(model): - """A convinience function to create a dictionary of paramter maskers for a model""" - zeros_mask_dict = {} - for name, param in model.named_parameters(): - masker = ParameterMasker(name) - zeros_mask_dict[name] = masker - return zeros_mask_dict +__all__ = ["CompressionScheduler", "ParameterMasker", "create_model_masks_dict"] +msglogger = logging.getLogger() class CompressionScheduler(object): @@ -87,11 +39,8 @@ class CompressionScheduler(object): self.device = device self.policies = {} self.sched_metadata = {} - self.zeros_mask_dict = {} # Create the masker objects and place them in a dictionary indexed by the parameter name - for name, param in self.model.named_parameters(): - masker = ParameterMasker(name) - self.zeros_mask_dict[name] = masker + self.zeros_mask_dict = create_model_masks_dict(model) def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1): """Add a new policy to the schedule. @@ -278,3 +227,54 @@ class CompressionScheduler(object): raise TypeError("Expected an instance of " + LossComponent.__name__ + " or a list of such instances") return curr_loss_components + + +class ParameterMasker(object): + """A ParameterMasker can mask a parameter tensor or a gradients tensor. + + It is used when pruning DNN weights. + """ + def __init__(self, param_name): + self.mask = None # Mask lazily initialized by pruners + self.param_name = param_name # For debug/logging purposes + self.is_regularization_mask = False + self.use_double_copies = False + self.mask_on_forward_only = False + self.unmasked_copy = None + self.backward_hook_handle = None + + def apply_mask(self, parameter): + """Apply a mask on the weights tensor (parameter).""" + if self.mask is None: + msglogger.debug('No mask for parameter {0}'.format(self.param_name)) + return + if self.use_double_copies: + self.unmasked_copy = parameter.clone().detach() + self.mask_tensor(parameter) + if self.is_regularization_mask: + self.mask = None + return parameter + + def mask_tensor(self, tensor): + if self.mask is not None: + tensor.data.mul_(self.mask) + + def mask_gradient(self, gradient): + if self.mask is not None: + return gradient.mul(self.mask) + + def revert_weights(self, parameter): + if not self.use_double_copies or self.unmasked_copy is None: + msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name)) + return + parameter.data.copy_(self.unmasked_copy) + self.unmasked_copy = None + + +def create_model_masks_dict(model): + """A convenience function to create a dictionary of parameter maskers for a model""" + zeros_mask_dict = {} + for name, param in model.named_parameters(): + masker = ParameterMasker(name) + zeros_mask_dict[name] = masker + return zeros_mask_dict