From 729b14a880f9f3ebf6b266ba48fbebff9701d4d1 Mon Sep 17 00:00:00 2001 From: crides <zhuhaoqing@live.cn> Date: Fri, 7 Oct 2022 20:29:53 -0500 Subject: [PATCH] pend --- verse/analysis/incremental.py | 6 ++-- verse/analysis/simulator.py | 2 +- verse/analysis/verifier.py | 12 +++++--- verse/scenario/scenario.py | 57 +++++++++++++++++++++-------------- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py index 291be962..0e16d548 100644 --- a/verse/analysis/incremental.py +++ b/verse/analysis/incremental.py @@ -63,11 +63,11 @@ def to_simulate(old_agents: Dict[str, BaseAgent], new_agents: Dict[str, BaseAgen raise NotImplementedError("different variable outputs") for var, old_paths in old_grouped.items(): new_paths = new_grouped[var] - for old, new in itertools.zip_longest(old_paths, new_paths): + for i, (old, new) in enumerate(itertools.zip_longest(old_paths, new_paths)): if new == None: removed_paths.append(old) elif old.cond != new.cond: - added_paths.append((new_agent, new)) + added_paths.append((new_agent, i, new)) elif old.val != new.val: reset_changed_paths.append(new) new_cache = {} @@ -204,8 +204,8 @@ class ReachTubeCache: def add_tube(self, agent_id: str, init: Dict[str, List[List[float]]], node: AnalysisTreeNode, transit_agents: List[str], transition, trans_ind: int, run_num: int): key = (agent_id,) + tuple(node.mode[agent_id]) tree = self.cache[key] + pp(('add seg', agent_id, node.mode[agent_id], init)) assert_hits = node.assert_hits or {} - # pp(('add seg', agent_id, *node.mode[agent_id], *init)) init = list(map(tuple, zip(*init[agent_id]))) for i, (low, high) in enumerate(init): if i == len(init) - 1: diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py index eb65159c..99781b24 100644 --- a/verse/analysis/simulator.py +++ b/verse/analysis/simulator.py @@ -13,7 +13,7 @@ pp = functools.partial(pprint.pprint, compact=True, width=130) # from verse.agents.base_agent import BaseAgent from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree -PathDiffs = List[Tuple[BaseAgent, ModePath]] +PathDiffs = List[Tuple[BaseAgent, int, ModePath]] def red(s): return "\x1b[31m" + s + "\x1b[0m" #]] diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py index 76dc5169..bfb144d2 100644 --- a/verse/analysis/verifier.py +++ b/verse/analysis/verifier.py @@ -140,6 +140,7 @@ class Verifier: while verification_queue != []: node: AnalysisTreeNode = verification_queue.pop(0) combined_inits = {a: combine_all(inits) for a, inits in node.init.items()} + print() pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode})) remain_time = round(time_horizon - node.start_time, 10) if remain_time <= 0: @@ -150,9 +151,8 @@ class Verifier: for agent_id in node.agent: mode = node.mode[agent_id] inits = node.init[agent_id] - init = combined_inits[agent_id] if self.config.incremental: - cached = self.trans_cache.check_hit(agent_id, mode, init) + cached = self.trans_cache.check_hit(agent_id, mode, combined_inits[agent_id]) if cached != None: cached_tubes[agent_id] = cached if agent_id not in node.trace: @@ -200,7 +200,7 @@ class Verifier: trace = np.array(cur_bloated_tube) trace[:, 0] += node.start_time node.trace[agent_id] = trace.tolist() - pp(("cached_segments", cached_tubes.keys())) + pp(("cached tubes", cached_tubes.keys())) node_ids = list(set((s.run_num, s.node_id) for s in cached_tubes.values())) # assert len(node_ids) <= 1, f"{node_ids}" new_cache, paths_to_sim = {}, [] @@ -214,7 +214,7 @@ class Verifier: # Get all possible transitions to next mode asserts, all_possible_transitions = transition_graph.get_transition_verify_new(new_cache, paths_to_sim, node) - pp(("transitions:", all_possible_transitions)) + pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions])) node.assert_hits = asserts if asserts != None: asserts, idx = asserts @@ -249,6 +249,8 @@ class Verifier: max_end_idx = 0 for transition in all_possible_transitions: # Each transition will contain a list of rectangles and their corresponding indexes in the original list + if len(transition) != 6: + pp(("weird trans", transition)) transit_agent_idx, src_mode, dest_mode, next_init, idx, path = transition start_idx, end_idx = idx[0], idx[-1] @@ -274,7 +276,7 @@ class Verifier: next_node_init[agent_idx] = next_init else: next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]] - if agent_idx != transit_agent_idx: + pp(("infer init", agent_idx, next_node_init[agent_idx], truncated_trace[agent_idx][:8])) next_node_trace[agent_idx] = truncated_trace[agent_idx] tmp = AnalysisTreeNode( diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py index 35d7eb4d..ef2646bd 100644 --- a/verse/scenario/scenario.py +++ b/verse/scenario/scenario.py @@ -11,11 +11,12 @@ import numpy as np from verse.agents.base_agent import BaseAgent from verse.analysis.dryvr import _EPSILON -from verse.analysis.incremental import CachedRTTrans, CachedSegment +from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all from verse.analysis.simulator import PathDiffs from verse.automaton import GuardExpressionAst, ResetExpression from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree from verse.analysis.utils import sample_rect +from verse.parser import astunparser from verse.parser.parser import ControllerIR, ModePath, find from verse.sensor.base_sensor import BaseSensor from verse.map.lane_map import LaneMap @@ -307,7 +308,7 @@ class Scenario: # The reset_list here are all the resets for a single transition. Need to evaluate each of them # and then combine them together for reset_tuple in reset_list: - reset, disc_var_dict, cont_var_dict, _ = reset_tuple + reset, disc_var_dict, cont_var_dict, _, _p = reset_tuple reset_variable = reset.var expr = reset.expr # First get the transition destinations @@ -457,7 +458,7 @@ class Scenario: for path in agent_paths: cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond])) - for agent, path in paths: + for agent, _idx, path in paths: # Get guard if len(agent.controller.args) == 0: continue @@ -496,7 +497,7 @@ class Scenario: break return None, dict(transitions), idx - def get_transition_verify_new(self, cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]: + def get_transition_verify_new(self, cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]]]: lane_map = self.map # For each agent @@ -504,28 +505,33 @@ class Scenario: cached_guards = defaultdict(list) if not cache: - paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths] + paths = [(agent, i, p) for agent in node.agent.values() for i, p in enumerate(agent.controller.paths)] else: if len(paths) == 0: # print(red("full cache")) - def trans_close(a: Dict[str, List[float]], b: Dict[str, List[float]]) -> bool: + def trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool: assert set(a.keys()) == set(b.keys()) - return all(abs(av - bv) < _EPSILON for aid in a.keys() for av, bv in zip(a[aid], b[aid])) - - _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions] - # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_close(trans.inits, node.init)] + def transp(a): + return list(map(list, zip(*a))) + def suits(a: List[List[float]], b: List[List[float]]) -> bool: + at, bt = transp(a), transp(b) + return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt)) + return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid])) + + # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions] + _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_suit(trans.inits, node.init)] # pp(("cached trans", _transitions)) if len(_transitions) == 0: - return None, None, 0 + return None, [] transition = min(_transitions) - transitions = defaultdict(list) + transitions = [] for agent_id, seg in cache.items(): # TODO: check for asserts for tran in seg.transitions: if tran.transition == transition: # pp(("chosen tran", agent_id, tran)) - transitions[agent_id].append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths)) - return None, dict(transitions) + transitions.append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths)) + return None, transitions path_transitions = defaultdict(int) for seg in cache.values(): @@ -563,7 +569,10 @@ class Scenario: continue cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond])) - for agent, path in paths: + for aid, trace in node.trace.items(): + if len(trace) < 2: + pp(("weird state", aid, trace)) + for agent, idx, path in paths: if len(agent.controller.args) == 0: continue agent_id = agent.id @@ -585,7 +594,8 @@ class Scenario: agent_guard_dict[agent_id].append( (guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), path)) - trace_length = int(len(list(node.trace.values())[0])/2) + trace_length = int(min(len(v) for v in node.trace.values()) // 2) + # pp(("trace len", trace_length, {a: len(t) for a, t in node.trace.items()})) guard_hits = [] guard_hit = False for idx in range(trace_length): @@ -599,6 +609,8 @@ class Scenario: if len(agent.controller.args) == 0: continue agent_state, agent_mode, agent_static = state_dict[agent_id] + if np.array(agent_state).ndim != 2: + pp(("weird state", agent_id, agent_state)) agent_state = agent_state[1:] cont_vars, disc_vars, len_dict = self.sensor.sense(self, agent, state_dict, self.map) resets = defaultdict(list) @@ -673,8 +685,8 @@ class Scenario: for hits, all_agent_state, hit_idx in guard_hits: for agent_id, reset_idx, reset_list in hits: # TODO: Need to change this function to handle the new reset expression and then I am done - dest_list, reset_rect = self.apply_reset( - node.agent[agent_id], reset_list[:-1], all_agent_state) + dest_list, reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state) + # pp(("dests", dest_list, *[astunparser.unparse(reset[-1].val_veri) for reset in reset_list])) if agent_id not in reset_dict: reset_dict[agent_id] = {} if not dest_list: @@ -693,9 +705,10 @@ class Scenario: for agent in reset_dict: for reset_idx in reset_dict[agent]: for dest in reset_dict[agent][reset_idx]: - if list(dest) != list(node.mode[agent]): - resets = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest]))) - transition = (agent, node.mode[agent], dest, *resets) - possible_transitions.append(transition) + # reset_data = reset_dict[agent][reset_idx][dest] + reset_data = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest]))) + # pp(("resets", reset_data)) + transition = (agent, node.mode[agent], dest, *reset_data) + possible_transitions.append(transition) # Return result return None, possible_transitions -- GitLab