Skip to content
Snippets Groups Projects
Commit 729b14a8 authored by crides's avatar crides
Browse files

pend

parent e0ac034a
No related branches found
No related tags found
1 merge request!9Tutorial
......@@ -63,11 +63,11 @@ def to_simulate(old_agents: Dict[str, BaseAgent], new_agents: Dict[str, BaseAgen
raise NotImplementedError("different variable outputs")
for var, old_paths in old_grouped.items():
new_paths = new_grouped[var]
for old, new in itertools.zip_longest(old_paths, new_paths):
for i, (old, new) in enumerate(itertools.zip_longest(old_paths, new_paths)):
if new == None:
removed_paths.append(old)
elif old.cond != new.cond:
added_paths.append((new_agent, new))
added_paths.append((new_agent, i, new))
elif old.val != new.val:
reset_changed_paths.append(new)
new_cache = {}
......@@ -204,8 +204,8 @@ class ReachTubeCache:
def add_tube(self, agent_id: str, init: Dict[str, List[List[float]]], node: AnalysisTreeNode, transit_agents: List[str], transition, trans_ind: int, run_num: int):
key = (agent_id,) + tuple(node.mode[agent_id])
tree = self.cache[key]
pp(('add seg', agent_id, node.mode[agent_id], init))
assert_hits = node.assert_hits or {}
# pp(('add seg', agent_id, *node.mode[agent_id], *init))
init = list(map(tuple, zip(*init[agent_id])))
for i, (low, high) in enumerate(init):
if i == len(init) - 1:
......
......@@ -13,7 +13,7 @@ pp = functools.partial(pprint.pprint, compact=True, width=130)
# from verse.agents.base_agent import BaseAgent
from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree
PathDiffs = List[Tuple[BaseAgent, ModePath]]
PathDiffs = List[Tuple[BaseAgent, int, ModePath]]
def red(s):
return "\x1b[31m" + s + "\x1b[0m" #]]
......
......@@ -140,6 +140,7 @@ class Verifier:
while verification_queue != []:
node: AnalysisTreeNode = verification_queue.pop(0)
combined_inits = {a: combine_all(inits) for a, inits in node.init.items()}
print()
pp(("start sim", node.start_time, {a: (*node.mode[a], *combined_inits[a]) for a in node.mode}))
remain_time = round(time_horizon - node.start_time, 10)
if remain_time <= 0:
......@@ -150,9 +151,8 @@ class Verifier:
for agent_id in node.agent:
mode = node.mode[agent_id]
inits = node.init[agent_id]
init = combined_inits[agent_id]
if self.config.incremental:
cached = self.trans_cache.check_hit(agent_id, mode, init)
cached = self.trans_cache.check_hit(agent_id, mode, combined_inits[agent_id])
if cached != None:
cached_tubes[agent_id] = cached
if agent_id not in node.trace:
......@@ -200,7 +200,7 @@ class Verifier:
trace = np.array(cur_bloated_tube)
trace[:, 0] += node.start_time
node.trace[agent_id] = trace.tolist()
pp(("cached_segments", cached_tubes.keys()))
pp(("cached tubes", cached_tubes.keys()))
node_ids = list(set((s.run_num, s.node_id) for s in cached_tubes.values()))
# assert len(node_ids) <= 1, f"{node_ids}"
new_cache, paths_to_sim = {}, []
......@@ -214,7 +214,7 @@ class Verifier:
# Get all possible transitions to next mode
asserts, all_possible_transitions = transition_graph.get_transition_verify_new(new_cache, paths_to_sim, node)
pp(("transitions:", all_possible_transitions))
pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions]))
node.assert_hits = asserts
if asserts != None:
asserts, idx = asserts
......@@ -249,6 +249,8 @@ class Verifier:
max_end_idx = 0
for transition in all_possible_transitions:
# Each transition will contain a list of rectangles and their corresponding indexes in the original list
if len(transition) != 6:
pp(("weird trans", transition))
transit_agent_idx, src_mode, dest_mode, next_init, idx, path = transition
start_idx, end_idx = idx[0], idx[-1]
......@@ -274,7 +276,7 @@ class Verifier:
next_node_init[agent_idx] = next_init
else:
next_node_init[agent_idx] = [[truncated_trace[agent_idx][0][1:], truncated_trace[agent_idx][1][1:]]]
if agent_idx != transit_agent_idx:
pp(("infer init", agent_idx, next_node_init[agent_idx], truncated_trace[agent_idx][:8]))
next_node_trace[agent_idx] = truncated_trace[agent_idx]
tmp = AnalysisTreeNode(
......
......@@ -11,11 +11,12 @@ import numpy as np
from verse.agents.base_agent import BaseAgent
from verse.analysis.dryvr import _EPSILON
from verse.analysis.incremental import CachedRTTrans, CachedSegment
from verse.analysis.incremental import CachedRTTrans, CachedSegment, combine_all
from verse.analysis.simulator import PathDiffs
from verse.automaton import GuardExpressionAst, ResetExpression
from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree
from verse.analysis.utils import sample_rect
from verse.parser import astunparser
from verse.parser.parser import ControllerIR, ModePath, find
from verse.sensor.base_sensor import BaseSensor
from verse.map.lane_map import LaneMap
......@@ -307,7 +308,7 @@ class Scenario:
# The reset_list here are all the resets for a single transition. Need to evaluate each of them
# and then combine them together
for reset_tuple in reset_list:
reset, disc_var_dict, cont_var_dict, _ = reset_tuple
reset, disc_var_dict, cont_var_dict, _, _p = reset_tuple
reset_variable = reset.var
expr = reset.expr
# First get the transition destinations
......@@ -457,7 +458,7 @@ class Scenario:
for path in agent_paths:
cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond]))
for agent, path in paths:
for agent, _idx, path in paths:
# Get guard
if len(agent.controller.args) == 0:
continue
......@@ -496,7 +497,7 @@ class Scenario:
break
return None, dict(transitions), idx
def get_transition_verify_new(self, cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]:
def get_transition_verify_new(self, cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]]]:
lane_map = self.map
# For each agent
......@@ -504,28 +505,33 @@ class Scenario:
cached_guards = defaultdict(list)
if not cache:
paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths]
paths = [(agent, i, p) for agent in node.agent.values() for i, p in enumerate(agent.controller.paths)]
else:
if len(paths) == 0:
# print(red("full cache"))
def trans_close(a: Dict[str, List[float]], b: Dict[str, List[float]]) -> bool:
def trans_suit(a: Dict[str, List[List[List[float]]]], b: Dict[str, List[List[List[float]]]]) -> bool:
assert set(a.keys()) == set(b.keys())
return all(abs(av - bv) < _EPSILON for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_close(trans.inits, node.init)]
def transp(a):
return list(map(list, zip(*a)))
def suits(a: List[List[float]], b: List[List[float]]) -> bool:
at, bt = transp(a), transp(b)
return all(al <= bl and ah >= bh for (al, ah), (bl, bh) in zip(at, bt))
return all(suits(av, bv) for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_suit(trans.inits, node.init)]
# pp(("cached trans", _transitions))
if len(_transitions) == 0:
return None, None, 0
return None, []
transition = min(_transitions)
transitions = defaultdict(list)
transitions = []
for agent_id, seg in cache.items():
# TODO: check for asserts
for tran in seg.transitions:
if tran.transition == transition:
# pp(("chosen tran", agent_id, tran))
transitions[agent_id].append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths))
return None, dict(transitions)
transitions.append((agent_id, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths))
return None, transitions
path_transitions = defaultdict(int)
for seg in cache.values():
......@@ -563,7 +569,10 @@ class Scenario:
continue
cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond]))
for agent, path in paths:
for aid, trace in node.trace.items():
if len(trace) < 2:
pp(("weird state", aid, trace))
for agent, idx, path in paths:
if len(agent.controller.args) == 0:
continue
agent_id = agent.id
......@@ -585,7 +594,8 @@ class Scenario:
agent_guard_dict[agent_id].append(
(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), path))
trace_length = int(len(list(node.trace.values())[0])/2)
trace_length = int(min(len(v) for v in node.trace.values()) // 2)
# pp(("trace len", trace_length, {a: len(t) for a, t in node.trace.items()}))
guard_hits = []
guard_hit = False
for idx in range(trace_length):
......@@ -599,6 +609,8 @@ class Scenario:
if len(agent.controller.args) == 0:
continue
agent_state, agent_mode, agent_static = state_dict[agent_id]
if np.array(agent_state).ndim != 2:
pp(("weird state", agent_id, agent_state))
agent_state = agent_state[1:]
cont_vars, disc_vars, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
resets = defaultdict(list)
......@@ -673,8 +685,8 @@ class Scenario:
for hits, all_agent_state, hit_idx in guard_hits:
for agent_id, reset_idx, reset_list in hits:
# TODO: Need to change this function to handle the new reset expression and then I am done
dest_list, reset_rect = self.apply_reset(
node.agent[agent_id], reset_list[:-1], all_agent_state)
dest_list, reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
# pp(("dests", dest_list, *[astunparser.unparse(reset[-1].val_veri) for reset in reset_list]))
if agent_id not in reset_dict:
reset_dict[agent_id] = {}
if not dest_list:
......@@ -693,9 +705,10 @@ class Scenario:
for agent in reset_dict:
for reset_idx in reset_dict[agent]:
for dest in reset_dict[agent][reset_idx]:
if list(dest) != list(node.mode[agent]):
resets = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest])))
transition = (agent, node.mode[agent], dest, *resets)
possible_transitions.append(transition)
# reset_data = reset_dict[agent][reset_idx][dest]
reset_data = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest])))
# pp(("resets", reset_data))
transition = (agent, node.mode[agent], dest, *reset_data)
possible_transitions.append(transition)
# Return result
return None, possible_transitions
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment