Skip to content
Snippets Groups Projects
Commit 4a331d73 authored by Neta Zmora's avatar Neta Zmora
Browse files

scheduler.py: code refactoring (non-functional changes)

parent 961bfc89
No related branches found
No related tags found
No related merge requests found
...@@ -24,58 +24,10 @@ import torch ...@@ -24,58 +24,10 @@ import torch
from .quantization.quantizer import FP_BKP_PREFIX from .quantization.quantizer import FP_BKP_PREFIX
from .policy import PolicyLoss, LossComponent from .policy import PolicyLoss, LossComponent
from .utils import model_device, normalize_module_name from .utils import model_device, normalize_module_name
msglogger = logging.getLogger()
import distiller
class ParameterMasker(object):
def __init__(self, param_name):
msglogger.debug('Created masker for parameter {0}'.format(param_name))
self.mask = None # Mask lazily initialized by pruners
self.param_name = param_name # For debug/logging purposes
self.is_regularization_mask = False
self.use_double_copies = False
self.mask_on_forward_only = False
self.unmasked_copy = None
self.backward_hook_handle = None
def apply_mask(self, parameter):
"""Apply a mask on the weights tensor (parameter)."""
if self.mask is None:
msglogger.debug('No mask for parameter {0}'.format(self.param_name))
return
if self.use_double_copies:
self.unmasked_copy = parameter.clone().detach()
self.mask_tensor(parameter)
if self.is_regularization_mask:
self.mask = None
return parameter
def mask_tensor(self, tensor):
if self.mask is not None:
tensor.data.mul_(self.mask)
def mask_gradient(self, gradient):
if self.mask is not None:
return gradient.mul(self.mask)
def revert_weights(self, parameter):
if not self.use_double_copies or self.unmasked_copy is None:
msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
return
#msglogger.info('Parameter {} before {}'.format(self.param_name, distiller.sparsity(parameter)))
parameter.data.copy_(self.unmasked_copy)
#msglogger.info('Parameter {} after {}'.format(self.param_name, distiller.sparsity(parameter)))
self.unmasked_copy = None
def create_model_masks_dict(model): __all__ = ["CompressionScheduler", "ParameterMasker", "create_model_masks_dict"]
"""A convinience function to create a dictionary of paramter maskers for a model""" msglogger = logging.getLogger()
zeros_mask_dict = {}
for name, param in model.named_parameters():
masker = ParameterMasker(name)
zeros_mask_dict[name] = masker
return zeros_mask_dict
class CompressionScheduler(object): class CompressionScheduler(object):
...@@ -87,11 +39,8 @@ class CompressionScheduler(object): ...@@ -87,11 +39,8 @@ class CompressionScheduler(object):
self.device = device self.device = device
self.policies = {} self.policies = {}
self.sched_metadata = {} self.sched_metadata = {}
self.zeros_mask_dict = {}
# Create the masker objects and place them in a dictionary indexed by the parameter name # Create the masker objects and place them in a dictionary indexed by the parameter name
for name, param in self.model.named_parameters(): self.zeros_mask_dict = create_model_masks_dict(model)
masker = ParameterMasker(name)
self.zeros_mask_dict[name] = masker
def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1): def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
"""Add a new policy to the schedule. """Add a new policy to the schedule.
...@@ -278,3 +227,54 @@ class CompressionScheduler(object): ...@@ -278,3 +227,54 @@ class CompressionScheduler(object):
raise TypeError("Expected an instance of " + LossComponent.__name__ + raise TypeError("Expected an instance of " + LossComponent.__name__ +
" or a list of such instances") " or a list of such instances")
return curr_loss_components return curr_loss_components
class ParameterMasker(object):
"""A ParameterMasker can mask a parameter tensor or a gradients tensor.
It is used when pruning DNN weights.
"""
def __init__(self, param_name):
self.mask = None # Mask lazily initialized by pruners
self.param_name = param_name # For debug/logging purposes
self.is_regularization_mask = False
self.use_double_copies = False
self.mask_on_forward_only = False
self.unmasked_copy = None
self.backward_hook_handle = None
def apply_mask(self, parameter):
"""Apply a mask on the weights tensor (parameter)."""
if self.mask is None:
msglogger.debug('No mask for parameter {0}'.format(self.param_name))
return
if self.use_double_copies:
self.unmasked_copy = parameter.clone().detach()
self.mask_tensor(parameter)
if self.is_regularization_mask:
self.mask = None
return parameter
def mask_tensor(self, tensor):
if self.mask is not None:
tensor.data.mul_(self.mask)
def mask_gradient(self, gradient):
if self.mask is not None:
return gradient.mul(self.mask)
def revert_weights(self, parameter):
if not self.use_double_copies or self.unmasked_copy is None:
msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
return
parameter.data.copy_(self.unmasked_copy)
self.unmasked_copy = None
def create_model_masks_dict(model):
"""A convenience function to create a dictionary of parameter maskers for a model"""
zeros_mask_dict = {}
for name, param in model.named_parameters():
masker = ParameterMasker(name)
zeros_mask_dict[name] = masker
return zeros_mask_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment