From c9abf1f9f8575fa29a01fdd3e56b686bcb3cbb95 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 25 Jul 2018 13:04:13 +0300
Subject: [PATCH] create_model_masks_dict: added create_model_masks_dict

This is a convinence function used by customers of the scheduler,
and might change location in the future.
---
 distiller/scheduler.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index fe4e463..d052d51 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -53,6 +53,15 @@ class ParameterMasker(object):
         return tensor
 
 
+def create_model_masks_dict(model):
+    """A convinience function to create a dictionary of paramter maskers for a model"""
+    zeros_mask_dict = {}
+    for name, param in model.named_parameters():
+        masker = ParameterMasker(name)
+        zeros_mask_dict[name] = masker
+    return zeros_mask_dict
+
+
 class CompressionScheduler(object):
     """Responsible for scheduling pruning and masking parameters.
 
-- 
GitLab