From d8a5a6d7cfb28fa993e14b79bc115c924446cdd1 Mon Sep 17 00:00:00 2001 From: crides <zhuhaoqing@live.cn> Date: Sun, 9 Oct 2022 18:37:47 -0500 Subject: [PATCH] demo7 reach working --- verse/analysis/incremental.py | 38 +++++++++++++++++++++------ verse/analysis/utils.py | 11 ++++++++ verse/analysis/verifier.py | 20 +++++--------- verse/scenario/scenario.py | 49 +++++++---------------------------- 4 files changed, 58 insertions(+), 60 deletions(-) diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py index 0e16d548..99f1d429 100644 --- a/verse/analysis/incremental.py +++ b/verse/analysis/incremental.py @@ -103,6 +103,15 @@ def combine_all(inits): return [[min(a) for a in np.transpose(np.array(inits)[:, 0])], [max(a) for a in np.transpose(np.array(inits)[:, 1])]] +def reach_trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool: + assert set(a.keys()) == set(b.keys()) + def transp(a): + return list(map(list, zip(*a))) + def suits(a: List[List[float]], b: List[List[float]]) -> bool: + at, bt = transp(a), transp(b) + return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt)) + return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid])) + @dataclass class CachedTube: tube: List[List[List[float]]] @@ -219,18 +228,31 @@ class ReachTubeCache: tree = next_level_tree raise Exception("???") - def check_hit(self, agent_id: str, mode: Tuple[str], init: List[float]) -> Optional[CachedRTTrans]: + @staticmethod + def query_cont(tree: IntervalTree, cont: List[Tuple[float, float]]) -> List[CachedRTTrans]: + assert isinstance(tree, IntervalTree) + low, high = cont[0] + next_level_entries = [t.data for t in tree[low:high + _EPSILON] if t.begin <= low and high <= t.end] + if len(cont) == 1: + for ent in next_level_entries: + assert isinstance(ent, CachedRTTrans) + return next_level_entries + else: + return [ent for t in next_level_entries for ent in ReachTubeCache.query_cont(t, cont[1:])] + + def check_hit(self, agent_id: str, mode: Tuple[str], init: List[float], inits: Dict[str, List[List[List[float]]]]) -> Optional[CachedRTTrans]: key = (agent_id,) + tuple(mode) if key not in self.cache: return None tree = self.cache[key] - for low, high in list(map(tuple, zip(*init))): - next_level_entries = [t for t in tree[low:high + _EPSILON] if t.begin <= low and high <= t.end] - if len(next_level_entries) == 0: - return None - tree = min(next_level_entries, key=lambda e: low - e.begin + e.end - high).data - assert isinstance(tree, CachedRTTrans) - return tree + entries = self.query_cont(tree, list(map(tuple, zip(*init)))) + if len(entries) == 0: + return None + def num_trans_suit(e: CachedRTTrans) -> int: + return sum(1 if reach_trans_suit(t.inits, inits) else 0 for t in e.transitions) + entries = list(sorted([(e, -num_trans_suit(e)) for e in entries], key=lambda p: p[1])) + pp(("check hit entries", len(entries), entries[0][1])) + return entries[0][0] @staticmethod def iter_tree(tree, depth: int) -> List[List[float]]: diff --git a/verse/analysis/utils.py b/verse/analysis/utils.py index 47da88c6..6778df8c 100644 --- a/verse/analysis/utils.py +++ b/verse/analysis/utils.py @@ -382,3 +382,14 @@ def sample_rect(rect: List[List[float]]) -> List[float]: # for i in range(len(rect[0])): res = np.random.uniform(rect[0], rect[1]).tolist() return res + +def dedup(l, f=lambda a:a): + o = [] + l = [(i, f(i)) for i in l] + for i, k in l: + for _, k_ in o: + if k == k_: + break + else: + o.append((i, k)) + return [i for i, _ in o] diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py index bfb144d2..30689e13 100644 --- a/verse/analysis/verifier.py +++ b/verse/analysis/verifier.py @@ -11,6 +11,7 @@ from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree from verse.analysis.dryvr import calc_bloated_tube, SIMTRACENUM from verse.analysis.mixmonotone import calculate_bloated_tube_mixmono_cont, calculate_bloated_tube_mixmono_disc from verse.analysis.incremental import ReachTubeCache, TubeCache, convert_reach_trans, to_simulate, combine_all +from verse.analysis.utils import dedup from verse.parser.parser import find pp = functools.partial(pprint.pprint, compact=True, width=130) @@ -151,8 +152,10 @@ class Verifier: for agent_id in node.agent: mode = node.mode[agent_id] inits = node.init[agent_id] + combined = combine_all(inits) if self.config.incremental: - cached = self.trans_cache.check_hit(agent_id, mode, combined_inits[agent_id]) + cached = self.trans_cache.check_hit(agent_id, mode, combined, node.init) + pp(("check hit", agent_id, mode, combined)) if cached != None: cached_tubes[agent_id] = cached if agent_id not in node.trace: @@ -232,17 +235,8 @@ class Verifier: if agent_id in cached_tubes: cached_tubes[agent_id].transitions.extend(convert_reach_trans(agent_id, transit_agents, node.init, transition, transit_ind)) pre_len = len(cached_tubes[agent_id].transitions) - def dedup(l): - o = [] - for i in l: - for j in o: - if i.mode == j.mode and i.dest == j.dest: - break - else: - o.append(i) - return o - cached_tubes[agent_id].transitions = dedup(cached_tubes[agent_id].transitions) - # pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions))) + cached_tubes[agent_id].transitions = dedup(cached_tubes[agent_id].transitions, lambda i: (i.mode, i.dest, i.inits)) + pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions))) else: self.trans_cache.add_tube(agent_id, combined_inits, node, transit_agents, transition, transit_ind, run_num) @@ -276,7 +270,7 @@ class Verifier: next_node_init[agent_idx] = next_init else: next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]] - pp(("infer init", agent_idx, next_node_init[agent_idx], truncated_trace[agent_idx][:8])) + pp(("infer init", agent_idx, next_node_init[agent_idx])) next_node_trace[agent_idx] = truncated_trace[agent_idx] tmp = AnalysisTreeNode( diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py index ef2646bd..48d30ffe 100644 --- a/verse/scenario/scenario.py +++ b/verse/scenario/scenario.py @@ -11,11 +11,11 @@ import numpy as np from verse.agents.base_agent import BaseAgent from verse.analysis.dryvr import _EPSILON -from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all +from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all, reach_trans_suit from verse.analysis.simulator import PathDiffs from verse.automaton import GuardExpressionAst, ResetExpression from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree -from verse.analysis.utils import sample_rect +from verse.analysis.utils import dedup, sample_rect from verse.parser import astunparser from verse.parser.parser import ControllerIR, ModePath, find from verse.sensor.base_sensor import BaseSensor @@ -444,16 +444,7 @@ class Scenario: if len(agent.controller.args) == 0: continue state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} - def dedup(l): - o = [] - for i in l: - for j in o: - if i.var == j.var and i.cond == j.cond and i.val == j.val: - break - else: - o.append(i) - return o - agent_paths = dedup([p for tran in segment.transitions for p in tran.paths]) + agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val)) cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map) for path in agent_paths: cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond])) @@ -509,28 +500,16 @@ class Scenario: else: if len(paths) == 0: # print(red("full cache")) - def trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool: - assert set(a.keys()) == set(b.keys()) - def transp(a): - return list(map(list, zip(*a))) - def suits(a: List[List[float]], b: List[List[float]]) -> bool: - at, bt = transp(a), transp(b) - return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt)) - return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid])) # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions] - _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_suit(trans.inits, node.init)] - # pp(("cached trans", _transitions)) + _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init) or pp(("not suit", trans.inits, node.init))] + pp(("cached trans", len(_transitions))) if len(_transitions) == 0: return None, [] transition = min(_transitions) - transitions = [] - for agent_id, seg in cache.items(): - # TODO: check for asserts - for tran in seg.transitions: - if tran.transition == transition: - # pp(("chosen tran", agent_id, tran)) - transitions.append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths)) + transitions = [(agent_id, tran) for agent_id, seg in cache.items() for tran in seg.transitions if tran.transition == transition] + # TODO: check for asserts + transitions = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(transitions, lambda p: (p[0], p[1].mode, p[1].dest))] return None, transitions path_transitions = defaultdict(int) @@ -543,16 +522,8 @@ class Scenario: if len(agent.controller.args) == 0: continue state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} - def dedup(l): - o = [] - for i in l: - for j in o: - if i.var == j.var and i.cond == j.cond and i.val == j.val: - break - else: - o.append(i) - return o - agent_paths = dedup([p for tran in segment.transitions for p in tran.paths]) + + agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val)) for path in agent_paths: cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense( self, agent, state_dict, self.map) -- GitLab