diff --git a/example_two_car_lane_switch.py b/example_two_car_lane_switch.py index 69b9b9250725ea088b49dd9df42f3763f6e2f498..74e9a88c74605a34324a09e584f122dad8910467 100644 --- a/example_two_car_lane_switch.py +++ b/example_two_car_lane_switch.py @@ -71,6 +71,7 @@ if __name__ == "__main__": ] ) # simulator = Simulator() + # traces = scenario.simulate(40) traces = scenario.verify(40) plt.plot([0, 40], [3, 3], 'g') diff --git a/ourtool/analysis/verifier.py b/ourtool/analysis/verifier.py index bb395303389b6fd0d2532fb0deae36b4ad26399c..31cfd7f4e312b12fe3bb654d4e55cc60409e5700 100644 --- a/ourtool/analysis/verifier.py +++ b/ourtool/analysis/verifier.py @@ -61,24 +61,46 @@ class Verifier: trace = np.array(cur_bloated_tube) trace[:,0] += node.start_time node.trace[agent_id] = trace.tolist() - print("here") + # print("here") + + # Check safety conditions here - trace_length = int(len(list(node.trace.values())[0])/2) - guard_hits = [] - for idx in range(0,trace_length): - # For each trace, check with the guard to see if there's any possible transition - # Store all possible transition in a list - # A transition is defined by (agent, src_mode, dest_mode, corresponding reset, transit idx) - # Here we enforce that only one agent transit at a time - all_agent_state = {} - for agent_id in node.agent: - all_agent_state[agent_id] = (node.trace[agent_id][idx*2:idx*2+2], node.mode[agent_id]) - guards, resets, is_contain = transition_graph.check_guard_hit(all_agent_state) - if possible_transitions != []: - for agent_idx, src_mode, dest_mode, next_init, contained in possible_transitions: - transitions.append((agent_idx, src_mode, dest_mode, next_init, idx)) - any_contained = any_contained or contained - if any_contained: - break - pass + # Get all possible transitions to next mode + all_possible_transitions = transition_graph.get_all_transition_set(node) + max_end_idx = 0 + for transition in all_possible_transitions: + transit_agent_idx, src_mode, dest_mode, next_init, idx = transition + start_idx, end_idx = idx + + truncated_trace = {} + for agent_idx in node.agent: + truncated_trace[agent_idx] = node.trace[agent_idx][start_idx*2:] + if end_idx > max_end_idx: + max_end_idx = end_idx + next_node_mode = copy.deepcopy(node.mode) + next_node_mode[transit_agent_idx] = dest_mode + next_node_agent = node.agent + next_node_start_time = list(truncated_trace.values())[0][0][0] + next_node_init = {} + next_node_trace = {} + for agent_idx in next_node_agent: + if agent_idx == transit_agent_idx: + next_node_init[agent_idx] = next_init + else: + next_node_trace[agent_idx] = truncated_trace[agent_idx] + + tmp = AnalysisTreeNode( + trace = next_node_trace, + init = next_node_init, + mode = next_node_mode, + agent = next_node_agent, + child = [], + start_time = next_node_start_time + ) + node.child.append(tmp) + verification_queue.append(tmp) + """Truncate trace of current node based on max_end_idx""" + for agent_idx in node.agent: + node.trace[agent_idx] = node.trace[agent_idx][:(max_end_idx+1)*2] + \ No newline at end of file diff --git a/ourtool/automaton/guard.py b/ourtool/automaton/guard.py index a68e5e0aaa4ef674b4e49ec31c3cfafff6934e18..5ff33274a1b4633b7d38f159164ead81d71fea1d 100644 --- a/ourtool/automaton/guard.py +++ b/ourtool/automaton/guard.py @@ -6,7 +6,9 @@ import pickle # from pythonparser import Guard import ast -from pkg_resources import compatible_platforms +from z3 import * +import sympy + import astunparse class LogicTreeNode: @@ -16,354 +18,212 @@ class LogicTreeNode: self.val = val self.mode_guard = mode_guard -''' -class GuardExpression: - def __init__(self, root:LogicTreeNode=None, logic_str:str=None, guard_list=None): - self._func_dict = {} +class GuardExpressionAst: + def __init__(self, guard_list): + self.ast_list = [] + for guard in guard_list: + self.ast_list.append(guard.ast) + self.cont_variables = {} + self.varDict = {'t':Real('t')} - self.logic_tree_root = root - self.logic_string = logic_str + def _build_guard(self, guard_str, agent): + """ + Build solver for current guard based on guard string - if self.logic_tree_root is None and logic_str is not None: - self.construct_tree_from_str(logic_str) - elif guard_list is not None: - self.construct_tree_from_list(guard_list) + Args: + guard_str (str): the guard string. + For example:"And(v>=40-0.1*u, v-40+0.1*u<=0)" - def construct_tree_from_list(self, guard_list:List[Guard]): - # guard_list = ['('+elem.code+')' for elem in guard_list] - tmp = [] - func_count = 0 - for guard in guard_list: - if guard.func is not None: - func_identifier = f'func{func_count}' - self._func_dict[func_identifier] = guard.code - tmp.append(f'({func_identifier})') - else: - tmp.append('('+guard.code+')') - - guard_str = ' and '.join(tmp) - self.construct_tree_from_str(guard_str) - - def logic_string_split(self, logic_string): - # Input: - # logic_string: str, a python logic expression - # Output: - # List[str], a list of string containing atomics, logic operator and brackets - # The function take a python logic expression and split the expression into brackets, atomics and logic operators - # logic_string = logic_string.replace(' ','') - res = re.split('( and )',logic_string) - - tmp = [] - for sub_str in res: - tmp += re.split('( or )',sub_str) - res = tmp - - tmp = [] - for sub_str in res: - tmp += re.split('(\()',sub_str) - res = tmp - - tmp = [] - for sub_str in res: - tmp += re.split('(\))',sub_str) - res = tmp - - while("" in res) : - res.remove("") - while(" " in res): - res.remove(" ") - for i,sub_str in enumerate(res): - res[i]= sub_str.strip(' ') - - # Handle spurious brackets in the splitted string - # Get all the index of brackets pairs in the splitted string - # Construct brackets tree - # class BracketTreeNode: - # def __init__(self): - # self.left_idx = None - # self.right_idx = None - # self.child = [] - bracket_stack = [] - for i in range(len(res)): - if res[i] == "(": - bracket_stack.append(i) - elif res[i] == ")": - left_idx = bracket_stack.pop() - sub_list = res[left_idx:i+1] - # Check for each brackets pairs if there's any logic operators in between - # If no, combine things in between and the brackets together, reconstruct the list - if "and" not in sub_list and "or" not in sub_list: - res[left_idx] = "".join(sub_list) - for j in range(left_idx+1,i+1): - res[j] = "" - - # For each pair of logic operator - start_idx = 0 - end_idx = 0 - for i in range(len(res)): - if res[i]!="(": - start_idx = i - break + Returns: + A Z3 Solver obj that check for guard. + A symbol index dic obj that indicates the index + of variables that involved in the guard. + """ + cur_solver = Solver() + # This magic line here is because SymPy will evaluate == to be False + # Therefore we are not be able to get free symbols from it + # Thus we need to replace "==" to something else + sympy_guard_str = guard_str.replace("==", ">=") + for vars in self.cont_variables: + sympy_guard_str = sympy_guard_str.replace(vars, self.cont_variables[vars]) + + symbols = list(sympy.sympify(sympy_guard_str, evaluate=False).free_symbols) + symbols = [str(s) for s in symbols] + tmp = list(self.cont_variables.values()) + symbols_map = {} + for s in symbols: + if s in tmp: + key = list(self.cont_variables.keys())[list(self.cont_variables.values()).index(s)] + symbols_map[s] = key + + for vars in reversed(self.cont_variables): + guard_str = guard_str.replace(vars, self.cont_variables[vars]) + guard_str = self._handleReplace(guard_str) + cur_solver.add(eval(guard_str)) # TODO use an object instead of `eval` a string + return cur_solver, symbols_map + + def _handleReplace(self, input_str): + """ + Replace variable in inputStr to self.varDic["variable"] + For example: + input + And(y<=0,t>=0.2,v>=-0.1) + output: + And(self.varDic["y"]<=0,self.varDic["t"]>=0.2,self.varDic["v"]>=-0.1) - for i in range(len(res)): - if res[i] == "and" or res[i] == "or": - end_idx = i - sub_list = res[start_idx:end_idx] - # Check if there's any dangling brackents in between. - # If no, combine things between logic operators - if "(" not in sub_list and ")" not in sub_list: - res[start_idx] = "".join(sub_list) - for j in range(start_idx+1, end_idx): - res[j] = "" - start_idx = end_idx + 1 - while("" in res) : - res.remove("") - - # Put back functions - for i in range(len(res)): - for key in self._func_dict: - if key in res[i]: - res[i] = res[i].replace(key, self._func_dict[key]) - # if res[i] in self._func_dict: - # res[i] = self._func_dict[res[i]] - return res - - def construct_tree_from_str(self, logic_string:str): - # Convert an infix expression notation to an expression tree - # https://www.geeksforgeeks.org/program-to-convert-infix-notation-to-expression-tree/ - - self.logic_string = logic_string - logic_string = "(" + logic_string + ")" - s = self.logic_string_split(logic_string) - - stN = [] - stC = [] - p = {} - p["and"] = 1 - p["or"] = 1 - p[")"] = 0 - - for i in range(len(s)): - if s[i] == "(": - stC.append(s[i]) - - elif s[i] not in p: - t = LogicTreeNode(s[i]) - stN.append(t) - - elif(p[s[i]]>0): - while (len(stC) != 0 and stC[-1] != '(' and p[stC[-1]] >= p[s[i]]): - # Get and remove the top element - # from the character stack - t = LogicTreeNode(stC[-1]) - stC.pop() - - # Get and remove the top element - # from the node stack - t1 = stN[-1] - stN.pop() - - # Get and remove the currently top - # element from the node stack - t2 = stN[-1] - stN.pop() - - # Update the tree - t.child = [t1, t2] - - # Push the node to the node stack - stN.append(t) - stC.append(s[i]) - elif (s[i] == ')'): - while (len(stC) != 0 and stC[-1] != '('): - # from the character stack - t = LogicTreeNode(stC[-1]) - stC.pop() - - # Get and remove the top element - # from the node stack - t1 = stN[-1] - stN.pop() - - # Get and remove the currently top - # element from the node stack - t2 = stN[-1] - stN.pop() - - # Update the tree - t.child = [t1, t2] - - # Push the node to the node stack - stN.append(t) - stC.pop() - t = stN[-1] - self.logic_tree_root = t - - def generate_guard_string_python(self): - return self._generate_guard_string_python(self.logic_tree_root) - - def _generate_guard_string_python(self, root: LogicTreeNode)->str: - if root.data!="and" and root.data!="or": - return root.data - else: - data1 = self._generate_guard_string_python(root.child[0]) - data2 = self._generate_guard_string_python(root.child[1]) - return f"({data1} {root.data} {data2})" + Args: + input_str (str): original string need to be replaced + keys (list): list of variable strings - def generate_guard_string(self): - return self._generate_guard_string(self.logic_tree_root) + Returns: + str: a string that all variables have been replaced into a desire form - def _generate_guard_string(self, root: LogicTreeNode)->str: - if root.data!="and" and root.data!="or": - return root.data - else: - data1 = self._generate_guard_string(root.child[0]) - data2 = self._generate_guard_string(root.child[1]) - if root.data == "and": - return f"And({data1},{data2})" - elif root.data == "or": - return f"Or({data1},{data2})" + """ + idxes = [] + i = 0 + original = input_str + keys = list(self.varDict.keys()) + + keys.sort(key=lambda s: len(s)) + for key in keys[::-1]: + for i in range(len(input_str)): + if input_str[i:].startswith(key): + idxes.append((i, i + len(key))) + input_str = input_str[:i] + "@" * len(key) + input_str[i + len(key):] + + idxes = sorted(idxes) + + input_str = original + for idx in idxes[::-1]: + key = input_str[idx[0]:idx[1]] + target = 'self.varDict["' + key + '"]' + input_str = input_str[:idx[0]] + target + input_str[idx[1]:] + return input_str - def evaluate_guard(self, agent, continuous_variable_dict, discrete_variable_dict, lane_map): - res = self._evaluate_guard(self.logic_tree_root, agent, continuous_variable_dict, discrete_variable_dict, lane_map) - return res + def evaluate_guard_cont(self, agent, continuous_variable_dict, lane_map): + res = False + is_contained = False - def _evaluate_guard(self, root, agent, cnts_var_dict, disc_var_dict, lane_map): - if root.child == []: - expr = root.data - # Check if the root is a function - if 'map' in expr: - # tmp = re.split('\(|\)',expr) - # while "" in tmp: - # tmp.remove("") - # for arg in tmp[1:]: - # if arg in disc_var_dict: - # expr = expr.replace(arg,f'"{disc_var_dict[arg]}"') - # res = eval(expr) - for arg in disc_var_dict: - expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') - for arg in cnts_var_dict: - expr = expr.replace(arg, str(cnts_var_dict[arg])) - res = eval(expr) - return res - # Elif check if the root contain any discrete data + for cont_vars in continuous_variable_dict: + self.cont_variables[cont_vars] = cont_vars.replace('.','_') + self.varDict[cont_vars.replace('.','_')] = Real(cont_vars.replace('.','_')) + + z3_string = self.generate_z3_expression() + if isinstance(z3_string, bool): + if z3_string: + return True, True else: - is_mode_guard = False - for key in disc_var_dict: - if key in expr: - is_mode_guard = True - val = disc_var_dict[key] - for mode_name in agent.controller.modes: - if val in agent.controller.modes[mode_name]: - val = mode_name+'.'+val - break - expr = expr.replace(key, val) - if is_mode_guard: - # Execute guard, assign type and and return result - root.mode_guard = True - expr = expr.strip('(') - expr = expr.strip(')') - expr = expr.replace(' ','') - expr = expr.split('==') - res = expr[0] == expr[1] - # res = eval(expr) - root.val = res - return res - # Elif have cnts variable guard handle cnts variable guard - else: - for key in cnts_var_dict: - expr = expr.replace(key, str(cnts_var_dict[key])) - res = eval(expr) - return res - # For the two children, call _execute_guard and collect result - res1 = self._evaluate_guard(root.child[0],agent,cnts_var_dict, disc_var_dict, lane_map) - res2 = self._evaluate_guard(root.child[1],agent,cnts_var_dict, disc_var_dict, lane_map) - # Evaluate result for current node - if root.data == "and": - res = res1 and res2 - elif root.data == "or": - res = res1 or res2 - else: - raise ValueError(f"Invalid root data {root.data}") - return res + return False, False + + cur_solver, symbols = self._build_guard(z3_string, agent) + cur_solver.push() + for symbol in symbols: + cur_solver.add(self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0]) + cur_solver.add(self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1]) + if cur_solver.check() == sat: + # The reachtube hits the guard + cur_solver.pop() + res = True + + # TODO: If the reachtube completely fall inside guard, break + tmp_solver = Solver() + tmp_solver.add(Not(cur_solver.assertions()[0])) + for symbol in symbols: + tmp_solver.add(self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0]) + tmp_solver.add(self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1]) + if tmp_solver.check() == unsat: + print("Full intersect, break") + is_contained = True - def execute_guard(self, discrete_variable_dict:Dict) -> bool: - # This function will execute guard, and remove guard related to mode from the tree - # We can do this recursively - res = self._execute_guard(self.logic_tree_root, discrete_variable_dict) - - return res + return res, is_contained - def _execute_guard(self, root:LogicTreeNode, discrete_variable_dict:Dict) -> bool: - # If is tree leaf - if root.child == []: - # Check if the expression involves mode - expr = root.data - is_mode_guard = False - for key in discrete_variable_dict: - if key in expr: - is_mode_guard = True - expr = expr.replace(key, discrete_variable_dict[key]) - if is_mode_guard: - # Execute guard, assign type and and return result - root.mode_guard = True - expr = expr.strip('(') - expr = expr.strip(')') - expr = expr.replace(' ','') - expr = expr.split('==') - res = expr[0] == expr[1] - # res = eval(expr) - root.val = res - return res - # Otherwise, return True - else: - root.mode_guard = False - root.val = True - return True - # For the two children, call _execute_guard and collect result - res1 = self._execute_guard(root.child[0],discrete_variable_dict) - res2 = self._execute_guard(root.child[1],discrete_variable_dict) - # Evaluate result for current node - if root.data == "and": - res = res1 and res2 - elif root.data == "or": - res = res1 or res2 - else: - raise ValueError(f"Invalid root data {root.data}") - - # If the result is False, return False - if not res: - return False - # Else if any child have false result, remove that child - else: - if not res1 or root.child[0].mode_guard: - root.data = root.child[1].data - root.val = root.child[1].val - root.mode_guard = root.child[1].mode_guard - root.child = root.child[1].child - elif not res2 or root.child[1].mode_guard: - root.data = root.child[0].data - root.val = root.child[0].val - root.mode_guard = root.child[0].mode_guard - root.child = root.child[0].child - return True -''' + def generate_z3_expression(self): + """ + The return value of this function will be a bool/str -class GuardExpressionAst: - def __init__(self, guard_list): - self.ast_list = [] - for guard in guard_list: - self.ast_list.append(guard.ast) + If without evaluating the continuous variables the result is True, then + the guard will automatically be satisfied and is_contained will be True - def evaluate_guard_cont(self, agent, continuous_variable_dict, lane_map): - res = True - is_contained = True - # TODO + If without evaluating the continuous variables the result is False, th- + en the guard will automatically be unsatisfied - return res, is_contained + If the result is a string, then continuous variables will be checked to + see if the guard can be satisfied + """ + res = [] + for node in self.ast_list: + tmp = self._generate_z3_expression_node(node) + if isinstance(tmp, bool): + if not tmp: + return False + else: + continue + res.append(tmp) + if res == []: + return True + elif len(res) == 1: + return res[0] + res = "And("+",".join(res)+")" + return res - # def _evaluate_guard_cont(self, root, agent, cont_var_dict, lane_map): - # return False + def _generate_z3_expression_node(self, node): + """ + Perform a DFS over expression ast and generate the guard expression + The return value of this function can be a bool/str + + If without evaluating the continuous variables the result is True, then + the guard condition will automatically be satisfied + + If without evaluating the continuous variables the result is False, then + the guard condition will not be satisfied + + If the result is a string, then continuous variables will be checked to + see if the guard can be satisfied + """ + if isinstance(node, ast.BoolOp): + # Check the operator + # For each value in the boolop, check results + if isinstance(node.op, ast.And): + z3_str = [] + for i,val in enumerate(node.values): + tmp = self._generate_z3_expression_node(val) + if isinstance(tmp, bool): + if tmp: + continue + else: + return False + z3_str.append(tmp) + z3_str = 'And('+','.join(z3_str)+')' + return z3_str + elif isinstance(node.op, ast.Or): + z3_str = [] + for val in node.values: + tmp = self._generate_z3_expression_node(val) + if isinstance(tmp, bool): + if tmp: + return True + else: + continue + z3_str.append(tmp) + z3_str = 'Or('+','.join(z3_str)+')' + return z3_str + # If string, construct string + # If bool, check result and discard/evaluate result according to operator + pass + elif isinstance(node, ast.Constant): + # If is bool, return boolean result + if isinstance(node.value, bool): + return node.value + # Else, return raw expression + else: + expr = astunparse.unparse(node) + expr = expr.strip('\n') + return expr + else: + # For other cases, we can return the expression directly + expr = astunparse.unparse(node) + expr = expr.strip('\n') + return expr def evaluate_guard_disc(self, agent, discrete_variable_dict, lane_map): """ diff --git a/ourtool/automaton/reset.py b/ourtool/automaton/reset.py new file mode 100644 index 0000000000000000000000000000000000000000..440f993b0bfa0dc70000f2c7fe57102aa34a689f --- /dev/null +++ b/ourtool/automaton/reset.py @@ -0,0 +1,78 @@ +import numpy as np + +class ResetExpression: + def __init__(self, reset_list): + self.ast_list = [] + for reset in reset_list: + self.ast_list.append(reset.ast) + self.expr_list = [] + for reset in reset_list: + self.expr_list.append(reset.code) + + def apply_reset_continuous(self, agent, continuous_variable_dict, lane_map): + agent_state_lower = [] + agent_state_upper = [] + for var in agent.controller.vars_dict['ego']['cont']: + agent_state_lower.append(continuous_variable_dict['ego.'+var][0]) + agent_state_upper.append(continuous_variable_dict['ego.'+var][1]) + assert len(agent_state_lower) == len(agent_state_upper) == len(agent.controller.vars_dict['ego']['cont']) + for expr in self.expr_list: + if 'mode' not in expr: + tmp = expr.split('=') + lhs, rhs = tmp[0], tmp[1] + for lhs_idx, cts_variable in enumerate(agent.controller.vars_dict['ego']['cont']): + if "output."+cts_variable == lhs: + break + + lower = float('inf') + upper = -float('inf') + + symbols = [] + for var in continuous_variable_dict: + if var in expr: + symbols.append(var) + + combinations = self._get_combinations(symbols, continuous_variable_dict) + # for cts_variable in continuous_variable_dict: + # tmp = tmp.replace(cts_variable, str(continuous_variable_dict[cts_variable])) + # next_init[i] = eval(tmp) + for i in combinations.shape[0]: + comb = combinations[i,:] + for j in range(len(symbols)): + tmp = rhs.replace(symbols[j], str(comb[i,j])) + tmp = min(tmp, lower) + tmp = max(tmp, upper) + + agent_state_lower[lhs_idx] = lower + agent_state_upper[lhs_idx] = upper + + return [agent_state_lower, agent_state_upper] + + def _get_combinations(self, symbols, cont_var_dict): + all_vars = [] + for symbol in symbols: + all_vars.append(cont_var_dict[symbol]) + comb_array = np.array(np.meshgrid(*all_vars)).T.reshape(-1, len(symbols)) + return comb_array + + def get_dest(self, agent, agent_state, discrete_variable_dict, lane_map) -> str: + agent_mode = agent_state[1] + dest = agent_mode.split(',') + possible_dest = [[elem] for elem in dest] + for reset in self.expr_list: + if "mode" in reset: + for i, discrete_variable_ego in enumerate(agent.controller.vars_dict['ego']['disc']): + if discrete_variable_ego in reset: + break + tmp = reset.split('=') + if 'map' in tmp[1]: + tmp = tmp[1] + for var in discrete_variable_dict: + tmp = tmp.replace(var, f"'{discrete_variable_dict[var]}'") + possible_dest[i] = eval(tmp) + else: + tmp = tmp[1].split('.') + if tmp[0].strip(' ') in agent.controller.modes: + possible_dest[i] = [tmp[1]] + + return possible_dest \ No newline at end of file diff --git a/ourtool/scenario/scenario.py b/ourtool/scenario/scenario.py index 1cd218bd72f86b41b1056e3b8e599cfe788e129f..d6284c13d4bf72cfcf9088561184249452539fa1 100644 --- a/ourtool/scenario/scenario.py +++ b/ourtool/scenario/scenario.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Tuple import copy import itertools import ast @@ -7,6 +7,7 @@ import numpy as np from ourtool.agents.base_agent import BaseAgent from ourtool.automaton.guard import GuardExpressionAst +from ourtool.automaton.reset import ResetExpression from pythonparser import Guard from pythonparser import Reset from ourtool.analysis.simulator import Simulator @@ -176,8 +177,68 @@ class Scenario: continue guard_satisfied, is_contained = guard_expression.evaluate_guard_cont(agent, continuous_variable_dict, self.map) if guard_satisfied: - guard_hits.append(agent_id, guard_list, reset_list) - return guard_hits, is_conatined + guard_hits.append((agent_id, guard_list, reset_list)) + return guard_hits, is_contained + + def get_all_transition_set(self, node): + possible_transitions = [] + trace_length = int(len(list(node.trace.values())[0])/2) + guard_hits = [] + guard_hit_bool = False + for idx in range(0,trace_length): + # For each trace, check with the guard to see if there's any possible transition + # Store all possible transition in a list + # A transition is defined by (agent, src_mode, dest_mode, corresponding reset, transit idx) + # Here we enforce that only one agent transit at a time + all_agent_state = {} + for agent_id in node.agent: + all_agent_state[agent_id] = (node.trace[agent_id][idx*2:idx*2+2], node.mode[agent_id]) + hits, is_contain = self.check_guard_hit(all_agent_state) + if hits != []: + guard_hits.append((hits, all_agent_state, idx)) + guard_hit_bool = True + if hits == [] and guard_hit_bool: + break + if is_contain: + break + + reset_dict = {} + reset_idx_dict = {} + for hits, all_agent_state, hit_idx in guard_hits: + for agent_id, guard_list, reset_list in hits: + dest,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state) + if agent_id not in reset_dict: + reset_dict[agent_id] = {} + reset_idx_dict[agent_id] = {} + if dest not in reset_dict[agent_id]: + reset_dict[agent_id][dest] = [] + reset_idx_dict[agent_id][dest] = [] + reset_dict[agent_id][dest].append(reset_rect) + reset_idx_dict[agent_id][dest].append(hit_idx) + + # Combine reset rects and construct transitions + for agent in reset_dict: + for dest in reset_dict[agent]: + combined_rect = None + for rect in reset_dict[agent][dest]: + if combined_rect is None: + combined_rect = rect + else: + combined_rect[0,:] = np.minimum(combined_rect[0,:], rect[0,:]) + combined_rect[1,:] = np.maximum(combined_rect[1,:], rect[1,:]) + min_idx = min(reset_idx_dict[agent][dest]) + max_idx = max(reset_idx_dict[agent][dest]) + transition = (agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx)) + possible_transitions.append(transition) + # Return result + return possible_transitions + + def apply_reset(self, agent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]: + reset_expr = ResetExpression(reset_list) + continuous_variable_dict, discrete_variable_dict = self.sensor.sense(self, agent, all_agent_state, self.map) + dest = reset_expr.get_dest(agent, all_agent_state[agent.id], discrete_variable_dict, self.map) + rect = reset_expr.apply_reset_continuous(agent, continuous_variable_dict, self.map) + return dest, rect def get_all_transition(self, state_dict): lane_map = self.map @@ -196,7 +257,7 @@ class Scenario: if isinstance(item, Guard): guard_list.append(item) elif isinstance(item, Reset): - reset_list.append(item.code) + reset_list.append(item) # guard_expression = GuardExpression(guard_list=guard_list) guard_expression = GuardExpressionAst(guard_list) # Map the values to variables using sensor @@ -214,6 +275,7 @@ class Scenario: possible_dest = [[elem] for elem in dest] for reset in reset_list: # Specify the destination mode + reset = reset.code if "mode" in reset: for i, discrete_variable_ego in enumerate(agent.controller.vars_dict['ego']['disc']): if discrete_variable_ego in reset: