diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index e6bcd551d07c05642bc6090284e36fa94b9dc0f2..a0b7490991db206448b2b733a7d24874369203f6 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -24,58 +24,10 @@ import torch
 from .quantization.quantizer import FP_BKP_PREFIX
 from .policy import PolicyLoss, LossComponent
 from .utils import model_device, normalize_module_name
-msglogger = logging.getLogger()
-import distiller
-
-
-class ParameterMasker(object):
-    def __init__(self, param_name):
-        msglogger.debug('Created masker for parameter {0}'.format(param_name))
-        self.mask = None                # Mask lazily initialized by pruners
-        self.param_name = param_name    # For debug/logging purposes
-        self.is_regularization_mask = False
-        self.use_double_copies = False
-        self.mask_on_forward_only = False
-        self.unmasked_copy = None
-        self.backward_hook_handle = None
-
-    def apply_mask(self, parameter):
-        """Apply a mask on the weights tensor (parameter)."""
-        if self.mask is None:
-            msglogger.debug('No mask for parameter {0}'.format(self.param_name))
-            return
-        if self.use_double_copies:
-            self.unmasked_copy = parameter.clone().detach()
-        self.mask_tensor(parameter)
-        if self.is_regularization_mask:
-            self.mask = None
-        return parameter
-
-    def mask_tensor(self, tensor):
-        if self.mask is not None:
-            tensor.data.mul_(self.mask)
-
-    def mask_gradient(self, gradient):
-        if self.mask is not None:
-            return gradient.mul(self.mask)
-
-    def revert_weights(self, parameter):
-        if not self.use_double_copies or self.unmasked_copy is None:
-            msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
-            return
-        #msglogger.info('Parameter {} before {}'.format(self.param_name, distiller.sparsity(parameter)))
-        parameter.data.copy_(self.unmasked_copy)
-        #msglogger.info('Parameter {} after {}'.format(self.param_name, distiller.sparsity(parameter)))
-        self.unmasked_copy = None
 
 
-def create_model_masks_dict(model):
-    """A convinience function to create a dictionary of paramter maskers for a model"""
-    zeros_mask_dict = {}
-    for name, param in model.named_parameters():
-        masker = ParameterMasker(name)
-        zeros_mask_dict[name] = masker
-    return zeros_mask_dict
+__all__ = ["CompressionScheduler", "ParameterMasker", "create_model_masks_dict"]
+msglogger = logging.getLogger()
 
 
 class CompressionScheduler(object):
@@ -87,11 +39,8 @@ class CompressionScheduler(object):
         self.device = device
         self.policies = {}
         self.sched_metadata = {}
-        self.zeros_mask_dict = {}
         # Create the masker objects and place them in a dictionary indexed by the parameter name
-        for name, param in self.model.named_parameters():
-            masker = ParameterMasker(name)
-            self.zeros_mask_dict[name] = masker
+        self.zeros_mask_dict = create_model_masks_dict(model)
 
     def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
         """Add a new policy to the schedule.
@@ -278,3 +227,54 @@ class CompressionScheduler(object):
             raise TypeError("Expected an instance of " + LossComponent.__name__ +
                             " or a list of such instances")
         return curr_loss_components
+
+
+class ParameterMasker(object):
+    """A ParameterMasker can mask a parameter tensor or a gradients tensor.
+
+    It is used when pruning DNN weights.
+    """
+    def __init__(self, param_name):
+        self.mask = None                # Mask lazily initialized by pruners
+        self.param_name = param_name    # For debug/logging purposes
+        self.is_regularization_mask = False
+        self.use_double_copies = False
+        self.mask_on_forward_only = False
+        self.unmasked_copy = None
+        self.backward_hook_handle = None
+
+    def apply_mask(self, parameter):
+        """Apply a mask on the weights tensor (parameter)."""
+        if self.mask is None:
+            msglogger.debug('No mask for parameter {0}'.format(self.param_name))
+            return
+        if self.use_double_copies:
+            self.unmasked_copy = parameter.clone().detach()
+        self.mask_tensor(parameter)
+        if self.is_regularization_mask:
+            self.mask = None
+        return parameter
+
+    def mask_tensor(self, tensor):
+        if self.mask is not None:
+            tensor.data.mul_(self.mask)
+
+    def mask_gradient(self, gradient):
+        if self.mask is not None:
+            return gradient.mul(self.mask)
+
+    def revert_weights(self, parameter):
+        if not self.use_double_copies or self.unmasked_copy is None:
+            msglogger.debug('Parameter {0} does not maintain double copies'.format(self.param_name))
+            return
+        parameter.data.copy_(self.unmasked_copy)
+        self.unmasked_copy = None
+
+
+def create_model_masks_dict(model):
+    """A convenience function to create a dictionary of parameter maskers for a model"""
+    zeros_mask_dict = {}
+    for name, param in model.named_parameters():
+        masker = ParameterMasker(name)
+        zeros_mask_dict[name] = masker
+    return zeros_mask_dict