diff --git a/dryvr_plus_plus/automaton/guard_backup.py b/dryvr_plus_plus/automaton/guard_backup.py deleted file mode 100644 index 8f40e1dcc8a1ef70ef7821a8c035069d36b296d0..0000000000000000000000000000000000000000 --- a/dryvr_plus_plus/automaton/guard_backup.py +++ /dev/null @@ -1,1082 +0,0 @@ -import enum -import re -from typing import List, Dict, Any -import pickle -import ast -import copy - -from z3 import * -import astunparse -import numpy as np - -from dryvr_plus_plus.map.lane_map import LaneMap -from dryvr_plus_plus.map.lane_segment import AbstractLane -from dryvr_plus_plus.utils.utils import * -from dryvr_plus_plus.code_parser.parser import Reduction, ReductionType - -class LogicTreeNode: - def __init__(self, data, child = [], val = None, mode_guard = None): - self.data = data - self.child = child - self.val = val - self.mode_guard = mode_guard - -class NodeSubstituter(ast.NodeTransformer): - def __init__(self, old_node, new_node): - super().__init__() - self.old_node = old_node - self.new_node = new_node - - def visit_Reduction(self, node: Reduction) -> Any: - if node == self.old_node: - self.generic_visit(node) - return self.new_node - else: - self.generic_visit(node) - return node - - # def visit_Call(self, node: ast.Call) -> Any: - # if node == self.old_node: - # self.generic_visit(node) - # return self.new_node - # else: - # self.generic_visit(node) - # return node - -class ValueSubstituter(ast.NodeTransformer): - def __init__(self, val:str, node): - super().__init__() - self.val = val - self.node = node - - def visit_Attribute(self, node: ast.Attribute) -> Any: - # Substitute attribute node in the ast - if isinstance(self.node, Reduction): - return node - if node == self.node: - return ast.Name( - id = self.val, - ctx = ast.Load() - ) - return node - - def visit_Name(self, node: ast.Attribute) -> Any: - # Substitute name node in the ast - if isinstance(self.node, Reduction): - return node - if node == self.node: - return ast.Name( - id = self.val, - ctx = ast.Load - ) - return node - - def visit_Reduction(self, node: Reduction) -> Any: - if node == self.node: - if len(self.val) == 1: - self.generic_visit(node) - return self.val[0] - elif node.op == ReductionType.Any: - self.generic_visit(node) - return ast.BoolOp( - op = ast.Or(), - values = self.val - ) - elif node.op == ReductionType.All: - self.generic_visit(node) - return ast.BoolOp( - op = ast.And(), - values = self.val - ) - self.generic_visit(node) - return node - -class GuardExpressionAst: - def __init__(self, guard_list): - self.ast_list = copy.deepcopy(guard_list) - self.cont_variables = {} - self.varDict = {} - - def _build_guard(self, guard_str, agent): - """ - Build solver for current guard based on guard string - - Args: - guard_str (str): the guard string. - For example:"And(v>=40-0.1*u, v-40+0.1*u<=0)" - - Returns: - A Z3 Solver obj that check for guard. - A symbol index dic obj that indicates the index - of variables that involved in the guard. - """ - cur_solver = Solver() - # This magic line here is because SymPy will evaluate == to be False - # Therefore we are not be able to get free symbols from it - # Thus we need to replace "==" to something else - - symbols_map = {v: k for k, v in self.cont_variables.items()} - - for vars in reversed(self.cont_variables): - guard_str = guard_str.replace(vars, self.cont_variables[vars]) - # XXX `locals` should override `globals` right? - cur_solver.add(eval(guard_str, globals(), self.varDict)) # TODO use an object instead of `eval` a string - return cur_solver, symbols_map - - def evaluate_guard_cont(self, agent, continuous_variable_dict, lane_map): - res = False - is_contained = False - - for cont_vars in continuous_variable_dict: - underscored = cont_vars.replace('.','_') - self.cont_variables[cont_vars] = underscored - self.varDict[underscored] = Real(underscored) - - z3_string = self.generate_z3_expression() - if isinstance(z3_string, bool): - return z3_string, z3_string - - cur_solver, symbols = self._build_guard(z3_string, agent) - cur_solver.push() - for symbol in symbols: - start, end = continuous_variable_dict[symbols[symbol]] - cur_solver.add(self.varDict[symbol] >= start, self.varDict[symbol] <= end) - if cur_solver.check() == sat: - # The reachtube hits the guard - cur_solver.pop() - res = True - - tmp_solver = Solver() - tmp_solver.add(Not(cur_solver.assertions()[0])) - for symbol in symbols: - start, end = continuous_variable_dict[symbols[symbol]] - tmp_solver.add(self.varDict[symbol] >= start, self.varDict[symbol] <= end) - if tmp_solver.check() == unsat: - print("Full intersect, break") - is_contained = True - - return res, is_contained - - def generate_z3_expression(self): - """ - The return value of this function will be a bool/str - - If without evaluating the continuous variables the result is True, then - the guard will automatically be satisfied and is_contained will be True - - If without evaluating the continuous variables the result is False, th- - en the guard will automatically be unsatisfied - - If the result is a string, then continuous variables will be checked to - see if the guard can be satisfied - """ - res = [] - for node in self.ast_list: - tmp = self._generate_z3_expression_node(node) - if isinstance(tmp, bool): - if not tmp: - return False - else: - continue - res.append(tmp) - if res == []: - return True - elif len(res) == 1: - return res[0] - res = "And("+",".join(res)+")" - return res - - def _generate_z3_expression_node(self, node): - """ - Perform a DFS over expression ast and generate the guard expression - The return value of this function can be a bool/str - - If without evaluating the continuous variables the result is True, then - the guard condition will automatically be satisfied - - If without evaluating the continuous variables the result is False, then - the guard condition will not be satisfied - - If the result is a string, then continuous variables will be checked to - see if the guard can be satisfied - """ - if isinstance(node, ast.BoolOp): - # Check the operator - # For each value in the boolop, check results - if isinstance(node.op, ast.And): - z3_str = [] - for i,val in enumerate(node.values): - tmp = self._generate_z3_expression_node(val) - if isinstance(tmp, bool): - if tmp: - continue - else: - return False - z3_str.append(tmp) - if len(z3_str) == 1: - z3_str = z3_str[0] - else: - z3_str = 'And('+','.join(z3_str)+')' - return z3_str - elif isinstance(node.op, ast.Or): - z3_str = [] - for val in node.values: - tmp = self._generate_z3_expression_node(val) - if isinstance(tmp, bool): - if tmp: - return True - else: - continue - z3_str.append(tmp) - if len(z3_str) == 1: - z3_str = z3_str[0] - else: - z3_str = 'Or('+','.join(z3_str)+')' - return z3_str - # If string, construct string - # If bool, check result and discard/evaluate result according to operator - pass - elif isinstance(node, ast.Constant): - # If is bool, return boolean result - if isinstance(node.value, bool): - return node.value - # Else, return raw expression - else: - expr = astunparse.unparse(node) - expr = expr.strip('\n') - return expr - elif isinstance(node, ast.UnaryOp): - # If is UnaryOp, - value = self._generate_z3_expression_node(node.operand) - if isinstance(node.op, ast.USub): - return -value - elif isinstance(node.op, ast.Not): - z3_str = 'Not('+value+')' - return z3_str - else: - raise NotImplementedError(f"UnaryOp {node.op} is not supported") - else: - # For other cases, we can return the expression directly - expr = astunparse.unparse(node) - expr = expr.strip('\n') - return expr - - def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map:LaneMap): - """ - Handle guard atomics that contains both continuous and hybrid variables - Especially, we want to handle function calls that need both continuous and - discrete variables as input - We will perform interval arithmetic based on the function calls to the input and replace the function calls - with temp constants with their values stored in the continuous variable dict - By doing this, all calls that need both continuous and discrete variables as input will now become only continuous - variables. We can then handle these using what we already have for the continous variables - """ - res = True - for i, node in enumerate(self.ast_list): - tmp, self.ast_list[i] = self._evaluate_guard_hybrid(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) - res = res and tmp - return res - - def _evaluate_guard_hybrid(self, root, agent, disc_var_dict, cont_var_dict, lane_map:LaneMap): - if isinstance(root, ast.Compare): - expr = astunparse.unparse(root) - left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.comparators[0] = self._evaluate_guard_hybrid(root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map) - return True, root - elif isinstance(root, ast.BoolOp): - if isinstance(root.op, ast.And): - res = True - for i, val in enumerate(root.values): - tmp, root.values[i] = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map) - res = res and tmp - if not res: - break - return res, root - elif isinstance(root.op, ast.Or): - res = False - for val in root.values: - tmp,val = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map) - res = res or tmp - return res, root - elif isinstance(root, ast.BinOp): - left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.right = self._evaluate_guard_hybrid(root.right, agent, disc_var_dict, cont_var_dict, lane_map) - return True, root - elif isinstance(root, ast.Call): - if isinstance(root.func, ast.Attribute): - func = root.func - if func.value.id == 'lane_map': - if func.attr == 'get_lateral_distance': - # Get function arguments - arg0_node = root.args[0] - arg1_node = root.args[1] - if isinstance(arg0_node, ast.Attribute): - arg0_var = arg0_node.value.id + '.' + arg0_node.attr - elif isinstance(arg0_node, ast.Name): - arg0_var = arg0_node.id - else: - raise ValueError(f"Node type {type(arg0_node)} is not supported") - vehicle_lane = disc_var_dict[arg0_var] - assert isinstance(arg1_node, ast.List) - arg1_lower = [] - arg1_upper = [] - for elt in arg1_node.elts: - if isinstance(elt, ast.Attribute): - var = elt.value.id + '.' + elt.attr - elif isinstance(elt, ast.Name): - var = elt.id - else: - raise ValueError(f"Node type {type(elt)} is not supported") - arg1_lower.append(cont_var_dict[var][0]) - arg1_upper.append(cont_var_dict[var][1]) - vehicle_pos = (arg1_lower, arg1_upper) - - # Get corresponding lane segments with respect to the set of vehicle pos - lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower) - lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper) - - # Compute the set of possible lateral values with respect to all possible segments - lateral_set1 = self._handle_lateral_set(lane_seg1, np.array(vehicle_pos)) - lateral_set2 = self._handle_lateral_set(lane_seg2, np.array(vehicle_pos)) - - # Use the union of two sets as the set of possible lateral positions - lateral_set = [min(lateral_set1[0], lateral_set2[0]), max(lateral_set1[1], lateral_set2[1])] - - # Construct the tmp variable - tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}' - # Add the tmp variable to the cont var dict - cont_var_dict[tmp_var_name] = lateral_set - # Replace the corresponding function call in ast - root = ast.parse(tmp_var_name).body[0].value - return True, root - elif func.attr == 'get_longitudinal_position': - # Get function arguments - arg0_node = root.args[0] - arg1_node = root.args[1] - # assert isinstance(arg0_node, ast.Attribute) - if isinstance(arg0_node, ast.Attribute): - arg0_var = arg0_node.value.id + '.' + arg0_node.attr - elif isinstance(arg0_node, ast.Name): - arg0_var = arg0_node.id - else: - raise ValueError(f"Node type {type(arg0_node)} is not supported") - vehicle_lane = disc_var_dict[arg0_var] - assert isinstance(arg1_node, ast.List) - arg1_lower = [] - arg1_upper = [] - for elt in arg1_node.elts: - if isinstance(elt, ast.Attribute): - var = elt.value.id + '.' + elt.attr - elif isinstance(elt, ast.Name): - var = elt.id - else: - raise ValueError(f"Node type {type(elt)} is not supported") - arg1_lower.append(cont_var_dict[var][0]) - arg1_upper.append(cont_var_dict[var][1]) - vehicle_pos = (arg1_lower, arg1_upper) - - # Get corresponding lane segments with respect to the set of vehicle pos - lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower) - lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper) - - # Compute the set of possible longitudinal values with respect to all possible segments - longitudinal_set1 = self._handle_longitudinal_set(lane_seg1, np.array(vehicle_pos)) - longitudinal_set2 = self._handle_longitudinal_set(lane_seg2, np.array(vehicle_pos)) - - # Use the union of two sets as the set of possible longitudinal positions - longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max(longitudinal_set1[1], longitudinal_set2[1])] - - # Construct the tmp variable - tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}' - # Add the tmp variable to the cont var dict - cont_var_dict[tmp_var_name] = longitudinal_set - # Replace the corresponding function call in ast - root = ast.parse(tmp_var_name).body[0].value - return True, root - else: - raise ValueError(f'Node type {func} from {astunparse.unparse(func)} is not supported') - else: - raise ValueError(f'Node type {func} from {astunparse.unparse(func)} is not supported') - else: - raise ValueError(f'Node type {root.func} from {astunparse.unparse(root.func)} is not supported') - elif isinstance(root, ast.Attribute): - return True, root - elif isinstance(root, ast.Constant): - return root.value, root - elif isinstance(root, ast.Name): - return True, root - elif isinstance(root, ast.UnaryOp): - if isinstance(root.op, ast.USub): - res, root.operand = self._evaluate_guard_hybrid(root.operand, agent, disc_var_dict, cont_var_dict, lane_map) - elif isinstance(root.op, ast.Not): - res, root.operand = self._evaluate_guard_hybrid(root.operand, agent, disc_var_dict, cont_var_dict, lane_map) - if not res: - root.operand = ast.parse('False').body[0].value - return True, ast.parse('True').body[0].value - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - return True, root - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - - def _handle_longitudinal_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]: - if lane_seg.type == "Straight": - # Delta lower - delta0 = position[0,:] - lane_seg.start - # Delta upper - delta1 = position[1,:] - lane_seg.start - - longitudinal_low = min(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \ - min(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1]) - longitudinal_high = max(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \ - max(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1]) - longitudinal_low += lane_seg.longitudinal_start - longitudinal_high += lane_seg.longitudinal_start - - assert longitudinal_high >= longitudinal_low - return longitudinal_low, longitudinal_high - elif lane_seg.type == "Circular": - # Delta lower - delta0 = position[0,:] - lane_seg.center - # Delta upper - delta1 = position[1,:] - lane_seg.center - - phi0 = np.min([ - np.arctan2(delta0[1], delta0[0]), - np.arctan2(delta0[1], delta1[0]), - np.arctan2(delta1[1], delta0[0]), - np.arctan2(delta1[1], delta1[0]), - ]) - phi1 = np.max([ - np.arctan2(delta0[1], delta0[0]), - np.arctan2(delta0[1], delta1[0]), - np.arctan2(delta1[1], delta0[0]), - np.arctan2(delta1[1], delta1[0]), - ]) - - phi0 = lane_seg.start_phase + wrap_to_pi(phi0 - lane_seg.start_phase) - phi1 = lane_seg.start_phase + wrap_to_pi(phi1 - lane_seg.start_phase) - longitudinal_low = min( - lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius, - lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius - ) + lane_seg.longitudinal_start - longitudinal_high = max( - lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius, - lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius - ) + lane_seg.longitudinal_start - - assert longitudinal_high >= longitudinal_low - return longitudinal_low, longitudinal_high - else: - raise ValueError(f'Lane segment with type {lane_seg.type} is not supported') - - def _handle_lateral_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]: - if lane_seg.type == "Straight": - # Delta lower - delta0 = position[0,:] - lane_seg.start - # Delta upper - delta1 = position[1,:] - lane_seg.start - - lateral_low = min(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \ - min(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1]) - lateral_high = max(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \ - max(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1]) - assert lateral_high >= lateral_low - return lateral_low, lateral_high - elif lane_seg.type == "Circular": - dx = np.max([position[0,0]-lane_seg.center[0],0,lane_seg.center[0]-position[1,0]]) - dy = np.max([position[0,1]-lane_seg.center[1],0,lane_seg.center[1]-position[1,1]]) - r_low = np.linalg.norm([dx, dy]) - - dx = np.max([np.abs(position[0,0]-lane_seg.center[0]),np.abs(position[1,0]-lane_seg.center[0])]) - dy = np.max([np.abs(position[0,1]-lane_seg.center[1]),np.abs(position[1,1]-lane_seg.center[1])]) - r_high = np.linalg.norm([dx, dy]) - lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low)) - lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low)) - # print(lateral_low, lateral_high) - assert lateral_high >= lateral_low - return lateral_low, lateral_high - else: - raise ValueError(f'Lane segment with type {lane_seg.type} is not supported') - - def evaluate_guard_disc(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map): - """ - Evaluate guard that involves only discrete variables. - """ - res = True - for i, node in enumerate(self.ast_list): - tmp, self.ast_list[i] = self._evaluate_guard_disc(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) - res = res and tmp - return res - - def _evaluate_guard_disc(self, root, agent, disc_var_dict, cont_var_dict, lane_map): - """ - Recursively called function to evaluate guard with only discrete variables - The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by - boolean constants - - :params: - :return: The return value will be a tuple. The first element in the tuple will either be a boolean value or a the evaluated value of of an expression involving guard - The second element in the tuple will be the updated ast node - """ - if isinstance(root, ast.Compare): - expr = astunparse.unparse(root) - left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.comparators[0] = self._evaluate_guard_disc(root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map) - if isinstance(left, bool) or isinstance(right, bool): - return True, root - if isinstance(root.ops[0], ast.GtE): - res = left>=right - elif isinstance(root.ops[0], ast.Gt): - res = left>right - elif isinstance(root.ops[0], ast.Lt): - res = left<right - elif isinstance(root.ops[0], ast.LtE): - res = left<=right - elif isinstance(root.ops[0], ast.Eq): - res = left == right - elif isinstance(root.ops[0], ast.NotEq): - res = left != right - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - if res: - root = ast.parse('True').body[0].value - else: - root = ast.parse('False').body[0].value - return res, root - elif isinstance(root, ast.BoolOp): - if isinstance(root.op, ast.And): - res = True - for i,val in enumerate(root.values): - tmp,root.values[i] = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map) - res = res and tmp - if not res: - break - return res, root - elif isinstance(root.op, ast.Or): - res = False - for val in root.values: - tmp,val = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map) - res = res or tmp - return res, root - elif isinstance(root, ast.BinOp): - # Check left and right in the binop and replace all attributes involving discrete variables - left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.right = self._evaluate_guard_disc(root.right, agent, disc_var_dict, cont_var_dict, lane_map) - return True, root - elif isinstance(root, ast.Call): - expr = astunparse.unparse(root) - # Check if the root is a function - if any([var in expr for var in disc_var_dict]) and all([var not in expr for var in cont_var_dict]): - # tmp = re.split('\(|\)',expr) - # while "" in tmp: - # tmp.remove("") - # for arg in tmp[1:]: - # if arg in disc_var_dict: - # expr = expr.replace(arg,f'"{disc_var_dict[arg]}"') - # res = eval(expr) - for arg in disc_var_dict: - expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') - res = eval(expr) - if isinstance(res, bool): - if res: - root = ast.parse('True').body[0].value - else: - root = ast.parse('False').body[0].value - else: - # TODO-PARSER: Handle This - for mode_name in agent.controller.modes: - # TODO-PARSER: Handle This - if res in agent.controller.modes[mode_name]: - res = mode_name+'.'+res - break - root = ast.parse(str(res)).body[0].value - return res, root - else: - return True, root - elif isinstance(root, ast.Attribute): - expr = astunparse.unparse(root) - expr = expr.strip('\n') - if expr in disc_var_dict: - val = disc_var_dict[expr] - # TODO-PARSER: Handle This - for mode_name in agent.controller.modes: - # TODO-PARSER: Handle This - if val in agent.controller.modes[mode_name]: - val = mode_name+'.'+val - break - return val, root - # TODO-PARSER: Handle This - elif root.value.id in agent.controller.modes: - return expr, root - else: - return True, root - elif isinstance(root, ast.Constant): - return root.value, root - elif isinstance(root, ast.UnaryOp): - if isinstance(root.op, ast.USub): - res, root.operand = self._evaluate_guard_disc(root.operand, agent, disc_var_dict, cont_var_dict, lane_map) - elif isinstance(root.op, ast.Not): - res, root.operand = self._evaluate_guard_disc(root.operand, agent, disc_var_dict, cont_var_dict, lane_map) - if not res: - root.operand = ast.parse('False').body[0].value - return True, ast.parse('True').body[0].value - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - return True, root - elif isinstance(root, ast.Name): - expr = root.id - if expr in disc_var_dict: - val = disc_var_dict[expr] - # TODO-PARSER: Handle This - for mode_name in agent.controller.modes: - # TODO-PARSER: Handle This - if val in agent.controller.modes[mode_name]: - val = mode_name + '.' + val - break - return val, root - else: - return True, root - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - - def evaluate_guard_old(self, agent, continuous_variable_dict, discrete_variable_dict, lane_map): - res = True - for i, node in enumerate(self.ast_list): - tmp = self._evaluate_guard_old(node, agent, continuous_variable_dict, discrete_variable_dict, lane_map) - res = tmp and res - if not res: - break - return res - - def _evaluate_guard_old(self, root, agent, cnts_var_dict, disc_var_dict, lane_map): - if isinstance(root, ast.Compare): - left = self._evaluate_guard_old(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard_old(root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map) - if isinstance(root.ops[0], ast.GtE): - return left>=right - elif isinstance(root.ops[0], ast.Gt): - return left>right - elif isinstance(root.ops[0], ast.Lt): - return left<right - elif isinstance(root.ops[0], ast.LtE): - return left<=right - elif isinstance(root.ops[0], ast.Eq): - return left == right - elif isinstance(root.ops[0], ast.NotEq): - return left != right - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - - elif isinstance(root, ast.BoolOp): - if isinstance(root.op, ast.And): - res = True - for val in root.values: - tmp = self._evaluate_guard_old(val, agent, cnts_var_dict, disc_var_dict, lane_map) - res = res and tmp - if not res: - break - return res - elif isinstance(root.op, ast.Or): - res = False - for val in root.values: - tmp = self._evaluate_guard_old(val, agent, cnts_var_dict, disc_var_dict, lane_map) - res = res or tmp - if res: - break - return res - elif isinstance(root, ast.BinOp): - left = self._evaluate_guard_old(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard_old(root.right, agent, cnts_var_dict, disc_var_dict, lane_map) - if isinstance(root.op, ast.Sub): - return left - right - elif isinstance(root.op, ast.Add): - return left + right - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - elif isinstance(root, ast.Call): - expr = astunparse.unparse(root) - # Check if the root is a function - if isinstance(root.func, ast.Attribute) and "map" in root.func.value.id: - # if 'map' in expr: - # tmp = re.split('\(|\)',expr) - # while "" in tmp: - # tmp.remove("") - # for arg in tmp[1:]: - # if arg in disc_var_dict: - # expr = expr.replace(arg,f'"{disc_var_dict[arg]}"') - # res = eval(expr) - for arg in disc_var_dict: - expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') - for arg in cnts_var_dict: - expr = expr.replace(arg, str(cnts_var_dict[arg])) - res = eval(expr) - # TODO-PARSER: Handle This - for mode_name in agent.controller.modes: - # TODO-PARSER: Handle This - if res in agent.controller.modes[mode_name]: - res = mode_name+'.'+res - break - return res - elif isinstance(root, ast.Attribute): - expr = astunparse.unparse(root) - expr = expr.strip('\n') - if expr in disc_var_dict: - val = disc_var_dict[expr] - # TODO-PARSER: Handle This - for mode_name in agent.controller.modes: - # TODO-PARSER: Handle This - if val in agent.controller.modes[mode_name]: - val = mode_name+'.'+val - break - return val - elif expr in cnts_var_dict: - val = cnts_var_dict[expr] - return val - - # TODO-PARSER: Handle This - elif root.value.id in agent.controller.modes: - return expr - elif isinstance(root, ast.Constant): - return root.value - elif isinstance(root, ast.UnaryOp): - val = self._evaluate_guard_old(root.operand, agent, cnts_var_dict, disc_var_dict, lane_map) - if isinstance(root.op, ast.USub): - return -val - if isinstance(root.op, ast.Not): - return not val - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - elif isinstance(root, ast.Name): - variable = root.id - if variable in cnts_var_dict: - val = cnts_var_dict[variable] - return val - elif variable in disc_var_dict: - val = disc_var_dict[variable] - # TODO-PARSER: Handle This - for mode_name in agent.controller.modes: - # TODO-PARSER: Handle This - if val in agent.controller.modes[mode_name]: - val = mode_name+'.'+val - break - return val - else: - raise ValueError(f"{variable} doesn't exist in either continuous varibales or discrete variables") - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - - def parse_any_all_new(self, cont_var_dict: Dict[str, float], disc_var_dict: Dict[str, float], len_dict: Dict[str, int]) -> Dict[str, List[str]]: - cont_var_updater = {} - for i in range(len(self.ast_list)): - root = self.ast_list[i] - j = 0 - while j < sum(1 for _ in ast.walk(root)): - # TODO: Find a faster way to access nodes in the tree - node = list(ast.walk(root))[j] - if isinstance(node, Reduction): - new_node = self.unroll_any_all_new(node, cont_var_dict, disc_var_dict, len_dict, cont_var_updater) - root = NodeSubstituter(node, new_node).visit(root) - j += 1 - self.ast_list[i] = root - return cont_var_updater - - def unroll_any_all_new( - self, node: Reduction, - cont_var_dict: Dict[str, float], - disc_var_dict: Dict[str, float], - len_dict: Dict[str, float], - cont_var_updater: Dict[str, List[str]], - ) -> Tuple[ast.BoolOp, Dict[str, List[str]]]: - # if isinstance(parse_arg, ast.GeneratorExp): - iter_name = node.value.id - iter_name_list = [node.value.id] - targ_name = node.it - targ_name_list = [node.it] - # targ_var_list = [] - # for var in cont_var_dict: - # if var.startswith(iter_name + '.'): - # tmp = var.split('.')[1] - # targ_var_list.append(targ_name + '.' + tmp) - # for var in disc_var_dict: - # if var.startswith(iter_name + '.'): - # tmp = var.split('.')[1] - # targ_var_list.append(targ_name + '.' + tmp) - - iter_len_list = [range(len_dict[node.value.id])] - # Get all the iter, targets and the length of iter list - # for generator in parse_arg.generators: - # iter_name_list.append(generator.iter.id) # a_list - # targ_name_list.append(generator.target.id) # a - # iter_len_list.append(range(len_dict[generator.iter.id])) # len(a_list) - - elt = node.expr - expand_elt_ast_list = [] - iter_len_list = list(itertools.product(*iter_len_list)) - # Loop through all possible combination of iter value - for i in range(len(iter_len_list)): - changed_elt = copy.deepcopy(elt) - iter_pos_list = iter_len_list[i] - # substitute temporary variable in each of the elt and add corresponding variables in the variable dicts - parsed_elt = self._parse_elt_new(changed_elt, cont_var_dict, disc_var_dict, cont_var_updater, iter_name_list, targ_name_list, iter_pos_list) - # Add the expanded elt into the list - expand_elt_ast_list.append(parsed_elt) - # Create the new boolop (and/or) node based on the list of expanded elt - return ValueSubstituter(expand_elt_ast_list, node).visit(node) - - def _parse_elt_new(self, root, cont_var_dict, disc_var_dict, cont_var_updater, iter_name_list, targ_var_list, iter_pos_list) -> Any: - # Loop through all node in the elt ast - for node in ast.walk(root): - # If the node is an attribute - if isinstance(node, ast.Attribute): - if node.value.id in targ_var_list: - # Find corresponding targ_name in the targ_var_list - targ_name = node.value.id - var_index = targ_var_list.index(targ_name) - - # Find the corresponding iter_name in the iter_name_list - iter_name = iter_name_list[var_index] - - # Create the name for the tmp variable - iter_pos = iter_pos_list[var_index] - tmp_variable_name = f"{iter_name}_{iter_pos}.{node.attr}" - - # Replace variables in the etl by using tmp variables - root = ValueSubstituter(tmp_variable_name, node).visit(root) - - # Find the value of the tmp variable in the cont/disc_var_dict - # Add the tmp variables into the cont/disc_var_dict - # NOTE: At each time step, for each agent, the variable value mapping and their - # sequence in the list is single. Therefore, for the same key, we will always rewrite - # its content. - variable_name = iter_name + '.' + node.attr - variable_val = None - if variable_name in cont_var_dict: - # variable_val = cont_var_dict[variable_name][iter_pos] - # cont_var_dict[tmp_variable_name] = variable_val - if variable_name not in cont_var_updater: - cont_var_updater[variable_name] = [(tmp_variable_name, iter_pos)] - else: - if (tmp_variable_name, iter_pos) not in cont_var_updater[variable_name]: - cont_var_updater[variable_name].append((tmp_variable_name, iter_pos)) - elif variable_name in disc_var_dict: - variable_val = disc_var_dict[variable_name][iter_pos] - disc_var_dict[tmp_variable_name] = variable_val - # if variable_name not in disc_var_updater: - # disc_var_updater[variable_name] = [(tmp_variable_name, iter_pos)] - # else: - # if (tmp_variable_name, iter_pos) not in disc_var_updater[variable_name]: - # disc_var_updater[variable_name].append((tmp_variable_name, iter_pos)) - - elif isinstance(node, ast.Name): - targ_var = None - for tmp in targ_var_list: - if node.id.startswith(tmp+'.'): - targ_var = tmp - break - if targ_var is not None: - node:ast.Name - # Find corresponding targ_name in the targ_var_list - targ_name = targ_var - var_index = targ_var_list.index(targ_name) - attr = node.id.split('.')[1] - - # Find the corresponding iter_name in the iter_name_list - iter_name = iter_name_list[var_index] - - # Create the name for the tmp variable - iter_pos = iter_pos_list[var_index] - tmp_variable_name = f"{iter_name}_{iter_pos}.{attr}" - - # Replace variables in the etl by using tmp variables - root = ValueSubstituter(tmp_variable_name, node).visit(root) - - # Find the value of the tmp variable in the cont/disc_var_dict - # Add the tmp variables into the cont/disc_var_dict - variable_name = iter_name + '.' + attr - variable_val = None - if variable_name in cont_var_dict: - # variable_val = cont_var_dict[variable_name][iter_pos] - # cont_var_dict[tmp_variable_name] = variable_val - if variable_name not in cont_var_updater: - cont_var_updater[variable_name] = [(tmp_variable_name, iter_pos)] - else: - if (tmp_variable_name, iter_pos) not in cont_var_updater[variable_name]: - cont_var_updater[variable_name].append((tmp_variable_name, iter_pos)) - elif variable_name in disc_var_dict: - variable_val = disc_var_dict[variable_name][iter_pos] - disc_var_dict[tmp_variable_name] = variable_val - # if variable_name not in disc_var_updater: - # disc_var_updater[variable_name] = [(tmp_variable_name, iter_pos)] - # else: - # if (tmp_variable_name, iter_pos) not in disc_var_updater[variable_name]: - # disc_var_updater[variable_name].append((tmp_variable_name, iter_pos)) - - # Return the modified node - return root - - def evaluate_guard(self, agent, continuous_variable_dict, discrete_variable_dict, lane_map): - res = True - for i, node in enumerate(self.ast_list): - tmp = self._evaluate_guard(node, agent, continuous_variable_dict, discrete_variable_dict, lane_map) - res = tmp and res - if not res: - break - return res - - def _evaluate_guard(self, root, agent, cnts_var_dict, disc_var_dict, lane_map): - if isinstance(root, ast.Compare): - left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard(root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map) - if isinstance(root.ops[0], ast.GtE): - return left>=right - elif isinstance(root.ops[0], ast.Gt): - return left>right - elif isinstance(root.ops[0], ast.Lt): - return left<right - elif isinstance(root.ops[0], ast.LtE): - return left<=right - elif isinstance(root.ops[0], ast.Eq): - return left == right - elif isinstance(root.ops[0], ast.NotEq): - return left != right - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - - elif isinstance(root, ast.BoolOp): - if isinstance(root.op, ast.And): - res = True - for val in root.values: - tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map) - res = res and tmp - if not res: - break - return res - elif isinstance(root.op, ast.Or): - res = False - for val in root.values: - tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map) - res = res or tmp - if res: - break - return res - elif isinstance(root, ast.BinOp): - left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard(root.right, agent, cnts_var_dict, disc_var_dict, lane_map) - if isinstance(root.op, ast.Sub): - return left - right - elif isinstance(root.op, ast.Add): - return left + right - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - elif isinstance(root, ast.Call): - expr = astunparse.unparse(root) - # Check if the root is a function - if isinstance(root.func, ast.Attribute) and "map" in root.func.value.id: - # if 'map' in expr: - # tmp = re.split('\(|\)',expr) - # while "" in tmp: - # tmp.remove("") - # for arg in tmp[1:]: - # if arg in disc_var_dict: - # expr = expr.replace(arg,f'"{disc_var_dict[arg]}"') - # res = eval(expr) - for arg in disc_var_dict: - expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') - for arg in cnts_var_dict: - expr = expr.replace(arg, str(cnts_var_dict[arg])) - res = eval(expr) - for mode_name in agent.controller.mode_defs: - if res in agent.controller.mode_defs[mode_name].modes: - res = mode_name+'.'+res - break - return res - elif isinstance(root.func, ast.Name): - if '.' in root.func.id and "map" in root.func.id.split('.')[0]: - for arg in disc_var_dict: - expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') - for arg in cnts_var_dict: - expr = expr.replace(arg, str(cnts_var_dict[arg])) - res = eval(expr) - for mode_name in agent.controller.mode_defs: - if res in agent.controller.mode_defs[mode_name].modes: - res = mode_name+'.'+res - break - return res - else: - raise ValueError(f'Unsupported function {astunparse.unparse(root)}') - else: - raise ValueError(f'Unsupported function {astunparse.unparse(root)}') - elif isinstance(root, ast.Attribute): - expr = astunparse.unparse(root) - expr = expr.strip('\n') - if expr in disc_var_dict: - val = disc_var_dict[expr] - for mode_name in agent.controller.mode_defs: - if val in agent.controller.mode_defs[mode_name].modes: - val = mode_name+'.'+val - break - return val - elif expr in cnts_var_dict: - val = cnts_var_dict[expr] - return val - elif root.value.id in agent.controller.mode_defs: - return expr - elif isinstance(root, ast.Constant): - return root.value - elif isinstance(root, ast.UnaryOp): - val = self._evaluate_guard(root.operand, agent, cnts_var_dict, disc_var_dict, lane_map) - if isinstance(root.op, ast.USub): - return -val - if isinstance(root.op, ast.Not): - return not val - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - elif isinstance(root, ast.Name): - variable = root.id - if variable in cnts_var_dict: - val = cnts_var_dict[variable] - return val - elif variable in disc_var_dict: - val = disc_var_dict[variable] - for mode_name in agent.controller.mode_defs: - if val in agent.controller.mode_defs[mode_name].modes: - val = mode_name+'.'+val - break - return val - elif '.' in variable: - cls_name, attr = variable.split('.') - if cls_name in agent.controller.mode_defs: - - else: - raise ValueError(f"{variable} doesn't exist in either continuous varibales or discrete variables") - else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - - # def _evaluate_guard(self, root, agent, cont_var_dict, disc_var_dict, lane_map): - # cont_var_dict.update(disc_var_dict) - # for node in ast.walk(root): - # if isinstance(node, ast.Name): - # if node.id in cont_var_dict and '.' in node.id: - # node.id = node.id.replace('.','_') - # elif '.' in node.id: - # class_name, attr = node.id.split('.') - # if class_name in agent.controller.mode_defs: - # node.id = attr - - # var_dict = {} - # for var in cont_var_dict: - # if '.' in var: - # var_dict[var.replace('.','_')] = cont_var_dict[var] - # else: - # var_dict[var] = cont_var_dict[var] - # var_dict['lane_map'] = lane_map - # return eval(astunparse.unparse(root), {}, var_dict) - -if __name__ == "__main__": - with open('tmp.pickle','rb') as f: - guard_list = pickle.load(f) - tmp = GuardExpressionAst(guard_list) - # tmp.evaluate_guard() - # tmp.construct_tree_from_str('(other_x-ego_x<20) and other_x-ego_x>10 and other_vehicle_lane==ego_vehicle_lane') - print("stop") \ No newline at end of file