Skip to content
Snippets Groups Projects
Commit df4d39c9 authored by Neta Zmora's avatar Neta Zmora
Browse files

Greedy filter pruning: add mobilenet_v1 greedy pruning

This is discussed in issue #282, although there @Bowenwu1 was
interested in mobilenet for CIFAR, not ImageNet.

Note that the implementation of the Greedy filter pruning algorithm is
not generic (but it is easily extensible) and supports only a subset
of the models.

An example invocation:
time python3 compress_classifier.py --arch=mobilenet PATH-TO-IMAGENET_DS  --resume=mobilenet_sgd_68.848.pth.tar --greedy --greedy-target-density=0.5 --vs=0 -p=50 --lr=0.1 --gpu=0 --greedy-pruning-step=0.15 --effective-train-size=0.01
parent 33419dcf
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,11 @@ an open-source package that runs on many different hardware platforms.
TODO: this code requires refactoring as is makes assumptions about applicative layers (e.g. the names of certain
members of the application `args` variable) and this is both tight-coupling and reverse-dependency.
An example invocation:
$ time python3 compress_classifier.py --arch=mobilenet PATH-TO-IMAGENET_DS --resume=mobilenet_sgd_68.848.pth.tar
--greedy --greedy-target-density=0.5 --vs=0 -p=50 --lr=0.1 --gpu=0 --greedy-pruning-step=0.1 --effective-train-size=0.1
References:
[1] Structural Compression of Convolutional Neural Networks Based on Greedy Filter Pruning
Reza Abbasi-Asl, Bin Yu https://arxiv.org/abs/1705.07356
......@@ -256,6 +261,13 @@ resnet56_params = ["module.layer1.0.conv1.weight", "module.layer1.1.conv1.weight
"module.layer3.3.conv1.weight", "module.layer3.4.conv1.weight", "module.layer3.5.conv1.weight",
"module.layer3.6.conv1.weight", "module.layer3.7.conv1.weight", "module.layer3.8.conv1.weight"]
mobilenet_params = [
#"module.model.0.0.weight",
"module.model.1.3.weight", "module.model.2.3.weight",
"module.model.3.3.weight", "module.model.4.3.weight", "module.model.5.3.weight",
"module.model.6.3.weight", "module.model.7.3.weight", "module.model.8.3.weight",
"module.model.9.3.weight", "module.model.10.3.weight", "module.model.11.3.weight",
"module.model.12.3.weight", "module.model.13.3.weight"]
def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_fn, train_fn):
dataset = app_args.dataset
......@@ -263,18 +275,19 @@ def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_
create_network_record_file()
# Temporary ugly hack!
resnet_layers = None
resnet_params = None
model_layers, model_params = None, None
if arch == "resnet20_cifar":
resnet_params = resnet20_params
model_params = resnet20_params
elif arch == "resnet56_cifar":
resnet_params = resnet56_params
model_params = resnet56_params
elif arch == "resnet50":
resnet_params = resnet50_params
if resnet_params is not None:
resnet_layers = [param[:-len(".weight")] for param in resnet_params]
model_params = resnet50_params
elif arch == "mobilenet":
model_params = mobilenet_params
if model_params is not None:
model_layers = [param[:-len(".weight")] for param in model_params]
total_macs = dense_total_macs = get_model_compute_budget(pruned_model, dataset, resnet_layers)
total_macs = dense_total_macs = get_model_compute_budget(pruned_model, dataset, model_layers)
iteration = 0
model = pruned_model
......@@ -288,10 +301,10 @@ def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_
prec1, prec5, param_name, pruned_model, zeros_mask_dict = find_most_robust_layer(iteration, pruned_model,
pruning_step,
test_fn, train_fn,
app_args, resnet_params,
app_args, model_params,
effective_train_size)
total_macs = get_model_compute_budget(pruned_model, dataset, resnet_layers)
densities = get_param_densities(model, pruned_model, resnet_params)
total_macs = get_model_compute_budget(pruned_model, dataset, model_layers)
densities = get_param_densities(model, pruned_model, model_params)
compute_density = total_macs/dense_total_macs
results = (iteration, prec1, param_name, compute_density, total_macs, densities)
record_network_details(results)
......@@ -300,6 +313,7 @@ def greedy_pruner(pruned_model, app_args, fraction_to_prune, pruning_step, test_
name="greedy__{}__{:.1f}__{:.1f}".format(str(iteration).zfill(3), compute_density*100, prec1),
dir=msglogger.logdir)
del scheduler
del zeros_mask_dict
msglogger.info("Iteration {}: top1-{:.2f} {} compute-{:.2f}".format(*results[0:4]))
assert iteration > 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment