Skip to content
Snippets Groups Projects
Commit 109955eb authored by keyis2's avatar keyis2
Browse files

del equality checker.

parent 04eb6c0d
No related branches found
No related tags found
2 merge requests!16Parallel,!21Arch2023 merged
......@@ -99,62 +99,6 @@ class AnalysisTreeNode:
type = data['type'],
)
def quick_check(self, other: "AnalysisTreeNode", atol = 1e-5) -> bool:
if not (
# self.init==other.init and
self.mode==other.mode and
self.agent==other.agent and
self.start_time==other.start_time and
self.assert_hits==other.assert_hits and
self.type==other.type and
# self.static==other.static and
# self.uncertain_param==other.uncertain_param and
len(self.child) == len(other.child) and
True
# self.id==other.id
):
return False
# for agent in self.agent:
# if not np.allclose(trace, trace_other, 0, atol, equal_nan=True):
# return False
return True
def check_inclusion(self, other: "AnalysisTreeNode", atol = 1e-5) -> bool:
# for simtrace, return True iff self and other are the same with the error toleration
# for reachtube, return True iff the trace of self contains trace of other with the same structure and error toleration
if not self.quick_check(other, atol):
return False
# atol = pow(10, -num_digits)
if self.type=='simtrace':
for agent, trace in self.trace.items():
trace_other = other.trace[agent]
assert trace.shape == trace_other.shape
# absolute(a - b) <= (atol + rtol * absolute(b))
if not np.allclose(trace, trace_other, 0, atol, equal_nan=True):
return False
elif self.type=='reachtube':
for agent, trace in self.trace.items():
# trace = np.around(trace, num_digits)
# trace_other = np.around(other.trace[agent], num_digits)
assert trace.shape == trace_other.shape
# when considering error, a - b >= -error is trusted as a>=b
ge_other = np.subtract(trace, trace_other) + atol >= 0
le_other = np.subtract(trace, trace_other) - atol <= 0
# exclude the time dim
if not ((np.all(le_other[0::2, 1:]) == False) and (np.all(ge_other[1::2, 1:]) == True)):
return False
else:
raise ValueError
child_match = 0
for child in self.child:
for child_other in other.child:
if child.check_inclusion(child_other):
child_match+=1
break
if child_match == len(self.child):
return True
return False
class AnalysisTree:
def __init__(self, root):
......@@ -227,9 +171,3 @@ class AnalysisTree:
return nid + 1
def __eq__(self, other: object) -> bool:
assert isinstance(other, AnalysisTree)
if len(self.nodes) != len(other.nodes):
return False
return self.root.check_inclusion(other.root)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment