From 556492a0a5434a6bcb21fb0301f8b68501ad2d21 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 7 Mar 2019 00:47:50 +0200
Subject: [PATCH] Greedy pruning: change the pruning order of resnet layers

---
 distiller/pruning/greedy_filter_pruning.py | 29 ++++++++++++----------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/distiller/pruning/greedy_filter_pruning.py b/distiller/pruning/greedy_filter_pruning.py
index 574b80d..60202ae 100755
--- a/distiller/pruning/greedy_filter_pruning.py
+++ b/distiller/pruning/greedy_filter_pruning.py
@@ -222,6 +222,7 @@ def record_network_details(fields):
         writer.writerow(fields)
 
 
+# This is a temporary hack!
 resnet50_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight",
                    "module.layer1.1.conv1.weight", "module.layer1.1.conv2.weight",
                    "module.layer1.2.conv1.weight", "module.layer1.2.conv2.weight",
@@ -239,19 +240,21 @@ resnet50_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight
                    "module.layer4.1.conv1.weight", "module.layer4.1.conv2.weight",
                    "module.layer4.2.conv1.weight", "module.layer4.2.conv2.weight"]
 
-resnet20_params = ["module.layer1.0.conv1.weight", "module.layer2.0.conv1.weight", "module.layer3.0.conv1.weight",
-                   "module.layer1.1.conv1.weight", "module.layer2.1.conv1.weight", "module.layer3.1.conv1.weight",
-                   "module.layer1.2.conv1.weight", "module.layer2.2.conv1.weight", "module.layer3.2.conv1.weight"]
-
-resnet56_params = [ "module.layer1.0.conv1.weight", "module.layer2.0.conv1.weight", "module.layer3.0.conv1.weight",
-                    "module.layer1.1.conv1.weight", "module.layer2.1.conv1.weight", "module.layer3.1.conv1.weight",
-                    "module.layer1.2.conv1.weight", "module.layer2.2.conv1.weight", "module.layer3.2.conv1.weight",
-                    "module.layer1.3.conv1.weight", "module.layer2.3.conv1.weight", "module.layer3.3.conv1.weight",
-                    "module.layer1.4.conv1.weight", "module.layer2.4.conv1.weight", "module.layer3.4.conv1.weight",
-                    "module.layer1.5.conv1.weight", "module.layer2.5.conv1.weight", "module.layer3.5.conv1.weight",
-                    "module.layer1.6.conv1.weight", "module.layer2.6.conv1.weight", "module.layer3.6.conv1.weight",
-                    "module.layer1.7.conv1.weight", "module.layer2.7.conv1.weight", "module.layer3.7.conv1.weight",
-                    "module.layer1.8.conv1.weight", "module.layer2.8.conv1.weight", "module.layer3.8.conv1.weight"]
+resnet20_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight", "module.layer1.2.conv1.weight",
+                   "module.layer2.0.conv1.weight", "module.layer2.1.conv1.weight", "module.layer2.2.conv1.weight",
+                   "module.layer3.0.conv1.weight", "module.layer3.1.conv1.weight", "module.layer3.2.conv1.weight"]
+
+resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight", "module.layer1.2.conv1.weight",
+                   "module.layer1.3.conv1.weight", "module.layer1.4.conv1.weight", "module.layer1.5.conv1.weight",
+                   "module.layer1.6.conv1.weight", "module.layer1.7.conv1.weight", "module.layer1.8.conv1.weight",
+
+                   "module.layer2.0.conv1.weight", "module.layer2.1.conv1.weight", "module.layer2.2.conv1.weight",
+                   "module.layer2.3.conv1.weight", "module.layer2.4.conv1.weight", "module.layer2.5.conv1.weight",
+                   "module.layer2.6.conv1.weight", "module.layer2.7.conv1.weight", "module.layer2.8.conv1.weight",
+
+                   "module.layer3.0.conv1.weight", "module.layer3.1.conv1.weight", "module.layer3.2.conv1.weight",
+                   "module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight",
+                   "module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"]
 
 
 def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_fn, train_fn):
-- 
GitLab