diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py index a09f87d08ee614251b88fd6d6902f215c2404ba6..1dec7183ce394124d1d9d4d5b185ceee735de011 100755 --- a/distiller/pruning/pruner.py +++ b/distiller/pruning/pruner.py @@ -36,5 +36,5 @@ def threshold_model(model, threshold): """ for name, p in model.named_parameters(): if 'weight' in name: - mask = distiller.threshold_mask(param.data, threshold) + mask = distiller.threshold_mask(p.data, threshold) p.data = p.data.mul_(mask)