diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 4bd04d465b75d2c59488606bc99d981ef0fd68d2..e6bcd551d07c05642bc6090284e36fa94b9dc0f2 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -224,7 +224,7 @@ class CompressionScheduler(object): Currently the scheduler state is comprised only of the set of pruning masks. - Arguments: + Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. It is a dictionary of parameter names (keys) and parameter masks (values). @@ -251,6 +251,21 @@ class CompressionScheduler(object): if masker.mask is not None: masker.mask = masker.mask.to(device) + def init_from_masks_dict(self, masks_dict, normalize_dataparallel_keys=False): + """This is a convenience function to initialize a CompressionScheduler from a dictionary + + Args: + masks_dict (list): A dictionary formatted as {parameter_name: 4D mask tensor} + normalize_dataparallel_keys (bool): indicates if we should convert the keys from + DataParallel format. + """ + for name, mask in self.zeros_mask_dict.items(): + if name not in masks_dict: + masks_dict[name] = None + state = {'masks_dict': masks_dict} + + self.load_state_dict(state, normalize_dataparallel_keys) + @staticmethod def verify_policy_loss(policy_loss): if not isinstance(policy_loss, PolicyLoss):