From d8a5a6d7cfb28fa993e14b79bc115c924446cdd1 Mon Sep 17 00:00:00 2001
From: crides <zhuhaoqing@live.cn>
Date: Sun, 9 Oct 2022 18:37:47 -0500
Subject: [PATCH] demo7 reach working

---
 verse/analysis/incremental.py | 38 +++++++++++++++++++++------
 verse/analysis/utils.py       | 11 ++++++++
 verse/analysis/verifier.py    | 20 +++++---------
 verse/scenario/scenario.py    | 49 +++++++----------------------------
 4 files changed, 58 insertions(+), 60 deletions(-)

diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py
index 0e16d548..99f1d429 100644
--- a/verse/analysis/incremental.py
+++ b/verse/analysis/incremental.py
@@ -103,6 +103,15 @@ def combine_all(inits):
     return [[min(a) for a in np.transpose(np.array(inits)[:, 0])],
             [max(a) for a in np.transpose(np.array(inits)[:, 1])]]
 
+def reach_trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool:
+    assert set(a.keys()) == set(b.keys())
+    def transp(a):
+        return list(map(list, zip(*a)))
+    def suits(a: List[List[float]], b: List[List[float]]) -> bool:
+        at, bt = transp(a), transp(b)
+        return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt))
+    return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
+
 @dataclass
 class CachedTube:
     tube: List[List[List[float]]]
@@ -219,18 +228,31 @@ class ReachTubeCache:
                 tree = next_level_tree
         raise Exception("???")
 
-    def check_hit(self, agent_id: str, mode: Tuple[str], init: List[float]) -> Optional[CachedRTTrans]:
+    @staticmethod
+    def query_cont(tree: IntervalTree, cont: List[Tuple[float, float]]) -> List[CachedRTTrans]:
+        assert isinstance(tree, IntervalTree)
+        low, high = cont[0]
+        next_level_entries = [t.data for t in tree[low:high + _EPSILON] if t.begin <= low and high <= t.end]
+        if len(cont) == 1:
+            for ent in next_level_entries:
+                assert isinstance(ent, CachedRTTrans)
+            return next_level_entries
+        else:
+            return [ent for t in next_level_entries for ent in ReachTubeCache.query_cont(t, cont[1:])]
+
+    def check_hit(self, agent_id: str, mode: Tuple[str], init: List[float], inits: Dict[str, List[List[List[float]]]]) -> Optional[CachedRTTrans]:
         key = (agent_id,) + tuple(mode)
         if key not in self.cache:
             return None
         tree = self.cache[key]
-        for low, high in list(map(tuple, zip(*init))):
-            next_level_entries = [t for t in tree[low:high + _EPSILON] if t.begin <= low and high <= t.end]
-            if len(next_level_entries) == 0:
-                return None
-            tree = min(next_level_entries, key=lambda e: low - e.begin + e.end - high).data
-        assert isinstance(tree, CachedRTTrans)
-        return tree
+        entries = self.query_cont(tree, list(map(tuple, zip(*init))))
+        if len(entries) == 0:
+            return None
+        def num_trans_suit(e: CachedRTTrans) -> int:
+            return sum(1 if reach_trans_suit(t.inits, inits) else 0 for t in e.transitions)
+        entries = list(sorted([(e, -num_trans_suit(e)) for e in entries], key=lambda p: p[1]))
+        pp(("check hit entries", len(entries), entries[0][1]))
+        return entries[0][0]
 
     @staticmethod
     def iter_tree(tree, depth: int) -> List[List[float]]:
diff --git a/verse/analysis/utils.py b/verse/analysis/utils.py
index 47da88c6..6778df8c 100644
--- a/verse/analysis/utils.py
+++ b/verse/analysis/utils.py
@@ -382,3 +382,14 @@ def sample_rect(rect: List[List[float]]) -> List[float]:
     # for i in range(len(rect[0])):
     res = np.random.uniform(rect[0], rect[1]).tolist()
     return res
+
+def dedup(l, f=lambda a:a):
+    o = []
+    l = [(i, f(i)) for i in l]
+    for i, k in l:
+        for _, k_ in o:
+            if k == k_:
+                break
+        else:
+            o.append((i, k))
+    return [i for i, _ in o]
diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py
index bfb144d2..30689e13 100644
--- a/verse/analysis/verifier.py
+++ b/verse/analysis/verifier.py
@@ -11,6 +11,7 @@ from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree
 from verse.analysis.dryvr import calc_bloated_tube, SIMTRACENUM
 from verse.analysis.mixmonotone import calculate_bloated_tube_mixmono_cont, calculate_bloated_tube_mixmono_disc
 from verse.analysis.incremental import ReachTubeCache, TubeCache, convert_reach_trans, to_simulate, combine_all
+from verse.analysis.utils import dedup
 from verse.parser.parser import find
 pp = functools.partial(pprint.pprint, compact=True, width=130)
 
@@ -151,8 +152,10 @@ class Verifier:
             for agent_id in node.agent:
                 mode = node.mode[agent_id]
                 inits = node.init[agent_id]
