From 729b14a880f9f3ebf6b266ba48fbebff9701d4d1 Mon Sep 17 00:00:00 2001
From: crides <zhuhaoqing@live.cn>
Date: Fri, 7 Oct 2022 20:29:53 -0500
Subject: [PATCH] pend

---
 verse/analysis/incremental.py |  6 ++--
 verse/analysis/simulator.py   |  2 +-
 verse/analysis/verifier.py    | 12 +++++---
 verse/scenario/scenario.py    | 57 +++++++++++++++++++++--------------
 4 files changed, 46 insertions(+), 31 deletions(-)

diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py
index 291be962..0e16d548 100644
--- a/verse/analysis/incremental.py
+++ b/verse/analysis/incremental.py
@@ -63,11 +63,11 @@ def to_simulate(old_agents: Dict[str, BaseAgent], new_agents: Dict[str, BaseAgen
             raise NotImplementedError("different variable outputs")
         for var, old_paths in old_grouped.items():
             new_paths = new_grouped[var]
-            for old, new in itertools.zip_longest(old_paths, new_paths):
+            for i, (old, new) in enumerate(itertools.zip_longest(old_paths, new_paths)):
                 if new == None:
                     removed_paths.append(old)
                 elif old.cond != new.cond:
-                    added_paths.append((new_agent, new))
+                    added_paths.append((new_agent, i, new))
                 elif old.val != new.val:
                     reset_changed_paths.append(new)
     new_cache = {}
@@ -204,8 +204,8 @@ class ReachTubeCache:
     def add_tube(self, agent_id: str, init: Dict[str, List[List[float]]], node: AnalysisTreeNode, transit_agents: List[str], transition, trans_ind: int, run_num: int):
         key = (agent_id,) + tuple(node.mode[agent_id])
         tree = self.cache[key]
+        pp(('add seg', agent_id, node.mode[agent_id], init))
         assert_hits = node.assert_hits or {}
-        # pp(('add seg', agent_id, *node.mode[agent_id], *init))
         init = list(map(tuple, zip(*init[agent_id])))
         for i, (low, high) in enumerate(init):
             if i == len(init) - 1:
diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py
index eb65159c..99781b24 100644
--- a/verse/analysis/simulator.py
+++ b/verse/analysis/simulator.py
@@ -13,7 +13,7 @@ pp = functools.partial(pprint.pprint, compact=True, width=130)
 # from verse.agents.base_agent import BaseAgent
 from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree
 
-PathDiffs = List[Tuple[BaseAgent, ModePath]]
+PathDiffs = List[Tuple[BaseAgent, int, ModePath]]
 
 def red(s):
     return "\x1b[31m" + s + "\x1b[0m" #]]
diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py
index 76dc5169..bfb144d2 100644
--- a/verse/analysis/verifier.py
+++ b/verse/analysis/verifier.py
@@ -140,6 +140,7 @@ class Verifier:
         while verification_queue != []:
             node: AnalysisTreeNode = verification_queue.pop(0)
             combined_inits = {a: combine_all(inits) for a, inits in node.init.items()}
+            print()
             pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode}))
             remain_time = round(time_horizon - node.start_time, 10)
             if remain_time <= 0:
@@ -150,9 +151,8 @@ class Verifier:
             for agent_id in node.agent:
                 mode = node.mode[agent_id]
                 inits = node.init[agent_id]
-                init = combined_inits[agent_id]
                 if self.config.incremental:
-                    cached = self.trans_cache.check_hit(agent_id, mode, init)
+                    cached = self.trans_cache.check_hit(agent_id, mode, combined_inits[agent_id])
                     if cached != None:
                         cached_tubes[agent_id] = cached
                 if agent_id not in node.trace:
@@ -200,7 +200,7 @@ class Verifier:
                     trace = np.array(cur_bloated_tube)
                     trace[:, 0] += node.start_time
                     node.trace[agent_id] = trace.tolist()
-            pp(("cached_segments", cached_tubes.keys()))
+            pp(("cached tubes", cached_tubes.keys()))
             node_ids = list(set((s.run_num, s.node_id) for s in cached_tubes.values()))
             # assert len(node_ids) <= 1, f"{node_ids}"
             new_cache, paths_to_sim = {}, []
@@ -214,7 +214,7 @@ class Verifier:
 
             # Get all possible transitions to next mode
             asserts, all_possible_transitions = transition_graph.get_transition_verify_new(new_cache, paths_to_sim, node)
-            pp(("transitions:", all_possible_transitions))
+            pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions]))
             node.assert_hits = asserts
             if asserts != None:
                 asserts, idx = asserts
@@ -249,6 +249,8 @@ class Verifier:
             max_end_idx = 0
             for transition in all_possible_transitions:
                 # Each transition will contain a list of rectangles and their corresponding indexes in the original list
+                if len(transition) != 6:
+                    pp(("weird trans", transition))
                 transit_agent_idx, src_mode, dest_mode, next_init, idx, path = transition
                 start_idx, end_idx = idx[0], idx[-1]
 
@@ -274,7 +276,7 @@ class Verifier:
                         next_node_init[agent_idx] = next_init
                     else:
                         next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]]
