diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index 44093daa1d85bd67004a6ca97b4a45c23e130e3e..9c7c07d1094feea91439ae191b8264752e4a890e 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -95,46 +95,55 @@ class AnalysisTreeNode: type = data['type'], ) - def quick_check(self, __o: object) -> bool: - assert isinstance(__o, AnalysisTreeNode) - if not (self.init==__o.init and - self.mode==__o.mode and - self.agent==__o.agent and - self.start_time==__o.start_time and - self.assert_hits==__o.assert_hits and - self.type==__o.type and - self.static==__o.static and - self.uncertain_param==__o.uncertain_param and - len(self.child) == len(__o.child) and + def quick_check(self, other: "AnalysisTreeNode", atol = 1e-5) -> bool: + if not ( + # self.init==other.init and + self.mode==other.mode and + self.agent==other.agent and + self.start_time==other.start_time and + self.assert_hits==other.assert_hits and + self.type==other.type and + # self.static==other.static and + # self.uncertain_param==other.uncertain_param and + len(self.child) == len(other.child) and True - # self.id==__o.id + # self.id==other.id ): return False + # for agent in self.agent: + # if not np.allclose(trace, trace_other, 0, atol, equal_nan=True): + # return False return True - def __eq__(self, __o: object) -> bool: - if not self.quick_check(__o): + def check_inclusion(self, other: "AnalysisTreeNode", atol = 1e-5) -> bool: + # for simtrace, return True iff self and other are the same with the error toleration + # for reachtube, return True iff the trace of self contains trace of other with the same structure and error toleration + if not self.quick_check(other, atol): return False + # atol = pow(10, -num_digits) if self.type=='simtrace': for agent, trace in self.trace.items(): - trace_other = __o.trace[agent] - for (step, step_other) in zip(trace, trace_other): - if not np.allclose(step, step_other, equal_nan=True): - print("diff in trace:", step, step_other) - return False + trace_other = other.trace[agent] + assert trace.shape == trace_other.shape + # absolute(a - b) <= (atol + rtol * absolute(b)) + if not np.allclose(trace, trace_other, 0, atol, equal_nan=True): + return False elif self.type=='reachtube': for agent, trace in self.trace.items(): - trace_other = __o.trace[agent] - for (step, step_other) in zip(trace, trace_other): - if not np.allclose(step, step_other, equal_nan=True): - print("diff in trace:", step, step_other) - return False + # trace = np.around(trace, num_digits) + # trace_other = np.around(other.trace[agent], num_digits) + assert trace.shape == trace_other.shape + # when considering error, a - b >= -error is trusted as a>=b + ge_other = np.subtract(trace, trace_other) + atol >= 0 + le_other = np.subtract(trace, trace_other) - atol <= 0 + if not ((np.all(le_other[0::2]) == False) and (np.all(ge_other[1::2]) == True)): + return False else: raise ValueError child_match = 0 for child in self.child: - for child_other in __o.child: - if child == child_other: + for child_other in other.child: + if child.check_inclusion(child_other): child_match+=1 break if child_match == len(self.child): @@ -212,11 +221,61 @@ class AnalysisTree: return nid + 1 - def __eq__(self, __o: object) -> bool: - assert isinstance(__o, AnalysisTree) - if len(self.nodes) != len(__o.nodes): + def __eq__(self, other: object) -> bool: + assert isinstance(other, AnalysisTree) + if len(self.nodes) != len(other.nodes): return False - for (node, node_other) in zip(self.nodes, __o.nodes): - if not (node == node_other): - return False - return True + return self.root.check_inclusion(other.root) + + def contains(self, other: "AnalysisTree", strict: bool = True, tol: Optional[float] = None) -> bool: + """ + Returns, for reachability, whether the current tree fully contains the other tree or not; + for simulation, whether the other tree is close enough to the current tree. + strict: requires set of agents to be the same + """ + tol = _EPSILON if tol == None else tol + cur_agents = set(self.nodes[0].agent.keys()) + other_agents = set(other.nodes[0].agent.keys()) + min_agents = list(other_agents) + types = list(set([n.type for n in self.nodes + other.nodes])) + assert len(types) == 1, f"Different types of nodes: {types}" + if not ((strict and cur_agents == other_agents) or (not strict and cur_agents.issuperset(other_agents))): + return False + if types[0] == "simtrace": # Simulation + if len(self.nodes) != len(other.nodes): + return False + def sim_seg_contains(a: Dict[str, TraceType], b: Dict[str, TraceType]) -> bool: + return all(a[aid].shape == b[aid].shape and bool(np.all(np.abs(a[aid][:, 1:] - b[aid][:, 1:]) < tol)) for aid in min_agents) + def sim_node_contains(a: AnalysisTreeNode, b: AnalysisTreeNode) -> bool: + if not sim_seg_contains(a.trace, b.trace): + return False + if len(a.child) != len(b.child): + return False + child_num = len(a.child) + other_not_paired = set(range(child_num)) + for i in range(child_num): + for j in other_not_paired: + if sim_node_contains(a.child[i], b.child[j]): + other_not_paired.remove(j) + break + else: + return False + return True + return sim_node_contains(self.root, other.root) + else: # Reachability + cont_num = len(other.nodes[0].trace[min_agents[0]][0]) - 1 + def collect_ranges(n: AnalysisTreeNode) -> Dict[str, List[List[portion.Interval]]]: + trace_len = len(n.trace[min_agents[0]]) + cur = {aid: [[portion.closed(n.trace[aid][i][j + 1], n.trace[aid][i + 1][j + 1]) for j in range(cont_num)] for i in range(trace_len)] for aid in min_agents} + if len(n.child) == 0: + return cur + else: + children = [collect_ranges(c) for c in n.child] + child_num = len(children) + trace_len = len(children[min_agents[0]][0]) + combined = {aid: [[reduce(portion.Interval.union, (children[i][aid][j][k] for k in range(child_num))) for j in range(cont_num)] for i in range(trace_len)] for aid in other_agents} + return {aid: cur[aid] + combined[aid] for aid in other_agents} + this_tree, other_tree = collect_ranges(self.root), collect_ranges(other.root) + total_len = len(other_tree[min_agents[0]]) + # bloat and containment + return all(other_tree[aid][i][j] in this_tree[aid][i][j].apply(lambda x: x.replace(lower=lambda v: v - tol, upper=lambda v: v + tol)) for aid in other_agents for i in range(total_len) for j in range(cont_num))