Skip to content
Snippets Groups Projects
Commit 994e58d0 authored by inner's avatar inner Committed by Neta Zmora
Browse files

Fix pruner base class bug. (#131)

parent 24381169
No related branches found
No related tags found
No related merge requests found
...@@ -36,5 +36,5 @@ def threshold_model(model, threshold): ...@@ -36,5 +36,5 @@ def threshold_model(model, threshold):
""" """
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if 'weight' in name: if 'weight' in name:
mask = distiller.threshold_mask(param.data, threshold) mask = distiller.threshold_mask(p.data, threshold)
p.data = p.data.mul_(mask) p.data = p.data.mul_(mask)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment