diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py index dfe7e254e8f2c03730b215f1fcccf89d10bd6a4b..25d57db22ddc0bafab8c34ec42599b43840db6b9 100755 --- a/distiller/pruning/splicing_pruner.py +++ b/distiller/pruning/splicing_pruner.py @@ -81,8 +81,10 @@ class SplicingPruner(_ParameterPruner): # We followed the example implementation from Yiwen Guo in Caffe, and used the # weight tensor's starting mean and std. # This is very similar to the initialization performed by distiller.SensitivityPruner. - - masked_weights = param.mul(zeros_mask_dict[param_name].mask).abs() - a = masked_weights.ge(threshold_low) - b = a & zeros_mask_dict[param_name].mask.type(torch.cuda.ByteTensor) - zeros_mask_dict[param_name].mask = (b | masked_weights.ge(threshold_hi)).type(torch.cuda.FloatTensor) + + mask = zeros_mask_dict[param_name].mask + zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type()) + weights_abs = param.abs() + new_mask = torch.where(threshold_low > weights_abs, zeros, mask) + new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask) + zeros_mask_dict[param_name].mask = new_mask