From c26428783bbbee67330630309e4b8fab12bfa2a9 Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Mon, 25 Mar 2019 14:39:10 +0200 Subject: [PATCH] Splicing Pruner: simplify the splicing code Rewrote the splicing logic with simpler code. --- distiller/pruning/splicing_pruner.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py index dfe7e25..25d57db 100755 --- a/distiller/pruning/splicing_pruner.py +++ b/distiller/pruning/splicing_pruner.py @@ -81,8 +81,10 @@ class SplicingPruner(_ParameterPruner): # We followed the example implementation from Yiwen Guo in Caffe, and used the # weight tensor's starting mean and std. # This is very similar to the initialization performed by distiller.SensitivityPruner. - - masked_weights = param.mul(zeros_mask_dict[param_name].mask).abs() - a = masked_weights.ge(threshold_low) - b = a & zeros_mask_dict[param_name].mask.type(torch.cuda.ByteTensor) - zeros_mask_dict[param_name].mask = (b | masked_weights.ge(threshold_hi)).type(torch.cuda.FloatTensor) + + mask = zeros_mask_dict[param_name].mask + zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type()) + weights_abs = param.abs() + new_mask = torch.where(threshold_low > weights_abs, zeros, mask) + new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask) + zeros_mask_dict[param_name].mask = new_mask -- GitLab