Skip to content
Snippets Groups Projects
Commit 67db927d authored by Neta Zmora's avatar Neta Zmora
Browse files

full_flow_tests.py: Added a pruning test

This test uses MNIST for faster execution and test various
pruners and their scheduling.
parent da4bcbfc
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@ import argparse
import time
DS_CIFAR = 'cifar10'
DS_MNIST = 'mnist'
distiller_root = os.path.realpath('..')
examples_root = os.path.join(distiller_root, 'examples')
......@@ -128,7 +129,10 @@ test_configs = [
DS_CIFAR, accuracy_checker, [44.370, 89.640]),
TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96157)])
DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96157)]),
TestConfig('--arch simplenet_mnist --epochs 3 -p=50 --compress={0}'.
format(os.path.join('full_flow_tests', 'simplenet_mnist_pruning.yaml')),
DS_MNIST, accuracy_checker, [98.78, 100.]),
]
......@@ -169,11 +173,13 @@ def validate_dataset_path(path, default, name):
def run_tests():
parser = argparse.ArgumentParser()
parser.add_argument('--cifar10-path', dest='cifar10_path', metavar='DIR', help='Path to CIFAR-10 dataset')
parser.add_argument('--mnist-path', dest='mnist_path', metavar='DIR', help='Path to MNIST dataset')
args = parser.parse_args()
cifar10_path = validate_dataset_path(args.cifar10_path, default='data.cifar10', name='CIFAR-10')
mnist_path = validate_dataset_path(args.mnist_path, default='data.mnist', name='MNIST')
datasets = {DS_CIFAR: cifar10_path}
datasets = {DS_CIFAR: cifar10_path, DS_MNIST: mnist_path}
total_configs = len(test_configs)
failed_tests = []
for idx, tc in enumerate(test_configs):
......
#
# A YAML file for testing various pruners and scheduling configurations
#
version: 1
pruners:
filter_pruner:
class: 'L1RankedStructureParameterPruner'
group_type: Filters
desired_sparsity: 0.1
weights: [module.conv1.weight]
filter_pruner_agp:
class: 'L1RankedStructureParameterPruner_AGP'
group_type: Filters
initial_sparsity: 0.05
final_sparsity: 0.20
weights: [module.conv2.weight]
gemm_pruner_agp:
class: 'AutomatedGradualPruner'
initial_sparsity: 0.02
final_sparsity: 0.15
weights: [module.fc2.weight]
extensions:
net_thinner:
class: 'FilterRemover'
thinning_func_str: remove_filters
arch: 'simplenet_mnist'
dataset: 'mnist'
policies:
- pruner:
instance_name: filter_pruner
epochs: [0,1]
- pruner:
instance_name: filter_pruner_agp
starting_epoch: 0
ending_epoch: 2
frequency: 1
- pruner:
instance_name: gemm_pruner_agp
starting_epoch: 0
ending_epoch: 2
frequency: 1
- extension:
instance_name: net_thinner
epochs: [2]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment