diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 4bd04d465b75d2c59488606bc99d981ef0fd68d2..e6bcd551d07c05642bc6090284e36fa94b9dc0f2 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -224,7 +224,7 @@ class CompressionScheduler(object):
 
         Currently the scheduler state is comprised only of the set of pruning masks.
 
-        Arguments:
+        Args:
             state_dict (dict): scheduler state. Should be an object returned
                 from a call to :meth:`state_dict`. It is a dictionary of parameter
                 names (keys) and parameter masks (values).
@@ -251,6 +251,21 @@ class CompressionScheduler(object):
             if masker.mask is not None:
                 masker.mask = masker.mask.to(device)
 
+    def init_from_masks_dict(self, masks_dict, normalize_dataparallel_keys=False):
+        """This is a convenience function to initialize a CompressionScheduler from a dictionary
+
+        Args:
+            masks_dict (list): A dictionary formatted as {parameter_name: 4D mask tensor}
+            normalize_dataparallel_keys (bool): indicates if we should convert the keys from
+                DataParallel format.
+        """
+        for name, mask in self.zeros_mask_dict.items():
+            if name not in masks_dict:
+                masks_dict[name] = None
+        state = {'masks_dict': masks_dict}
+
+        self.load_state_dict(state, normalize_dataparallel_keys)
+
     @staticmethod
     def verify_policy_loss(policy_loss):
         if not isinstance(policy_loss, PolicyLoss):