diff --git a/distiller/pruning/greedy_filter_pruning.py b/distiller/pruning/greedy_filter_pruning.py index 574b80db69a83cc974575a525c3561bbefb740cd..60202ae16327bde20fdb3052478946f60f91429a 100755 --- a/distiller/pruning/greedy_filter_pruning.py +++ b/distiller/pruning/greedy_filter_pruning.py @@ -222,6 +222,7 @@ def record_network_details(fields): writer.writerow(fields) +# This is a temporary hack! resnet50_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight", "module.layer1.1.conv1.weight", "module.layer1.1.conv2.weight", "module.layer1.2.conv1.weight", "module.layer1.2.conv2.weight", @@ -239,19 +240,21 @@ resnet50_params = ["module.layer1.0.conv1.weight", "module.layer1.0.conv2.weight "module.layer4.1.conv1.weight", "module.layer4.1.conv2.weight", "module.layer4.2.conv1.weight", "module.layer4.2.conv2.weight"] -resnet20_params = ["module.layer1.0.conv1.weight", "module.layer2.0.conv1.weight", "module.layer3.0.conv1.weight", - "module.layer1.1.conv1.weight", "module.layer2.1.conv1.weight", "module.layer3.1.conv1.weight", - "module.layer1.2.conv1.weight", "module.layer2.2.conv1.weight", "module.layer3.2.conv1.weight"] - -resnet56_params = [ "module.layer1.0.conv1.weight", "module.layer2.0.conv1.weight", "module.layer3.0.conv1.weight", - "module.layer1.1.conv1.weight", "module.layer2.1.conv1.weight", "module.layer3.1.conv1.weight", - "module.layer1.2.conv1.weight", "module.layer2.2.conv1.weight", "module.layer3.2.conv1.weight", - "module.layer1.3.conv1.weight", "module.layer2.3.conv1.weight", "module.layer3.3.conv1.weight", - "module.layer1.4.conv1.weight", "module.layer2.4.conv1.weight", "module.layer3.4.conv1.weight", - "module.layer1.5.conv1.weight", "module.layer2.5.conv1.weight", "module.layer3.5.conv1.weight", - "module.layer1.6.conv1.weight", "module.layer2.6.conv1.weight", "module.layer3.6.conv1.weight", - "module.layer1.7.conv1.weight", "module.layer2.7.conv1.weight", "module.layer3.7.conv1.weight", - "module.layer1.8.conv1.weight", "module.layer2.8.conv1.weight", "module.layer3.8.conv1.weight"] +resnet20_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight", "module.layer1.2.conv1.weight", + "module.layer2.0.conv1.weight", "module.layer2.1.conv1.weight", "module.layer2.2.conv1.weight", + "module.layer3.0.conv1.weight", "module.layer3.1.conv1.weight", "module.layer3.2.conv1.weight"] + +resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight", "module.layer1.2.conv1.weight", + "module.layer1.3.conv1.weight", "module.layer1.4.conv1.weight", "module.layer1.5.conv1.weight", + "module.layer1.6.conv1.weight", "module.layer1.7.conv1.weight", "module.layer1.8.conv1.weight", + + "module.layer2.0.conv1.weight", "module.layer2.1.conv1.weight", "module.layer2.2.conv1.weight", + "module.layer2.3.conv1.weight", "module.layer2.4.conv1.weight", "module.layer2.5.conv1.weight", + "module.layer2.6.conv1.weight", "module.layer2.7.conv1.weight", "module.layer2.8.conv1.weight", + + "module.layer3.0.conv1.weight", "module.layer3.1.conv1.weight", "module.layer3.2.conv1.weight", + "module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight", + "module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"] def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_fn, train_fn):