diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py index 99f1d429511c9a38a793226afe049d79ea5189c0..b42ef510c75cf0599cfb5bee0f3513a943d85ac6 100644 --- a/verse/analysis/incremental.py +++ b/verse/analysis/incremental.py @@ -213,7 +213,7 @@ class ReachTubeCache: def add_tube(self, agent_id: str, init: Dict[str, List[List[float]]], node: AnalysisTreeNode, transit_agents: List[str], transition, trans_ind: int, run_num: int): key = (agent_id,) + tuple(node.mode[agent_id]) tree = self.cache[key] - pp(('add seg', agent_id, node.mode[agent_id], init)) + # pp(('add seg', agent_id, node.mode[agent_id], init)) assert_hits = node.assert_hits or {} init = list(map(tuple, zip(*init[agent_id]))) for i, (low, high) in enumerate(init): @@ -251,7 +251,7 @@ class ReachTubeCache: def num_trans_suit(e: CachedRTTrans) -> int: return sum(1 if reach_trans_suit(t.inits, inits) else 0 for t in e.transitions) entries = list(sorted([(e, -num_trans_suit(e)) for e in entries], key=lambda p: p[1])) - pp(("check hit entries", len(entries), entries[0][1])) + # pp(("check hit entries", len(entries), entries[0][1])) return entries[0][0] @staticmethod diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py index 99781b2471ca718caa48c7c08d2377f4399a6feb..aa8bb74c418028e2a445f6e51cea8f6ad7a76896 100644 --- a/verse/analysis/simulator.py +++ b/verse/analysis/simulator.py @@ -7,6 +7,7 @@ import pprint from verse.agents.base_agent import BaseAgent from verse.analysis.incremental import SimTraceCache, convert_sim_trans, to_simulate +from verse.analysis.utils import dedup from verse.parser.parser import ModePath, find pp = functools.partial(pprint.pprint, compact=True, width=130) @@ -132,16 +133,7 @@ class Simulator: if agent_id in cached_segments: cached_segments[agent_id].transitions.extend(convert_sim_trans(agent_id, transit_agents, node.init, transition, transition_idx)) pre_len = len(cached_segments[agent_id].transitions) - def dedup(l): - o = [] - for i in l: - for j in o: - if i.disc == j.disc and i.cont == j.cont: - break - else: - o.append(i) - return o - cached_segments[agent_id].transitions = dedup(cached_segments[agent_id].transitions) + cached_segments[agent_id].transitions = dedup(cached_segments[agent_id].transitions, lambda i: (i.disc, i.cont)) # pp(("dedup!", pre_len, len(cached_segments[agent_id].transitions))) else: self.cache.add_segment(agent_id, node, transit_agents, full_traces[agent_id], transition, transition_idx, run_num) diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py index 30689e130f193912e6fef763ed6fa6b0576d68a8..99ff0d063a34a54708e88bd2bc15a14c3eb8c892 100644 --- a/verse/analysis/verifier.py +++ b/verse/analysis/verifier.py @@ -141,8 +141,8 @@ class Verifier: while verification_queue != []: node: AnalysisTreeNode = verification_queue.pop(0) combined_inits = {a: combine_all(inits) for a, inits in node.init.items()} - print() - pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode})) + # print() + # pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode})) remain_time = round(time_horizon - node.start_time, 10) if remain_time <= 0: continue @@ -155,7 +155,7 @@ class Verifier: combined = combine_all(inits) if self.config.incremental: cached = self.trans_cache.check_hit(agent_id, mode, combined, node.init) - pp(("check hit", agent_id, mode, combined)) + # pp(("check hit", agent_id, mode, combined)) if cached != None: cached_tubes[agent_id] = cached if agent_id not in node.trace: @@ -165,7 +165,7 @@ class Verifier: # trace[:,0] += node.start_time # node.trace[agent_id] = trace.tolist() if reachability_method == "DRYVR": - pp(('tube', agent_id, mode, inits)) + # pp(('tube', agent_id, mode, inits)) cur_bloated_tube = self.calculate_full_bloated_tube(agent_id, mode, inits, @@ -203,7 +203,7 @@ class Verifier: trace = np.array(cur_bloated_tube) trace[:, 0] += node.start_time node.trace[agent_id] = trace.tolist() - pp(("cached tubes", cached_tubes.keys())) + # pp(("cached tubes", cached_tubes.keys())) node_ids = list(set((s.run_num, s.node_id) for s in cached_tubes.values())) # assert len(node_ids) <= 1, f"{node_ids}" new_cache, paths_to_sim = {}, [] @@ -213,11 +213,11 @@ class Verifier: old_node = find(past_runs[old_run_num].nodes, lambda n: n.id == old_node_id) assert old_node != None new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_tubes) - pp(("to sim", new_cache.keys(), len(paths_to_sim))) + # pp(("to sim", new_cache.keys(), len(paths_to_sim))) # Get all possible transitions to next mode asserts, all_possible_transitions = transition_graph.get_transition_verify_new(new_cache, paths_to_sim, node) - pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions])) + # pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions])) node.assert_hits = asserts if asserts != None: asserts, idx = asserts @@ -227,7 +227,7 @@ class Verifier: transit_map = {k: list(l) for k, l in itertools.groupby(all_possible_transitions, key=lambda p:p[0])} transit_agents = transit_map.keys() - pp(("transit agents", transit_agents)) + # pp(("transit agents", transit_agents)) if self.config.incremental and len(all_possible_transitions) > 0: transit_ind = max(l[-2][-1] for l in all_possible_transitions) for agent_id in node.agent: @@ -236,15 +236,15 @@ class Verifier: cached_tubes[agent_id].transitions.extend(convert_reach_trans(agent_id, transit_agents, node.init, transition, transit_ind)) pre_len = len(cached_tubes[agent_id].transitions) cached_tubes[agent_id].transitions = dedup(cached_tubes[agent_id].transitions, lambda i: (i.mode, i.dest, i.inits)) - pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions))) + # pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions))) else: self.trans_cache.add_tube(agent_id, combined_inits, node, transit_agents, transition, transit_ind, run_num) max_end_idx = 0 for transition in all_possible_transitions: # Each transition will contain a list of rectangles and their corresponding indexes in the original list - if len(transition) != 6: - pp(("weird trans", transition)) + # if len(transition) != 6: + # pp(("weird trans", transition)) transit_agent_idx, src_mode, dest_mode, next_init, idx, path = transition start_idx, end_idx = idx[0], idx[-1] @@ -270,7 +270,7 @@ class Verifier: next_node_init[agent_idx] = next_init else: next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]] - pp(("infer init", agent_idx, next_node_init[agent_idx])) + # pp(("infer init", agent_idx, next_node_init[agent_idx])) next_node_trace[agent_idx] = truncated_trace[agent_idx] tmp = AnalysisTreeNode( diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py index 48d30ffe590e2d7ba5ab8aba9555283719b4f292..7134f835f7ce882b6d6e112ae15f5b1618ed45c8 100644 --- a/verse/scenario/scenario.py +++ b/verse/scenario/scenario.py @@ -502,8 +502,8 @@ class Scenario: # print(red("full cache")) # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions] - _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init) or pp(("not suit", trans.inits, node.init))] - pp(("cached trans", len(_transitions))) + _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init)]# or pp(("not suit", trans.inits, node.init))] + # pp(("cached trans", len(_transitions))) if len(_transitions) == 0: return None, [] transition = min(_transitions) @@ -540,9 +540,9 @@ class Scenario: continue cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond])) - for aid, trace in node.trace.items(): - if len(trace) < 2: - pp(("weird state", aid, trace)) + # for aid, trace in node.trace.items(): + # if len(trace) < 2: + # pp(("weird state", aid, trace)) for agent, idx, path in paths: if len(agent.controller.args) == 0: continue @@ -580,8 +580,8 @@ class Scenario: if len(agent.controller.args) == 0: continue agent_state, agent_mode, agent_static = state_dict[agent_id] - if np.array(agent_state).ndim != 2: - pp(("weird state", agent_id, agent_state)) + # if np.array(agent_state).ndim != 2: + # pp(("weird state", agent_id, agent_state)) agent_state = agent_state[1:] cont_vars, disc_vars, len_dict = self.sensor.sense(self, agent, state_dict, self.map) resets = defaultdict(list)