Skip to content
Snippets Groups Projects
Commit c9abf1f9 authored by Neta Zmora's avatar Neta Zmora
Browse files

create_model_masks_dict: added create_model_masks_dict

This is a convinence function used by customers of the scheduler,
and might change location in the future.
parent e46e196e
No related branches found
No related tags found
No related merge requests found
......@@ -53,6 +53,15 @@ class ParameterMasker(object):
return tensor
def create_model_masks_dict(model):
"""A convinience function to create a dictionary of paramter maskers for a model"""
zeros_mask_dict = {}
for name, param in model.named_parameters():
masker = ParameterMasker(name)
zeros_mask_dict[name] = masker
return zeros_mask_dict
class CompressionScheduler(object):
"""Responsible for scheduling pruning and masking parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment