diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py
index dfe7e254e8f2c03730b215f1fcccf89d10bd6a4b..25d57db22ddc0bafab8c34ec42599b43840db6b9 100755
--- a/distiller/pruning/splicing_pruner.py
+++ b/distiller/pruning/splicing_pruner.py
@@ -81,8 +81,10 @@ class SplicingPruner(_ParameterPruner):
         # We followed the example implementation from Yiwen Guo in Caffe, and used the
         # weight tensor's starting mean and std.
         # This is very similar to the initialization performed by distiller.SensitivityPruner.
-
-        masked_weights = param.mul(zeros_mask_dict[param_name].mask).abs()
-        a = masked_weights.ge(threshold_low)
-        b = a & zeros_mask_dict[param_name].mask.type(torch.cuda.ByteTensor)
-        zeros_mask_dict[param_name].mask = (b | masked_weights.ge(threshold_hi)).type(torch.cuda.FloatTensor)
+    
+        mask = zeros_mask_dict[param_name].mask
+        zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type())
+        weights_abs = param.abs()
+        new_mask = torch.where(threshold_low > weights_abs, zeros, mask)
+        new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask)
+        zeros_mask_dict[param_name].mask = new_mask