diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index 9c7c07d1094feea91439ae191b8264752e4a890e..20d9339f3562fe6b80d6aa787d6f53f5b8ccfa9f 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -227,55 +227,3 @@ class AnalysisTree: return False return self.root.check_inclusion(other.root) - def contains(self, other: "AnalysisTree", strict: bool = True, tol: Optional[float] = None) -> bool: - """ - Returns, for reachability, whether the current tree fully contains the other tree or not; - for simulation, whether the other tree is close enough to the current tree. - strict: requires set of agents to be the same - """ - tol = _EPSILON if tol == None else tol - cur_agents = set(self.nodes[0].agent.keys()) - other_agents = set(other.nodes[0].agent.keys()) - min_agents = list(other_agents) - types = list(set([n.type for n in self.nodes + other.nodes])) - assert len(types) == 1, f"Different types of nodes: {types}" - if not ((strict and cur_agents == other_agents) or (not strict and cur_agents.issuperset(other_agents))): - return False - if types[0] == "simtrace": # Simulation - if len(self.nodes) != len(other.nodes): - return False - def sim_seg_contains(a: Dict[str, TraceType], b: Dict[str, TraceType]) -> bool: - return all(a[aid].shape == b[aid].shape and bool(np.all(np.abs(a[aid][:, 1:] - b[aid][:, 1:]) < tol)) for aid in min_agents) - def sim_node_contains(a: AnalysisTreeNode, b: AnalysisTreeNode) -> bool: - if not sim_seg_contains(a.trace, b.trace): - return False - if len(a.child) != len(b.child): - return False - child_num = len(a.child) - other_not_paired = set(range(child_num)) - for i in range(child_num): - for j in other_not_paired: - if sim_node_contains(a.child[i], b.child[j]): - other_not_paired.remove(j) - break - else: - return False - return True - return sim_node_contains(self.root, other.root) - else: # Reachability - cont_num = len(other.nodes[0].trace[min_agents[0]][0]) - 1 - def collect_ranges(n: AnalysisTreeNode) -> Dict[str, List[List[portion.Interval]]]: - trace_len = len(n.trace[min_agents[0]]) - cur = {aid: [[portion.closed(n.trace[aid][i][j + 1], n.trace[aid][i + 1][j + 1]) for j in range(cont_num)] for i in range(trace_len)] for aid in min_agents} - if len(n.child) == 0: - return cur - else: - children = [collect_ranges(c) for c in n.child] - child_num = len(children) - trace_len = len(children[min_agents[0]][0]) - combined = {aid: [[reduce(portion.Interval.union, (children[i][aid][j][k] for k in range(child_num))) for j in range(cont_num)] for i in range(trace_len)] for aid in other_agents} - return {aid: cur[aid] + combined[aid] for aid in other_agents} - this_tree, other_tree = collect_ranges(self.root), collect_ranges(other.root) - total_len = len(other_tree[min_agents[0]]) - # bloat and containment - return all(other_tree[aid][i][j] in this_tree[aid][i][j].apply(lambda x: x.replace(lower=lambda v: v - tol, upper=lambda v: v + tol)) for aid in other_agents for i in range(total_len) for j in range(cont_num))