Skip to content
Snippets Groups Projects
Commit 21445a2b authored by crides's avatar crides
Browse files

no trace

parent d8a5a6d7
No related branches found
No related tags found
1 merge request!9Tutorial
......@@ -213,7 +213,7 @@ class ReachTubeCache:
def add_tube(self, agent_id: str, init: Dict[str, List[List[float]]], node: AnalysisTreeNode, transit_agents: List[str], transition, trans_ind: int, run_num: int):
key = (agent_id,) + tuple(node.mode[agent_id])
tree = self.cache[key]
pp(('add seg', agent_id, node.mode[agent_id], init))
# pp(('add seg', agent_id, node.mode[agent_id], init))
assert_hits = node.assert_hits or {}
init = list(map(tuple, zip(*init[agent_id])))
for i, (low, high) in enumerate(init):
......@@ -251,7 +251,7 @@ class ReachTubeCache:
def num_trans_suit(e: CachedRTTrans) -> int:
return sum(1 if reach_trans_suit(t.inits, inits) else 0 for t in e.transitions)
entries = list(sorted([(e, -num_trans_suit(e)) for e in entries], key=lambda p: p[1]))
pp(("check hit entries", len(entries), entries[0][1]))
# pp(("check hit entries", len(entries), entries[0][1]))
return entries[0][0]
@staticmethod
......
......@@ -7,6 +7,7 @@ import pprint
from verse.agents.base_agent import BaseAgent
from verse.analysis.incremental import SimTraceCache, convert_sim_trans, to_simulate
from verse.analysis.utils import dedup
from verse.parser.parser import ModePath, find
pp = functools.partial(pprint.pprint, compact=True, width=130)
......@@ -132,16 +133,7 @@ class Simulator:
if agent_id in cached_segments:
cached_segments[agent_id].transitions.extend(convert_sim_trans(agent_id, transit_agents, node.init, transition, transition_idx))
pre_len = len(cached_segments[agent_id].transitions)
def dedup(l):
o = []
for i in l:
for j in o:
if i.disc == j.disc and i.cont == j.cont:
break
else:
o.append(i)
return o
cached_segments[agent_id].transitions = dedup(cached_segments[agent_id].transitions)
cached_segments[agent_id].transitions = dedup(cached_segments[agent_id].transitions, lambda i: (i.disc, i.cont))
# pp(("dedup!", pre_len, len(cached_segments[agent_id].transitions)))
else:
self.cache.add_segment(agent_id, node, transit_agents, full_traces[agent_id], transition, transition_idx, run_num)
......
......@@ -141,8 +141,8 @@ class Verifier:
while verification_queue != []:
node: AnalysisTreeNode = verification_queue.pop(0)
combined_inits = {a: combine_all(inits) for a, inits in node.init.items()}
print()
pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode}))
# print()
# pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode}))
remain_time = round(time_horizon - node.start_time, 10)
if remain_time <= 0:
continue
......@@ -155,7 +155,7 @@ class Verifier:
combined = combine_all(inits)
if self.config.incremental:
cached = self.trans_cache.check_hit(agent_id, mode, combined, node.init)
pp(("check hit", agent_id, mode, combined))
# pp(("check hit", agent_id, mode, combined))
if cached != None:
cached_tubes[agent_id] = cached
if agent_id not in node.trace:
......@@ -165,7 +165,7 @@ class Verifier:
# trace[:,0] += node.start_time
# node.trace[agent_id] = trace.tolist()
if reachability_method == "DRYVR":
pp(('tube', agent_id, mode, inits))
# pp(('tube', agent_id, mode, inits))
cur_bloated_tube = self.calculate_full_bloated_tube(agent_id,
mode,
inits,
......@@ -203,7 +203,7 @@ class Verifier:
trace = np.array(cur_bloated_tube)
trace[:, 0] += node.start_time
node.trace[agent_id] = trace.tolist()
pp(("cached tubes", cached_tubes.keys()))
# pp(("cached tubes", cached_tubes.keys()))
node_ids = list(set((s.run_num, s.node_id) for s in cached_tubes.values()))
# assert len(node_ids) <= 1, f"{node_ids}"
new_cache, paths_to_sim = {}, []
......@@ -213,11 +213,11 @@ class Verifier:
old_node = find(past_runs[old_run_num].nodes, lambda n: n.id == old_node_id)
assert old_node != None
new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_tubes)
pp(("to sim", new_cache.keys(), len(paths_to_sim)))
# pp(("to sim", new_cache.keys(), len(paths_to_sim)))
# Get all possible transitions to next mode
asserts, all_possible_transitions = transition_graph.get_transition_verify_new(new_cache, paths_to_sim, node)
pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions]))
# pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions]))
node.assert_hits = asserts
if asserts != None:
asserts, idx = asserts
......@@ -227,7 +227,7 @@ class Verifier:
transit_map = {k: list(l) for k, l in itertools.groupby(all_possible_transitions, key=lambda p:p[0])}
transit_agents = transit_map.keys()
pp(("transit agents", transit_agents))
# pp(("transit agents", transit_agents))
if self.config.incremental and len(all_possible_transitions) > 0:
transit_ind = max(l[-2][-1] for l in all_possible_transitions)
for agent_id in node.agent:
......@@ -236,15 +236,15 @@ class Verifier:
cached_tubes[agent_id].transitions.extend(convert_reach_trans(agent_id, transit_agents, node.init, transition, transit_ind))
pre_len = len(cached_tubes[agent_id].transitions)
cached_tubes[agent_id].transitions = dedup(cached_tubes[agent_id].transitions, lambda i: (i.mode, i.dest, i.inits))
pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions)))
# pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions)))
else:
self.trans_cache.add_tube(agent_id, combined_inits, node, transit_agents, transition, transit_ind, run_num)
max_end_idx = 0
for transition in all_possible_transitions:
# Each transition will contain a list of rectangles and their corresponding indexes in the original list
if len(transition) != 6:
pp(("weird trans", transition))
# if len(transition) != 6:
# pp(("weird trans", transition))
transit_agent_idx, src_mode, dest_mode, next_init, idx, path = transition
start_idx, end_idx = idx[0], idx[-1]
......@@ -270,7 +270,7 @@ class Verifier:
next_node_init[agent_idx] = next_init
else:
next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]]
pp(("infer init", agent_idx, next_node_init[agent_idx]))
# pp(("infer init", agent_idx, next_node_init[agent_idx]))
next_node_trace[agent_idx] = truncated_trace[agent_idx]
tmp = AnalysisTreeNode(
......
......@@ -502,8 +502,8 @@ class Scenario:
# print(red("full cache"))
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init) or pp(("not suit", trans.inits, node.init))]
pp(("cached trans", len(_transitions)))
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init)]# or pp(("not suit", trans.inits, node.init))]
# pp(("cached trans", len(_transitions)))
if len(_transitions) == 0:
return None, []
transition = min(_transitions)
......@@ -540,9 +540,9 @@ class Scenario:
continue
cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond]))
for aid, trace in node.trace.items():
if len(trace) < 2:
pp(("weird state", aid, trace))
# for aid, trace in node.trace.items():
# if len(trace) < 2:
# pp(("weird state", aid, trace))
for agent, idx, path in paths:
if len(agent.controller.args) == 0:
continue
......@@ -580,8 +580,8 @@ class Scenario:
if len(agent.controller.args) == 0:
continue
agent_state, agent_mode, agent_static = state_dict[agent_id]
if np.array(agent_state).ndim != 2:
pp(("weird state", agent_id, agent_state))
# if np.array(agent_state).ndim != 2:
# pp(("weird state", agent_id, agent_state))
agent_state = agent_state[1:]
cont_vars, disc_vars, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
resets = defaultdict(list)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment