diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 0a5d2dacffa78838c8d4787595a0aef4fcbc0b4e..ef466198d1162a63fe8051224fbd92c9808be05e 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -23,6 +23,7 @@ import argparse import time DS_CIFAR = 'cifar10' +DS_MNIST = 'mnist' distiller_root = os.path.realpath('..') examples_root = os.path.join(distiller_root, 'examples') @@ -128,7 +129,10 @@ test_configs = [ DS_CIFAR, accuracy_checker, [44.370, 89.640]), TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), - DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96157)]) + DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96157)]), + TestConfig('--arch simplenet_mnist --epochs 3 -p=50 --compress={0}'. + format(os.path.join('full_flow_tests', 'simplenet_mnist_pruning.yaml')), + DS_MNIST, accuracy_checker, [98.78, 100.]), ] @@ -169,11 +173,13 @@ def validate_dataset_path(path, default, name): def run_tests(): parser = argparse.ArgumentParser() parser.add_argument('--cifar10-path', dest='cifar10_path', metavar='DIR', help='Path to CIFAR-10 dataset') + parser.add_argument('--mnist-path', dest='mnist_path', metavar='DIR', help='Path to MNIST dataset') args = parser.parse_args() cifar10_path = validate_dataset_path(args.cifar10_path, default='data.cifar10', name='CIFAR-10') + mnist_path = validate_dataset_path(args.mnist_path, default='data.mnist', name='MNIST') - datasets = {DS_CIFAR: cifar10_path} + datasets = {DS_CIFAR: cifar10_path, DS_MNIST: mnist_path} total_configs = len(test_configs) failed_tests = [] for idx, tc in enumerate(test_configs): diff --git a/tests/full_flow_tests/simplenet_mnist_pruning.yaml b/tests/full_flow_tests/simplenet_mnist_pruning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1505e04991fa43dc01094ac83529213399f5dda9 --- /dev/null +++ b/tests/full_flow_tests/simplenet_mnist_pruning.yaml @@ -0,0 +1,53 @@ +# +# A YAML file for testing various pruners and scheduling configurations +# +version: 1 +pruners: + filter_pruner: + class: 'L1RankedStructureParameterPruner' + group_type: Filters + desired_sparsity: 0.1 + weights: [module.conv1.weight] + + filter_pruner_agp: + class: 'L1RankedStructureParameterPruner_AGP' + group_type: Filters + initial_sparsity: 0.05 + final_sparsity: 0.20 + weights: [module.conv2.weight] + + gemm_pruner_agp: + class: 'AutomatedGradualPruner' + initial_sparsity: 0.02 + final_sparsity: 0.15 + weights: [module.fc2.weight] + + +extensions: + net_thinner: + class: 'FilterRemover' + thinning_func_str: remove_filters + arch: 'simplenet_mnist' + dataset: 'mnist' + + +policies: + - pruner: + instance_name: filter_pruner + epochs: [0,1] + + - pruner: + instance_name: filter_pruner_agp + starting_epoch: 0 + ending_epoch: 2 + frequency: 1 + + - pruner: + instance_name: gemm_pruner_agp + starting_epoch: 0 + ending_epoch: 2 + frequency: 1 + + - extension: + instance_name: net_thinner + epochs: [2] \ No newline at end of file