Skip to content
Snippets Groups Projects
Commit 49a2a967 authored by Neta Zmora's avatar Neta Zmora
Browse files

tests/full_flow_tests.py: improve early-exit test robustness

EE runs emit more statistics than the regular classification pipeline,
and it is more robust to validate more of the log output for correctness
validation.
parent a7473c95
No related branches found
No related tags found
No related merge requests found
...@@ -30,6 +30,7 @@ examples_root = os.path.join(distiller_root, 'examples') ...@@ -30,6 +30,7 @@ examples_root = os.path.join(distiller_root, 'examples')
script_path = os.path.realpath(os.path.join(examples_root, 'classifier_compression', script_path = os.path.realpath(os.path.join(examples_root, 'classifier_compression',
'compress_classifier.py')) 'compress_classifier.py'))
########### ###########
# Some Basic Logging Mechanisms # Some Basic Logging Mechanisms
########### ###########
...@@ -86,7 +87,8 @@ def compare_values(name, expected, actual): ...@@ -86,7 +87,8 @@ def compare_values(name, expected, actual):
return True return True
def accuracy_checker(log, run_dir, expected_top1, expected_top5): def accuracy_checker(log, run_dir, expected_results):
expected_top1, expected_top5 = expected_results
tops = re.findall(r"Top1: (?P<top1>\d*\.\d*) *Top5: (?P<top5>\d*\.\d*)", log) tops = re.findall(r"Top1: (?P<top1>\d*\.\d*) *Top5: (?P<top5>\d*\.\d*)", log)
if not tops: if not tops:
error('No accuracy results in log') error('No accuracy results in log')
...@@ -96,7 +98,34 @@ def accuracy_checker(log, run_dir, expected_top1, expected_top5): ...@@ -96,7 +98,34 @@ def accuracy_checker(log, run_dir, expected_top1, expected_top5):
return compare_values('Top-5', expected_top5, float(tops[-1][1])) return compare_values('Top-5', expected_top5, float(tops[-1][1]))
def collateral_checker(log, run_dir, *collateral_list): def earlyexit_accuracy_checker(log, run_dir, expected_results):
regex_list = (r"Accuracy Stats for exit 0: top1 = (?P<top1>\d*\.\d*), top5 = (?P<top5>\d*\.\d*)",
r"Accuracy Stats for exit 1: top1 = (?P<top1>\d*\.\d*), top5 = (?P<top5>\d*\.\d*)",
r"Totals for entire network with early exits: top1 = (?P<top1>\d*\.\d*), top5 = (?P<top5>\d*\.\d*)",
r"Top1: (?P<top1>\d*\.\d*) *Top5: (?P<top5>\d*\.\d*)")
for i, regex in enumerate(regex_list):
if not generic_results_checker(log, regex, expected_results[i]):
return False
return True
def generic_results_checker(log, regex1, expected_results):
actual_results = re.findall(regex1, log)
if not actual_results:
error('No results in log')
return False
# Grab only the last line of printed results
actual_results = actual_results[-1]
# Perform the comparison between expected and actual results
for (actual_result, expected_result) in zip(actual_results, expected_results):
if not compare_values('Un-named', expected_result, float(actual_result)):
return False
return True
def collateral_checker(log, run_dir, collateral_list):
"""Test that the test produced the expected collaterals. """Test that the test produced the expected collaterals.
A collateral_list is a list of tuples, where tuple elements are: A collateral_list is a list of tuples, where tuple elements are:
...@@ -119,9 +148,14 @@ def collateral_checker(log, run_dir, *collateral_list): ...@@ -119,9 +148,14 @@ def collateral_checker(log, run_dir, *collateral_list):
########### ###########
TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args']) TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args'])
test_configs = [ test_configs = [
TestConfig('--arch resnet20_cifar_earlyexit --lr=0.3 --epochs=180 --earlyexit_thresholds 0.9 ' TestConfig('--arch resnet20_cifar_earlyexit --lr=0.3 --epochs=180 --earlyexit_thresholds 0.9 '
'--earlyexit_lossweights 0.3 --epochs 2', DS_CIFAR, accuracy_checker, [17.04, 64.42]), '--earlyexit_lossweights 0.3 --epochs 2 -p 50', DS_CIFAR, earlyexit_accuracy_checker,
[(99.115, 100.),
(17.235, 64.914),
(18.160, 65.310),
(17.04, 64.42)]),
TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]), TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]),
TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'. TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'.
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'), format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'),
...@@ -214,7 +248,7 @@ def run_tests(): ...@@ -214,7 +248,7 @@ def run_tests():
format(p.returncode), idx, cmd, log_path, failed_tests, log) format(p.returncode), idx, cmd, log_path, failed_tests, log)
continue continue
test_progress('Running checker: ' + colorize(tc.checker_fn.__name__, Colors.YELLOW)) test_progress('Running checker: ' + colorize(tc.checker_fn.__name__, Colors.YELLOW))
if not tc.checker_fn(log, os.path.split(log_path)[0], *tc.checker_args): if not tc.checker_fn(log, os.path.split(log_path)[0], tc.checker_args):
process_failure('Checker failed', idx, cmd, log_path, failed_tests, log) process_failure('Checker failed', idx, cmd, log_path, failed_tests, log)
continue continue
success('TEST PASSED') success('TEST PASSED')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment