diff --git a/distiller/scheduler.py b/distiller/scheduler.py index fe4e4637d2845d73772983cb4f80e1811ae4c0f8..d052d51bc6a563ccad9427dc72d3a7dc2c844c83 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -53,6 +53,15 @@ class ParameterMasker(object): return tensor +def create_model_masks_dict(model): + """A convinience function to create a dictionary of paramter maskers for a model""" + zeros_mask_dict = {} + for name, param in model.named_parameters(): + masker = ParameterMasker(name) + zeros_mask_dict[name] = masker + return zeros_mask_dict + + class CompressionScheduler(object): """Responsible for scheduling pruning and masking parameters.