diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index 8ad2f6116c2737a3e9a89a3389bf8c2a334cd593..102ea2c4fc07e18a03b32977397241dab3d1033c 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -99,62 +99,6 @@ class AnalysisTreeNode: type = data['type'], ) - def quick_check(self, other: "AnalysisTreeNode", atol = 1e-5) -> bool: - if not ( - # self.init==other.init and - self.mode==other.mode and - self.agent==other.agent and - self.start_time==other.start_time and - self.assert_hits==other.assert_hits and - self.type==other.type and - # self.static==other.static and - # self.uncertain_param==other.uncertain_param and - len(self.child) == len(other.child) and - True - # self.id==other.id - ): - return False - # for agent in self.agent: - # if not np.allclose(trace, trace_other, 0, atol, equal_nan=True): - # return False - return True - - def check_inclusion(self, other: "AnalysisTreeNode", atol = 1e-5) -> bool: - # for simtrace, return True iff self and other are the same with the error toleration - # for reachtube, return True iff the trace of self contains trace of other with the same structure and error toleration - if not self.quick_check(other, atol): - return False - # atol = pow(10, -num_digits) - if self.type=='simtrace': - for agent, trace in self.trace.items(): - trace_other = other.trace[agent] - assert trace.shape == trace_other.shape - # absolute(a - b) <= (atol + rtol * absolute(b)) - if not np.allclose(trace, trace_other, 0, atol, equal_nan=True): - return False - elif self.type=='reachtube': - for agent, trace in self.trace.items(): - # trace = np.around(trace, num_digits) - # trace_other = np.around(other.trace[agent], num_digits) - assert trace.shape == trace_other.shape - # when considering error, a - b >= -error is trusted as a>=b - ge_other = np.subtract(trace, trace_other) + atol >= 0 - le_other = np.subtract(trace, trace_other) - atol <= 0 - # exclude the time dim - if not ((np.all(le_other[0::2, 1:]) == False) and (np.all(ge_other[1::2, 1:]) == True)): - return False - else: - raise ValueError - child_match = 0 - for child in self.child: - for child_other in other.child: - if child.check_inclusion(child_other): - child_match+=1 - break - if child_match == len(self.child): - return True - return False - class AnalysisTree: def __init__(self, root): @@ -227,9 +171,3 @@ class AnalysisTree: return nid + 1 - def __eq__(self, other: object) -> bool: - assert isinstance(other, AnalysisTree) - if len(self.nodes) != len(other.nodes): - return False - return self.root.check_inclusion(other.root) -