diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 052f068728d68ec3bc21ca4bf16421044ba3c60c..6704f203d4237427ba9deaa0e92929ffd39485cf 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -30,6 +30,7 @@ examples_root = os.path.join(distiller_root, 'examples') script_path = os.path.realpath(os.path.join(examples_root, 'classifier_compression', 'compress_classifier.py')) + ########### # Some Basic Logging Mechanisms ########### @@ -86,7 +87,8 @@ def compare_values(name, expected, actual): return True -def accuracy_checker(log, run_dir, expected_top1, expected_top5): +def accuracy_checker(log, run_dir, expected_results): + expected_top1, expected_top5 = expected_results tops = re.findall(r"Top1: (?P<top1>\d*\.\d*) *Top5: (?P<top5>\d*\.\d*)", log) if not tops: error('No accuracy results in log') @@ -96,7 +98,34 @@ def accuracy_checker(log, run_dir, expected_top1, expected_top5): return compare_values('Top-5', expected_top5, float(tops[-1][1])) -def collateral_checker(log, run_dir, *collateral_list): +def earlyexit_accuracy_checker(log, run_dir, expected_results): + regex_list = (r"Accuracy Stats for exit 0: top1 = (?P<top1>\d*\.\d*), top5 = (?P<top5>\d*\.\d*)", + r"Accuracy Stats for exit 1: top1 = (?P<top1>\d*\.\d*), top5 = (?P<top5>\d*\.\d*)", + r"Totals for entire network with early exits: top1 = (?P<top1>\d*\.\d*), top5 = (?P<top5>\d*\.\d*)", + r"Top1: (?P<top1>\d*\.\d*) *Top5: (?P<top5>\d*\.\d*)") + + for i, regex in enumerate(regex_list): + if not generic_results_checker(log, regex, expected_results[i]): + return False + return True + + +def generic_results_checker(log, regex1, expected_results): + actual_results = re.findall(regex1, log) + if not actual_results: + error('No results in log') + return False + + # Grab only the last line of printed results + actual_results = actual_results[-1] + # Perform the comparison between expected and actual results + for (actual_result, expected_result) in zip(actual_results, expected_results): + if not compare_values('Un-named', expected_result, float(actual_result)): + return False + return True + + +def collateral_checker(log, run_dir, collateral_list): """Test that the test produced the expected collaterals. A collateral_list is a list of tuples, where tuple elements are: @@ -119,9 +148,14 @@ def collateral_checker(log, run_dir, *collateral_list): ########### TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args']) + test_configs = [ TestConfig('--arch resnet20_cifar_earlyexit --lr=0.3 --epochs=180 --earlyexit_thresholds 0.9 ' - '--earlyexit_lossweights 0.3 --epochs 2', DS_CIFAR, accuracy_checker, [17.04, 64.42]), + '--earlyexit_lossweights 0.3 --epochs 2 -p 50', DS_CIFAR, earlyexit_accuracy_checker, + [(99.115, 100.), + (17.235, 64.914), + (18.160, 65.310), + (17.04, 64.42)]), TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]), TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'), @@ -214,7 +248,7 @@ def run_tests(): format(p.returncode), idx, cmd, log_path, failed_tests, log) continue test_progress('Running checker: ' + colorize(tc.checker_fn.__name__, Colors.YELLOW)) - if not tc.checker_fn(log, os.path.split(log_path)[0], *tc.checker_args): + if not tc.checker_fn(log, os.path.split(log_path)[0], tc.checker_args): process_failure('Checker failed', idx, cmd, log_path, failed_tests, log) continue success('TEST PASSED')