Skip to content
Snippets Groups Projects
Commit 6a6f80e5 authored by crides's avatar crides
Browse files

bug

parent 0ff22527
No related branches found
No related tags found
1 merge request!9Tutorial
......@@ -53,7 +53,7 @@ class Simulator:
# Perform BFS through the simulation tree to loop through all possible transitions
while simulation_queue != []:
node: AnalysisTreeNode = simulation_queue.pop(0)
# pp(("start sim", node.start_time, {a: (*node.mode[a], *node.init[a]) for a in node.mode}))
pp(("start sim", node.start_time, {a: (*node.mode[a], *node.init[a]) for a in node.mode}))
remain_time = round(time_horizon - node.start_time, 10)
if remain_time <= 0:
continue
......@@ -63,7 +63,9 @@ class Simulator:
mode = node.mode[agent_id]
init = node.init[agent_id]
if self.config.incremental:
pp(("check hit", agent_id, mode, init))
cached = self.cache.check_hit(agent_id, mode, init)
pp(("check hit res", agent_id, len(cached.transitions) if cached != None else None))
else:
cached = None
if agent_id in node.trace:
......@@ -83,7 +85,7 @@ class Simulator:
trace[:, 0] += node.start_time
trace = trace.tolist()
node.trace[agent_id] = trace
# pp(("cached_segments", cached_segments.keys()))
pp(("cached_segments", cached_segments.keys()))
# TODO: for now, make sure all the segments comes from the same node; maybe we can do
# something to combine results from different nodes in the future
node_ids = list(set((s.run_num, s.node_id) for s in cached_segments.values()))
......@@ -95,7 +97,7 @@ class Simulator:
old_node = find(past_runs[old_run_num].nodes, lambda n: n.id == old_node_id)
assert old_node != None
new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_segments)
# pp(("to sim", new_cache.keys(), len(paths_to_sim)))
pp(("to sim", new_cache.keys(), len(paths_to_sim)))
# else:
# print("!!!")
......
......@@ -409,30 +409,31 @@ class Scenario:
# For each agent
agent_guard_dict = defaultdict(list)
cached_guards = defaultdict(list)
min_trans_ind = None
cached_trans = defaultdict(list)
if not cache:
paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths]
else:
def trans_close(a: Dict[str, List[float]], b: Dict[str, List[float]]) -> bool:
assert set(a.keys()) == set(b.keys())
return all(abs(av - bv) < _EPSILON for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_close(trans.inits, node.init)]
pp(("cached trans", _transitions))
if len(_transitions) == 0:
return None, None, 0
min_trans_ind = min(_transitions)
for agent_id, seg in cache.items():
# TODO: check for asserts
for tran in seg.transitions:
if tran.transition == min_trans_ind:
# pp(("chosen tran", agent_id, tran))
cached_trans[agent_id].append((agent_id, tran.disc, tran.cont, tran.paths))
if len(paths) == 0:
# print(red("full cache"))
def trans_close(a: Dict[str, List[float]], b: Dict[str, List[float]]) -> bool:
assert set(a.keys()) == set(b.keys())
return all(abs(av - bv) < _EPSILON for aid in a.keys() for av, bv in zip(a[aid], b[aid]))
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if trans_close(trans.inits, node.init)]
# pp(("cached trans", _transitions))
if len(_transitions) == 0:
return None, None, 0
transition = min(_transitions)
transitions = defaultdict(list)
for agent_id, seg in cache.items():
# TODO: check for asserts
for tran in seg.transitions:
if tran.transition == transition:
# pp(("chosen tran", agent_id, tran))
transitions[agent_id].append((agent_id, tran.disc, tran.cont, tran.paths))
return None, dict(transitions), transition
print(red("full cache"))
return None, dict(cached_trans), min_trans_ind
path_transitions = defaultdict(int)
for seg in cache.values():
......@@ -462,6 +463,8 @@ class Scenario:
transitions = defaultdict(list)
# TODO: We can probably rewrite how guard hit are detected and resets are handled for simulation
for idx in range(trace_length):
if min_trans_ind != None and idx >= min_trans_ind:
return None, dict(cached_trans), min_trans_ind
satisfied_guard = []
all_asserts = defaultdict(list)
for agent_id in agent_guard_dict:
......@@ -516,7 +519,6 @@ class Scenario:
for seg in cache.values():
for tran in seg.transitions:
for p in tran.paths:
pp(p)
path_transitions[p.cond] = max(path_transitions[p.cond], tran.transition)
for agent_id, segment in cache.items():
agent = node.agent[agent_id]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment