From fd544c7b108b97df16e0625b8c41f69f4cfbb2f3 Mon Sep 17 00:00:00 2001
From: Benoit Brummer <trougnouf@gmail.com>
Date: Thu, 8 Aug 2019 14:51:29 +0200
Subject: [PATCH] LpRankedStructureParameterPruner: work with 3D filters (ie
 Conv1d) (#348)

* LpRankedStructureParameterPruner: work with 3D filters (ie Conv1d)
---
 distiller/pruning/ranked_structures_pruner.py | 2 +-
 distiller/thresholding.py                     | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 7b47e8e..56454e7 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -220,7 +220,7 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
     def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, 
                                model=None, binary_map=None, magnitude_fn=l1_magnitude, 
                                noise=0.0, group_size=1, rounding_fn=math.floor):
-        assert param.dim() == 4, "This pruning is only supported for 4D weights"
+        assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
 
         threshold = None
         threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 52761ad..55be4a1 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -19,6 +19,7 @@
 The code below supports fine-grained tensor thresholding and group-wise thresholding.
 """
 import torch
+import numpy as np
 
 
 def threshold_mask(weights, threshold):
@@ -85,7 +86,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         return binary_map
 
     elif group_type == '3D' or group_type == 'Filters':
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
         view_filters = param.view(param.size(0), -1)
         thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
         binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
@@ -153,7 +154,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
     elif group_type == '3D' or group_type == 'Filters':
         if binary_map is None:
             binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
-        a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
+        a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()
         return a.view(*param.shape), binary_map
 
     elif group_type == '4D':
-- 
GitLab