-                    if agent_idx != transit_agent_idx:
+                        pp(("infer init", agent_idx, next_node_init[agent_idx], truncated_trace[agent_idx][:8]))
                         next_node_trace[agent_idx] = truncated_trace[agent_idx]
 
                 tmp = AnalysisTreeNode(
diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py
index 35d7eb4d..ef2646bd 100644
--- a/verse/scenario/scenario.py
+++ b/verse/scenario/scenario.py
@@ -11,11 +11,12 @@ import numpy as np
 
 from verse.agents.base_agent import BaseAgent
 from verse.analysis.dryvr import _EPSILON
-from verse.analysis.incremental import CachedRTTrans, CachedSegment
+from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all
 from verse.analysis.simulator import PathDiffs
 from verse.automaton import GuardExpressionAst, ResetExpression
 from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree
 from verse.analysis.utils import sample_rect
+from verse.parser import astunparser
 from verse.parser.parser import ControllerIR, ModePath, find
 from verse.sensor.base_sensor import BaseSensor
 from verse.map.lane_map import LaneMap
@@ -307,7 +308,7 @@ class Scenario:
         # The reset_list here are all the resets for a single transition. Need to evaluate each of them
         # and then combine them together
         for reset_tuple in reset_list:
-            reset, disc_var_dict, cont_var_dict, _ = reset_tuple
+            reset, disc_var_dict, cont_var_dict, _, _p = reset_tuple
             reset_variable = reset.var
             expr = reset.expr
             # First get the transition destinations
@@ -457,7 +458,7 @@ class Scenario:
                 for path in agent_paths:
                     cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond]))
 
-        for agent, path in paths:
+        for agent, _idx, path in paths:
             # Get guard
             if len(agent.controller.args) == 0:
                 continue
@@ -496,7 +497,7 @@ class Scenario:
                 break
         return None, dict(transitions), idx
 
-    def get_transition_verify_new(self, cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]:
+    def get_transition_verify_new(self, cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]]]:
         lane_map = self.map
 
         # For each agent
@@ -504,28 +505,33 @@ class Scenario:
         cached_guards = defaultdict(list)
 
         if not cache:
-            paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths]
+            paths = [(agent, i, p) for agent in node.agent.values() for i, p in enumerate(agent.controller.paths)]
         else:
             if len(paths) == 0:
                 # print(red("full cache"))
-                def trans_close(a: Dict[str, List[float]], b: Dict[str, List[float]]) -> bool:
+                def trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool:
                     assert set(a.keys()) == set(b.keys())
-                    return all(abs(av - bv) < _EPSILON for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
-
-                _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
-                # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_close(trans.inits, node.init)]
+                    def transp(a):
+                        return list(map(list, zip(*a)))
+                    def suits(a: List[List[float]], b: List[List[float]]) -> bool:
+                        at, bt = transp(a), transp(b)
+                        return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt))
+                    return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
+
+                # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
+                _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_suit(trans.inits, node.init)]
                 # pp(("cached trans", _transitions))
                 if len(_transitions) == 0:
-                    return None, None, 0
+                    return None, []
                 transition = min(_transitions)
-                transitions = defaultdict(list)
+                transitions = []
                 for agent_id, seg in cache.items():
                     # TODO: check for asserts
                     for tran in seg.transitions:
                         if tran.transition == transition:
                             # pp(("chosen tran", agent_id, tran))
-                            transitions[agent_id].append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths))
-                return None, dict(transitions)
+                            transitions.append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths))
+                return None, transitions
 
             path_transitions = defaultdict(int)
             for seg in cache.values():
@@ -563,7 +569,10 @@ class Scenario:
                         continue
                     cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond]))
 
-        for agent, path in paths:
+        for aid, trace in node.trace.items():
+            if len(trace) < 2:
+                pp(("weird state", aid, trace))
+        for agent, idx, path in paths:
             if len(agent.controller.args) == 0:
                 continue
             agent_id = agent.id
@@ -585,7 +594,8 @@ class Scenario:
             agent_guard_dict[agent_id].append(
                 (guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), path))
 
-        trace_length = int(len(list(node.trace.values())[0])/2)
+        trace_length = int(min(len(v) for v in node.trace.values()) // 2)
+        # pp(("trace len", trace_length, {a: len(t) for a, t in node.trace.items()}))
         guard_hits = []
         guard_hit = False
         for idx in range(trace_length):
@@ -599,6 +609,8 @@ class Scenario:
                 if len(agent.controller.args) == 0:
                     continue
                 agent_state, agent_mode, agent_static = state_dict[agent_id]
+                if np.array(agent_state).ndim != 2:
+                    pp(("weird state", agent_id, agent_state))
                 agent_state = agent_state[1:]
                 cont_vars, disc_vars, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
                 resets = defaultdict(list)
@@ -673,8 +685,8 @@ class Scenario:
         for hits, all_agent_state, hit_idx in guard_hits:
             for agent_id, reset_idx, reset_list in hits:
                 # TODO: Need to change this function to handle the new reset expression and then I am done
-                dest_list, reset_rect = self.apply_reset(
-                    node.agent[agent_id], reset_list[:-1], all_agent_state)
+                dest_list, reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
+                # pp(("dests", dest_list, *[astunparser.unparse(reset[-1].val_veri) for reset in reset_list]))
                 if agent_id not in reset_dict:
                     reset_dict[agent_id] = {}
                 if not dest_list:
@@ -693,9 +705,10 @@ class Scenario:
         for agent in reset_dict:
             for reset_idx in reset_dict[agent]:
                 for dest in reset_dict[agent][reset_idx]:
-                    if list(dest) != list(node.mode[agent]):
-                        resets = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest])))
-                        transition = (agent, node.mode[agent], dest, *resets)
-                        possible_transitions.append(transition)
+                    # reset_data = reset_dict[agent][reset_idx][dest]
+                    reset_data = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest])))
+                    # pp(("resets", reset_data))
+                    transition = (agent, node.mode[agent], dest, *reset_data)
+                    possible_transitions.append(transition)
         # Return result
         return None, possible_transitions
-- 
GitLab