From b6b6f8174ac3307fcd1c4339cb3926abe921ec64 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 17 Mar 2019 13:49:15 +0200
Subject: [PATCH] Add two stochastic pruners
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

BernoulliFilterPruner – assigns a Bernoulli probability distribution to each
of the filters.

RandomLevelStructureParameterPruner – assigns a Uniform probability
distribution to the level-pruning level used by an L1-norm structure pruner.
---
 distiller/pruning/ranked_structures_pruner.py | 65 +++++++++++++++++--
 1 file changed, 60 insertions(+), 5 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 8be38b5..2ee79cb 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -18,6 +18,7 @@ from functools import partial
 import numpy as np
 import logging
 import torch
+from random import uniform
 import distiller
 from .pruner import _ParameterPruner
 msglogger = logging.getLogger()
@@ -174,7 +175,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
     @staticmethod
     def rank_and_prune_filters(fraction_to_prune, param, param_name,
                                zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        assert param.dim() == 4, "This pruning is only supported for 4D weights"
 
         threshold = None
         if binary_map is None:
@@ -217,7 +218,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
         away from each other in memory, and this means poor use of caches and system memory.
         """
 
-        assert param.dim() == 2, "This thresholding is only supported for 2D weights"
+        assert param.dim() == 2, "This pruning is only supported for 2D weights"
         ROWS_DIM = 0
         THRESHOLD_DIM = 'Cols'
         rows_mags = magnitude_fn(param, dim=ROWS_DIM)
@@ -338,6 +339,23 @@ class L2RankedStructureParameterPruner(LpRankedStructureParameterPruner):
                          group_dependency, kwargs, magnitude_fn=l2_magnitude)
 
 
+class RandomLevelStructureParameterPruner(L1RankedStructureParameterPruner):
+    """Uses mean L1-norm to rank and prune structures, with a random pruning regimen.
+
+    This class sets the pruning level to a random value in the range sparsity_range,
+    and chooses which structures to remove using L1-norm ranking.
+    The idea is similiar to DropFilter, but instead of randomly choosing filters,
+    we randomly choose a sparsity level and then prune according to magnitude.
+    """
+    def __init__(self, name, group_type, sparsity_range, weights,
+                 group_dependency=None, kwargs=None):
+        self.sparsity_range = sparsity_range
+        super().__init__(name, group_type, 0, weights, group_dependency, kwargs)
+
+    def fraction_to_prune(self, param_name):
+        return uniform(self.sparsity_range[0], self.sparsity_range[1])
+
+
 def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map):
     if binary_map is None:
         binary_map = torch.zeros(num_filters).cuda()
@@ -365,7 +383,7 @@ class ActivationRankedFilterPruner(RankedStructureParameterPruner):
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        assert param.dim() == 4, "This pruning is only supported for 4D weights"
 
         # Use the parameter name to locate the module that has the activation sparsity statistics
         fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")]
@@ -438,7 +456,7 @@ class RandomRankedFilterPruner(RankedStructureParameterPruner):
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        assert param.dim() == 4, "This pruning is only supported for 4D weights"
         num_filters = param.size(0)
         num_filters_to_prune = int(fraction_to_prune * num_filters)
 
@@ -457,6 +475,43 @@ class RandomRankedFilterPruner(RankedStructureParameterPruner):
         return binary_map
 
 
+class BernoulliFilterPruner(RankedStructureParameterPruner):
+    """A Bernoulli probability for dropping each filter.
+
+    This is can be used for random filter-dropping algorithms (e.g. DropFilter)
+    """
+    def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None):
+        super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
+
+    def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
+        if fraction_to_prune == 0:
+            return
+        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
+                                                 zeros_mask_dict, model, binary_map)
+        return binary_map
+
+    def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
+        assert param.dim() == 4, "This pruner is only supported for 4D weights"
+        num_filters = param.size(0)
+        num_filters_to_prune = int(fraction_to_prune * num_filters)
+
+        keep_prob = 1 - fraction_to_prune
+        if binary_map is None:
+            binary_map = torch.bernoulli(torch.as_tensor([keep_prob] * num_filters))
+        mask, _ = mask_from_filter_order(None, param, num_filters, binary_map)
+        # mask = mask.detach()
+        mask = mask.to(param.device)
+        # Compensate for dropping filters
+        pruning_factor = binary_map.sum() / num_filters
+        mask.div_(pruning_factor)
+        zeros_mask_dict[param_name].mask = mask
+        msglogger.debug("BernoulliFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                        param_name,
+                        distiller.sparsity_3D(zeros_mask_dict[param_name].mask),
+                        fraction_to_prune, num_filters_to_prune, num_filters)
+        return binary_map
+
+
 class GradientRankedFilterPruner(RankedStructureParameterPruner):
     """
     """
@@ -471,7 +526,7 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner):
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        assert param.dim() == 4, "This pruning is only supported for 4D weights"
         if param.grad is None:
             msglogger.info("Skipping gradient pruning of %s because it does not have a gradient yet", param_name)
             return
-- 
GitLab