From 994e58d095e4f9591f12f28e217704e9bc69778e Mon Sep 17 00:00:00 2001 From: inner <10429190+innerNULL@users.noreply.github.com> Date: Wed, 23 Jan 2019 04:41:45 +0800 Subject: [PATCH] Fix pruner base class bug. (#131) --- distiller/pruning/pruner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distiller/pruning/pruner.py b/distiller/pruning/pruner.py index a09f87d..1dec718 100755 --- a/distiller/pruning/pruner.py +++ b/distiller/pruning/pruner.py @@ -36,5 +36,5 @@ def threshold_model(model, threshold): """ for name, p in model.named_parameters(): if 'weight' in name: - mask = distiller.threshold_mask(param.data, threshold) + mask = distiller.threshold_mask(p.data, threshold) p.data = p.data.mul_(mask) -- GitLab