From b135701bed6392a55d6d23f04f81495a7c66799b Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 28 Feb 2019 14:36:19 +0200
Subject: [PATCH] Added DropFilter as a separate regularizer

---
 distiller/models/__init__.py                  | 69 -------------------
 distiller/regularization/drop_filter.py       | 61 ++++++++++++++++
 .../compress_classifier.py                    | 12 ----
 3 files changed, 61 insertions(+), 81 deletions(-)
 create mode 100755 distiller/regularization/drop_filter.py

diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 43e28ab..49f5e7b 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -42,74 +42,6 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
 ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
 
 
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-
-class Conv2dWithMask(nn.Conv2d):
-    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
-                 padding=0, dilation=1, groups=1, bias=True):
-
-        super(Conv2dWithMask, self).__init__(
-            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
-            padding=padding, dilation=dilation, groups=groups, bias=bias)
-
-        self.test_mask = None
-        self.p_mask = 1.0
-        self.frequency = 16
-
-    def forward(self, input):
-        if self.training:
-            #prob = torch.distributions.binomial.Binomial(total_count=1, probs=[0.9]*self.out_channels)
-            #mask = prob.sample()
-            self.frequency -= 1
-            if self.frequency == 0:
-                sample = np.random.binomial(n=1, p=self.p_mask, size=self.out_channels)
-                param = self.weight
-                l1norm = param.detach().view(param.size(0), -1).norm(p=1, dim=1)
-                mask = torch.tensor(sample)
-                #print(mask.sum().item())
-
-                mask = mask.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
-                mask = mask.view(self.weight.shape).to(param.device)
-                mask = mask.type(param.type())
-                #print(mask.sum().item())
-                #pruning_factor = self.p_mask
-                masked_weights = self.weight * mask
-                masked_l1norm = masked_weights.detach().view(param.size(0), -1).norm(p=1, dim=1)
-                pruning_factor = (masked_l1norm.sum() / l1norm.sum()).item()
-                # print(pruning_factor)
-                pruning_factor = max(0.2, pruning_factor)
-                weight = masked_weights / pruning_factor
-                self.frequency = 16
-            else:
-                weight = self.weight
-            #self.test_mask = mask
-        # elif self.mask is not None:
-        #     mask = self.mask.view(-1, 1, 1, 1)
-        #     mask = mask.expand(self.weight.shape)
-        #     mask = mask.to(self.weight.device)
-        #     weight = self.weight * mask
-        else:
-            weight = self.weight# * self.test_mask
-
-        return F.conv2d(input, weight, self.bias, self.stride,
-                        self.padding, self.dilation, self.groups)
-
-
-# replaces all conv2d layers in target`s model with 'Conv2dWithMask'
-def replace_conv2d(container):
-    for name, module in container.named_children(): #for name, module in model.named_modules():
-        if (isinstance(module, nn.Conv2d)):
-            print("replacing: ", name)
-            new_module = Conv2dWithMask(in_channels=module.in_channels,
-                                        out_channels=module.out_channels,
-                                        kernel_size=module.kernel_size, padding=module.padding,
-                                        stride=module.stride, bias=module.bias)
-            setattr(container, name, new_module)
-        replace_conv2d(module)
-
-
 def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     """Create a pytorch model based on the model architecture and dataset
 
@@ -157,5 +89,4 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     else:
         device = 'cpu'
 
-    #replace_conv2d(model)
     return model.to(device)
diff --git a/distiller/regularization/drop_filter.py b/distiller/regularization/drop_filter.py
new file mode 100755
index 0000000..a909335
--- /dev/null
+++ b/distiller/regularization/drop_filter.py
@@ -0,0 +1,61 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from .regularizer import _Regularizer
+
+
+class Conv2dWithMask(nn.Conv2d):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True):
+
+        super(Conv2dWithMask, self).__init__(
+            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
+            padding=padding, dilation=dilation, groups=groups, bias=bias)
+
+        self.test_mask = None
+        self.p_mask = 1.0
+        self.frequency = 16
+
+    def forward(self, input):
+        if self.training:
+            self.frequency -= 1
+            if self.frequency == 0:
+                sample = np.random.binomial(n=1, p=self.p_mask, size=self.out_channels)
+                param = self.weight
+                l1norm = param.detach().view(param.size(0), -1).norm(p=1, dim=1)
+                mask = torch.tensor(sample)
+                mask = mask.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
+                mask = mask.view(self.weight.shape).to(param.device)
+                mask = mask.type(param.type())
+                masked_weights = self.weight * mask
+                masked_l1norm = masked_weights.detach().view(param.size(0), -1).norm(p=1, dim=1)
+                pruning_factor = (masked_l1norm.sum() / l1norm.sum()).item()
+                pruning_factor = max(0.2, pruning_factor)
+                weight = masked_weights / pruning_factor
+                self.frequency = 16
+            else:
+                weight = self.weight
+        else:
+            weight = self.weight
+        return F.conv2d(input, weight, self.bias, self.stride,
+                        self.padding, self.dilation, self.groups)
+
+
+# replaces all conv2d layers in target`s model with 'Conv2dWithMask'
+def replace_conv2d(container):
+    for name, module in container.named_children():
+        if (isinstance(module, nn.Conv2d)):
+            print("replacing: ", name)
+            new_module = Conv2dWithMask(in_channels=module.in_channels,
+                                        out_channels=module.out_channels,
+                                        kernel_size=module.kernel_size, padding=module.padding,
+                                        stride=module.stride, bias=module.bias)
+            setattr(container, name, new_module)
+        replace_conv2d(module)
+
+
+class DropFilterRegularizer(_Regularizer):
+    def __init__(self, name, model, reg_regims, threshold_criteria=None):
+        super().__init__(name, model, reg_regims, threshold_criteria)
+        replace_conv2d(model)
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index e40cf3c..d0092e1 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -257,18 +257,6 @@ def main():
 
         # Train for one epoch
         with collectors_context(activations_collectors["train"]) as collectors:
-
-
-
-            # if epoch > 15:
-            #     for name, module in model.named_modules():
-            #         if (isinstance(module, nn.Conv2d)):
-            #             module.p_mask = max(0.6, module.p_mask-0.005)
-            #             #module.p_mask = max(0.5, module.p_mask-0.02)
-            #             msglogger.info("setting filter drop probability to %.2f", module.p_mask)
-
-
-
             train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
                   loggers=[tflogger, pylogger], args=args)
             distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
-- 
GitLab