Skip to content
Snippets Groups Projects
Commit 67b2c83b authored by crides's avatar crides
Browse files

sim bug port

parent 825696d0
No related branches found
No related tags found
1 merge request!9Tutorial
......@@ -497,23 +497,25 @@ class Scenario:
# For each agent
agent_guard_dict = defaultdict(list)
cached_guards = defaultdict(list)
min_trans_ind = None
cached_trans = defaultdict(list)
if not cache:
paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths]
else:
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init) or pp(("not suit", trans.inits, node.init))]
pp(("cached trans", len(_transitions)))
if len(_transitions) == 0:
return None, []
min_trans_ind = min(_transitions)
cached_trans = [(agent_id, tran) for agent_id, seg in cache.items() for tran in seg.transitions if tran.transition == min_trans_ind]
# TODO: check for asserts
cached_trans = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(cached_trans, lambda p: (p[0], p[1].mode, p[1].dest))]
if len(paths) == 0:
# print(red("full cache"))
# _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions]
_transitions = [trans.transition for seg in cache.values() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init)]# or pp(("not suit", trans.inits, node.init))]
# pp(("cached trans", len(_transitions)))
if len(_transitions) == 0:
return None, []
transition = min(_transitions)
transitions = [(agent_id, tran) for agent_id, seg in cache.items() for tran in seg.transitions if tran.transition == transition]
# TODO: check for asserts
transitions = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(transitions, lambda p: (p[0], p[1].mode, p[1].dest))]
return None, transitions
return None, cached_trans
path_transitions = defaultdict(int)
for seg in cache.values():
......@@ -573,6 +575,8 @@ class Scenario:
guard_hits = []
guard_hit = False
for idx in range(trace_length):
if min_trans_ind != None and idx >= min_trans_ind:
return None, cached_trans
any_contained = False
hits = []
state_dict = {aid: (node.trace[aid][idx*2:idx*2+2], node.mode[aid], node.static[aid]) for aid in node.agent}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment