From 556492a0a5434a6bcb21fb0301f8b68501ad2d21 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 7 Mar 2019 00:47:50 +0200 Subject: [PATCH] Greedy pruning: change the pruning order of resnet layers --- distiller/pruning/greedy_filter_pruning.py | 29 ++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/distiller/pruning/greedy_filter_pruning.py b/distiller/pruning/greedy_filter_pruning.py index 574b80d..60202ae 100755 --- a/distiller/pruning/greedy_filter_pruning.py +++ b/distiller/pruning/greedy_filter_pruning.py @@ -222,6 +222,7 @@ def record_network_details(fields): writer.writerow(fields) +# This is a temporary hack! resnet50_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight", "module.layer1.1.conv1.weight", "module.layer1.1.conv2.weight", "module.layer1.2.conv1.weight", "module.layer1.2.conv2.weight", @@ -239,19 +240,21 @@ resnet50_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight "module.layer4.1.conv1.weight", "module.layer4.1.conv2.weight", "module.layer4.2.conv1.weight", "module.layer4.2.conv2.weight"] -resnet20_params = ["module.layer1.0.conv1.weight", "module.layer2.0.conv1.weight", "module.layer3.0.conv1.weight", - "module.layer1.1.conv1.weight", "module.layer2.1.conv1.weight", "module.layer3.1.conv1.weight", - "module.layer1.2.conv1.weight", "module.layer2.2.conv1.weight", "module.layer3.2.conv1.weight"] - -resnet56_params = [ "module.layer1.0.conv1.weight", "module.layer2.0.conv1.weight", "module.layer3.0.conv1.weight", - "module.layer1.1.conv1.weight", "module.layer2.1.conv1.weight", "module.layer3.1.conv1.weight", - "module.layer1.2.conv1.weight", "module.layer2.2.conv1.weight", "module.layer3.2.conv1.weight", - "module.layer1.3.conv1.weight", "module.layer2.3.conv1.weight", "module.layer3.3.conv1.weight", - "module.layer1.4.conv1.weight", "module.layer2.4.conv1.weight", "module.layer3.4.conv1.weight", - "module.layer1.5.conv1.weight", "module.layer2.5.conv1.weight", "module.layer3.5.conv1.weight", - "module.layer1.6.conv1.weight", "module.layer2.6.conv1.weight", "module.layer3.6.conv1.weight", - "module.layer1.7.conv1.weight", "module.layer2.7.conv1.weight", "module.layer3.7.conv1.weight", - "module.layer1.8.conv1.weight", "module.layer2.8.conv1.weight", "module.layer3.8.conv1.weight"] +resnet20_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight", "module.layer1.2.conv1.weight", + "module.layer2.0.conv1.weight", "module.layer2.1.conv1.weight", "module.layer2.2.conv1.weight", + "module.layer3.0.conv1.weight", "module.layer3.1.conv1.weight", "module.layer3.2.conv1.weight"] + +resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight", "module.layer1.2.conv1.weight", + "module.layer1.3.conv1.weight", "module.layer1.4.conv1.weight", "module.layer1.5.conv1.weight", + "module.layer1.6.conv1.weight", "module.layer1.7.conv1.weight", "module.layer1.8.conv1.weight", + + "module.layer2.0.conv1.weight", "module.layer2.1.conv1.weight", "module.layer2.2.conv1.weight", + "module.layer2.3.conv1.weight", "module.layer2.4.conv1.weight", "module.layer2.5.conv1.weight", + "module.layer2.6.conv1.weight", "module.layer2.7.conv1.weight", "module.layer2.8.conv1.weight", + + "module.layer3.0.conv1.weight", "module.layer3.1.conv1.weight", "module.layer3.2.conv1.weight", + "module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight", + "module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"] def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_fn, train_fn): -- GitLab