+                combined = combine_all(inits)
                 if self.config.incremental:
-                    cached = self.trans_cache.check_hit(agent_id, mode, combined_inits[agent_id])
+                    cached = self.trans_cache.check_hit(agent_id, mode, combined, node.init)
+                    pp(("check hit", agent_id, mode, combined))
                     if cached != None:
                         cached_tubes[agent_id] = cached
                 if agent_id not in node.trace:
@@ -232,17 +235,8 @@ class Verifier:
                     if agent_id in cached_tubes:
                         cached_tubes[agent_id].transitions.extend(convert_reach_trans(agent_id, transit_agents, node.init, transition, transit_ind))
                         pre_len = len(cached_tubes[agent_id].transitions)
-                        def dedup(l):
-                            o = []
-                            for i in l:
-                                for j in o:
-                                    if i.mode == j.mode and i.dest == j.dest:
-                                        break
-                                else:
-                                    o.append(i)
-                            return o
-                        cached_tubes[agent_id].transitions = dedup(cached_tubes[agent_id].transitions)
-                        # pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions)))
+                        cached_tubes[agent_id].transitions = dedup(cached_tubes[agent_id].transitions, lambda i: (i.mode, i.dest, i.inits))
+                        pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions)))
                     else:
                         self.trans_cache.add_tube(agent_id, combined_inits, node, transit_agents, transition, transit_ind, run_num)
 
@@ -276,7 +270,7 @@ class Verifier:
                         next_node_init[agent_idx] = next_init
                     else:
                         next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]]
-                        pp(("infer init", agent_idx, next_node_init[agent_idx], truncated_trace[agent_idx][:8]))
+                        pp(("infer init", agent_idx, next_node_init[agent_idx]))
                         next_node_trace[agent_idx] = truncated_trace[agent_idx]
 
                 tmp = AnalysisTreeNode(
diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py
index ef2646bd..48d30ffe 100644
--- a/verse/scenario/scenario.py
+++ b/verse/scenario/scenario.py
@@ -11,11 +11,11 @@ import numpy as np
 
 from verse.agents.base_agent import BaseAgent
 from verse.analysis.dryvr import _EPSILON
-from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all
+from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all, reach_trans_suit
 from verse.analysis.simulator import PathDiffs
 from verse.automaton import GuardExpressionAst, ResetExpression
 from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree
-from verse.analysis.utils import sample_rect
+from verse.analysis.utils import dedup, sample_rect
 from verse.parser import astunparser
 from verse.parser.parser import ControllerIR, ModePath, find
 from verse.sensor.base_sensor import BaseSensor
@@ -444,16 +444,7 @@ class Scenario:
                 if len(agent.controller.args) == 0:
                     continue
                 state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent}
-                def dedup(l):
-                    o = []
-                    for i in l:
-                        for j in o:
-                            if i.var == j.var and i.cond == j.cond and i.val == j.val:
-                                break
-                        else:
-                            o.append(i)
-                    return o
-                agent_paths = dedup([p for tran in segment.transitions for p in tran.paths])
+                agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val))
                 cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
                 for path in agent_paths:
                     cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond]))
@@ -509,28 +500,16 @@ class Scenario:
         else:
             if len(paths) == 0:
                 # print(red("full cache"))
-                def trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool:
-                    assert set(a.keys()) == set(b.keys())
-                    def transp(a):
-                        return list(map(list, zip(*a)))
-                    def suits(a: List[List[float]], b: List[List[float]]) -> bool:
-                        at, bt = transp(a), transp(b)
-                        return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt))
-                    return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
 
                 # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
-                _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_suit(trans.inits, node.init)]
-                # pp(("cached trans", _transitions))
+                _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init) or pp(("not suit", trans.inits, node.init))]
+                pp(("cached trans", len(_transitions)))
                 if len(_transitions) == 0:
                     return None, []
                 transition = min(_transitions)
-                transitions = []
-                for agent_id, seg in cache.items():
-                    # TODO: check for asserts
-                    for tran in seg.transitions:
-                        if tran.transition == transition:
-                            # pp(("chosen tran", agent_id, tran))
-                            transitions.append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths))
+                transitions = [(agent_id, tran) for agent_id, seg in cache.items() for tran in seg.transitions if tran.transition == transition]
+                # TODO: check for asserts
+                transitions = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(transitions, lambda p: (p[0], p[1].mode, p[1].dest))]
                 return None, transitions
 
             path_transitions = defaultdict(int)
@@ -543,16 +522,8 @@ class Scenario:
                 if len(agent.controller.args) == 0:
                     continue
                 state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent}
-                def dedup(l):
-                    o = []
-                    for i in l:
-                        for j in o:
-                            if i.var == j.var and i.cond == j.cond and i.val == j.val:
-                                break
-                        else:
-                            o.append(i)
-                    return o
-                agent_paths = dedup([p for tran in segment.transitions for p in tran.paths])
+
+                agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val))
                 for path in agent_paths:
                     cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense(
                         self, agent, state_dict, self.map)
-- 
GitLab