From 15d3ad1b171cee7d4c60fbcfa5ff001c5d0c0484 Mon Sep 17 00:00:00 2001 From: crides <zhuhaoqing@live.cn> Date: Thu, 13 Oct 2022 23:04:53 -0500 Subject: [PATCH] veri: port fix --- verse/scenario/scenario.py | 67 +++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py index 2f1848b4..1a445064 100644 --- a/verse/scenario/scenario.py +++ b/verse/scenario/scenario.py @@ -503,42 +503,41 @@ class Scenario: # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions] _transitions = [(aid, trans) for aid, seg in cache.items() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init)] # pp(("cached trans", len(_transitions))) - if len(_transitions) == 0: - return None, [] - min_trans_ind = min([t.transition for _, t in _transitions]) - # TODO: check for asserts - cached_trans = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(_transitions, lambda p: (p[0], p[1].mode, p[1].dest)) if tran.transition == min_trans_ind] - if len(paths) == 0: - # print(red("full cache")) - return None, cached_trans + if len(_transitions) > 0: + min_trans_ind = min([t.transition for _, t in _transitions]) + # TODO: check for asserts + cached_trans = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(_transitions, lambda p: (p[0], p[1].mode, p[1].dest)) if tran.transition == min_trans_ind] + if len(paths) == 0: + # print(red("full cache")) + return None, cached_trans - path_transitions = defaultdict(int) - for seg in cache.values(): - for tran in seg.transitions: - for p in tran.paths: - path_transitions[p.cond] = max(path_transitions[p.cond], tran.transition) - for agent_id, segment in cache.items(): - agent = node.agent[agent_id] - if len(agent.controller.args) == 0: - continue - state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} - - agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val)) - for path in agent_paths: - cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense( - self, agent, state_dict, self.map) - reset = (path.var, path.val_veri) - guard_expression = GuardExpressionAst([path.cond_veri]) - - cont_var_updater = guard_expression.parse_any_all_new( - cont_var_dict_template, discrete_variable_dict, length_dict) - self.apply_cont_var_updater( - cont_var_dict_template, cont_var_updater) - guard_can_satisfied = guard_expression.evaluate_guard_disc( - agent, discrete_variable_dict, cont_var_dict_template, self.map) - if not guard_can_satisfied: + path_transitions = defaultdict(int) + for seg in cache.values(): + for tran in seg.transitions: + for p in tran.paths: + path_transitions[p.cond] = max(path_transitions[p.cond], tran.transition) + for agent_id, segment in cache.items(): + agent = node.agent[agent_id] + if len(agent.controller.args) == 0: continue - cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond])) + state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} + + agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val)) + for path in agent_paths: + cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense( + self, agent, state_dict, self.map) + reset = (path.var, path.val_veri) + guard_expression = GuardExpressionAst([path.cond_veri]) + + cont_var_updater = guard_expression.parse_any_all_new( + cont_var_dict_template, discrete_variable_dict, length_dict) + self.apply_cont_var_updater( + cont_var_dict_template, cont_var_updater) + guard_can_satisfied = guard_expression.evaluate_guard_disc( + agent, discrete_variable_dict, cont_var_dict_template, self.map) + if not guard_can_satisfied: + continue + cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond])) # for aid, trace in node.trace.items(): # if len(trace) < 2: -- GitLab