Skip to content
Snippets Groups Projects
Commit b6624fe4 authored by keyis2's avatar keyis2
Browse files

checker

parent b5a30e24
No related branches found
No related tags found
2 merge requests!16Parallel,!21Arch2023 merged
...@@ -227,55 +227,3 @@ class AnalysisTree: ...@@ -227,55 +227,3 @@ class AnalysisTree:
return False return False
return self.root.check_inclusion(other.root) return self.root.check_inclusion(other.root)
def contains(self, other: "AnalysisTree", strict: bool = True, tol: Optional[float] = None) -> bool:
"""
Returns, for reachability, whether the current tree fully contains the other tree or not;
for simulation, whether the other tree is close enough to the current tree.
strict: requires set of agents to be the same
"""
tol = _EPSILON if tol == None else tol
cur_agents = set(self.nodes[0].agent.keys())
other_agents = set(other.nodes[0].agent.keys())
min_agents = list(other_agents)
types = list(set([n.type for n in self.nodes + other.nodes]))
assert len(types) == 1, f"Different types of nodes: {types}"
if not ((strict and cur_agents == other_agents) or (not strict and cur_agents.issuperset(other_agents))):
return False
if types[0] == "simtrace": # Simulation
if len(self.nodes) != len(other.nodes):
return False
def sim_seg_contains(a: Dict[str, TraceType], b: Dict[str, TraceType]) -> bool:
return all(a[aid].shape == b[aid].shape and bool(np.all(np.abs(a[aid][:, 1:] - b[aid][:, 1:]) < tol)) for aid in min_agents)
def sim_node_contains(a: AnalysisTreeNode, b: AnalysisTreeNode) -> bool:
if not sim_seg_contains(a.trace, b.trace):
return False
if len(a.child) != len(b.child):
return False
child_num = len(a.child)
other_not_paired = set(range(child_num))
for i in range(child_num):
for j in other_not_paired:
if sim_node_contains(a.child[i], b.child[j]):
other_not_paired.remove(j)
break
else:
return False
return True
return sim_node_contains(self.root, other.root)
else: # Reachability
cont_num = len(other.nodes[0].trace[min_agents[0]][0]) - 1
def collect_ranges(n: AnalysisTreeNode) -> Dict[str, List[List[portion.Interval]]]:
trace_len = len(n.trace[min_agents[0]])
cur = {aid: [[portion.closed(n.trace[aid][i][j + 1], n.trace[aid][i + 1][j + 1]) for j in range(cont_num)] for i in range(trace_len)] for aid in min_agents}
if len(n.child) == 0:
return cur
else:
children = [collect_ranges(c) for c in n.child]
child_num = len(children)
trace_len = len(children[min_agents[0]][0])
combined = {aid: [[reduce(portion.Interval.union, (children[i][aid][j][k] for k in range(child_num))) for j in range(cont_num)] for i in range(trace_len)] for aid in other_agents}
return {aid: cur[aid] + combined[aid] for aid in other_agents}
this_tree, other_tree = collect_ranges(self.root), collect_ranges(other.root)
total_len = len(other_tree[min_agents[0]])
# bloat and containment
return all(other_tree[aid][i][j] in this_tree[aid][i][j].apply(lambda x: x.replace(lower=lambda v: v - tol, upper=lambda v: v + tol)) for aid in other_agents for i in range(total_len) for j in range(cont_num))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment