diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 231a7b7a916d36e787c34fe5c2c878727d8ab2e6..20ee31561067f44ffd355f4093c46d726875e8b9 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -21,6 +21,7 @@ import re from collections import namedtuple import argparse import time +import math DS_CIFAR = 'cifar10' DS_MNIST = 'mnist' @@ -78,13 +79,13 @@ def success(string): # Checkers ########### -def compare_values(name, expected, actual): +def compare_values(name, expected, actual, rel_tol=5e-2): print('Comparing {0}: Expected = {1} ; Actual = {2}'.format(name, expected, actual)) - if expected != actual: + if math.isclose(expected, actual, rel_tol=rel_tol): + return True + else: error('Mismatch on {0}'.format(name)) return False - else: - return True def accuracy_checker(log, run_dir, expected_results):