From c9abf1f9f8575fa29a01fdd3e56b686bcb3cbb95 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 25 Jul 2018 13:04:13 +0300 Subject: [PATCH] create_model_masks_dict: added create_model_masks_dict This is a convinence function used by customers of the scheduler, and might change location in the future. --- distiller/scheduler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index fe4e463..d052d51 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -53,6 +53,15 @@ class ParameterMasker(object): return tensor +def create_model_masks_dict(model): + """A convinience function to create a dictionary of paramter maskers for a model""" + zeros_mask_dict = {} + for name, param in model.named_parameters(): + masker = ParameterMasker(name) + zeros_mask_dict[name] = masker + return zeros_mask_dict + + class CompressionScheduler(object): """Responsible for scheduling pruning and masking parameters. -- GitLab