From c26428783bbbee67330630309e4b8fab12bfa2a9 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 25 Mar 2019 14:39:10 +0200
Subject: [PATCH] Splicing Pruner: simplify the splicing code

Rewrote the splicing logic with simpler code.
---
 distiller/pruning/splicing_pruner.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/distiller/pruning/splicing_pruner.py b/distiller/pruning/splicing_pruner.py
index dfe7e25..25d57db 100755
--- a/distiller/pruning/splicing_pruner.py
+++ b/distiller/pruning/splicing_pruner.py
@@ -81,8 +81,10 @@ class SplicingPruner(_ParameterPruner):
         # We followed the example implementation from Yiwen Guo in Caffe, and used the
         # weight tensor's starting mean and std.
         # This is very similar to the initialization performed by distiller.SensitivityPruner.
-
-        masked_weights = param.mul(zeros_mask_dict[param_name].mask).abs()
-        a = masked_weights.ge(threshold_low)
-        b = a & zeros_mask_dict[param_name].mask.type(torch.cuda.ByteTensor)
-        zeros_mask_dict[param_name].mask = (b | masked_weights.ge(threshold_hi)).type(torch.cuda.FloatTensor)
+    
+        mask = zeros_mask_dict[param_name].mask
+        zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type())
+        weights_abs = param.abs()
+        new_mask = torch.where(threshold_low > weights_abs, zeros, mask)
+        new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask)
+        zeros_mask_dict[param_name].mask = new_mask
-- 
GitLab