diff --git a/develop/parse_any_all.py b/develop/parse_any_all.py new file mode 100644 index 0000000000000000000000000000000000000000..132725a4195286b39c8ff7960f9f888c0698bc63 --- /dev/null +++ b/develop/parse_any_all.py @@ -0,0 +1,134 @@ +import ast +from email import generator +from re import M +import astunparse +from enum import Enum, auto +import itertools +import copy +from typing import Any, Dict, List, Tuple + +class VehicleMode(Enum): + Normal = auto() + SwitchLeft = auto() + SwitchRight = auto() + Brake = auto() + +class LaneMode(Enum): + Lane0 = auto() + Lane1 = auto() + Lane2 = auto() + +class LaneObjectMode(Enum): + Vehicle = auto() + Ped = auto() # Pedestrians + Sign = auto() # Signs, stop signs, merge, yield etc. + Signal = auto() # Traffic lights + Obstacle = auto() # Static (to road/lane) obstacles + +class State: + x: float + y: float + theta: float + v: float + vehicle_mode: VehicleMode + lane_mode: LaneMode + type: LaneObjectMode + + def __init__(self, x: float = 0, y: float = 0, theta: float = 0, v: float = 0, vehicle_mode: VehicleMode = VehicleMode.Normal, lane_mode: LaneMode = LaneMode.Lane0, type: LaneObjectMode = LaneObjectMode.Vehicle): + pass + +def _parse_elt(root, cont_var_dict, disc_var_dict, iter_name_list, targ_name_list, iter_pos_list) -> Any: + # Loop through all node in the elt ast + for node in ast.walk(root): + # If the node is an attribute + if isinstance(node, ast.Attribute): + if node.value.id in targ_name_list: + # Find corresponding targ_name in the targ_name_list + targ_name = node.value.id + var_index = targ_name_list.find(targ_name) + + # Find the corresponding iter_name in the iter_name_list + iter_name = iter_name_list[var_index] + + # Create the name for the tmp variable + iter_pos = iter_pos_list[var_index] + tmp_variable_name = f"{iter_name}_{iter_pos}" + + # Replace variables in the etl by using tmp variables + AttributeSubstituter(tmp_variable_name).visit(node) + + # Find the value of the tmp variable in the cont/disc_var_dict + + # Add the tmp variables into the cont/disc_var_dict + + # Return the modified node + pass + +class AttributeSubstituter(ast.NodeTransformer): + def __init__(self, name:str): + super().__init__() + self.name = name + + def visit_Attribute(self, node: ast.Attribute) -> Any: + return ast.Name( + id = self.name, + ctx = ast.Load() + ) + + +class FunctionCallSubstituter(ast.NodeTransformer): + def __init__(self, values:List[ast.Expr]): + super().__init__() + self.values = values + + def visit_Call(self, node: ast.Call) -> Any: + if node.func.id == 'any': + raise ast.BoolOp( + op = ast.Or(), + values = self.values + ) + elif node.func.id == 'all': + raise NotImplementedError + else: + return node + +def parse_any( + node: ast.Call, + cont_var_dict: Dict[str, float], + disc_var_dict: Dict[str, float], + len_dict: Dict[str, int] +) -> ast.BoolOp: + + parse_arg = node.args[0] + if isinstance(parse_arg, ast.GeneratorExp): + iter_name_list = [] + targ_name_list = [] + iter_len_list = [] + # Get all the iter, targets and the length of iter list + for generator in parse_arg.generators: + iter_name_list.append(generator.iter.name) # a_list + targ_name_list.append(generator.target.name) # a + iter_len_list.append(range(len_dict[generator.iter.name])) # len(a_list) + + elt = generator.elt + expand_elt_ast_list = [] + iter_len_list = list(itertools.product(*iter_len_list)) + # Loop through all possible combination of iter value + for i in range(iter_len_list): + changed_elt = copy.deepcopy(elt) + iter_pos_list = iter_len_list[i] + # substitute temporary variable in each of the elt and add corresponding variables in the variable dicts + parsed_elt = _parse_elt(changed_elt, cont_var_dict, disc_var_dict, iter_name_list, targ_name_list, iter_pos_list) + # Add the expanded elt into the list + expand_elt_ast_list.append(parsed_elt) + # Create the new boolop (or) node based on the list of expanded elt + return FunctionCallSubstituter(expand_elt_ast_list).visit(node) + pass + +if __name__ == "__main__": + others = [State(), State()] + ego = State() + code_any = "any((other.x -ego.x > 5 and other.type==Vehicle) for other in others)" + ast_any = ast.parse(code_any).body[0].value + parse_any(ast_any) + print(ast_any) diff --git a/parse.py b/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..e25b97bcd4eb387064e60803cf3834897e4b3412 --- /dev/null +++ b/parse.py @@ -0,0 +1,141 @@ +import ast, copy +from typing import List, Dict, Optional +import astunparse + + +class Scope: + scopes: List[Dict[str, ast.AST]] + def __init__(self): + self.scopes = [{}] + + def push(self): + self.scopes = [{}] + self.scopes + + def pop(self): + self.scopes = self.scopes[1:] + + def lookup(self, key): + for scope in self.scopes: + if key in scope: + return scope[key] + return None + + def set(self, key, val): + for scope in self.scopes: + if key in scope: + scope[key] = val + return + self.scopes[0][key] = val + + def dump(self, dump=False): + ast_dump = lambda node: ast.dump(node, indent=2) if dump else astunparse.unparse(node) + for scope in self.scopes: + for k, node in scope.items(): + print(f"{k}: ", end="") + if isinstance(node, Function): + print(f"Function args: {node.args} body: {ast_dump(node.body)}") + else: + print(ast_dump(node)) + print("===") + +class VarSubstituter(ast.NodeTransformer): + args: Dict[str, ast.expr] + def __init__(self, args): + super().__init__() + self.args = args + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load) and node.id in self.args: + return self.args[node.id] + self.generic_visit(node) + return node # XXX needed? + +class Function: + args: List[str] + body: ast.expr + def __init__(self, args, body): + self.args = args + self.body = body + + @staticmethod + def from_func_def(fd: ast.FunctionDef, scope: Scope) -> "Function": + args = [a.arg for a in fd.args.args] + scope.push() + for a in args: + scope.set(a, ast.Name(a, ctx=ast.Load())) + ret = None + for node in fd.body: + ret = proc(node, scope) + scope.pop() + return Function(args, ret) + + def apply(self, args: List[ast.expr]) -> ast.expr: + ret = copy.deepcopy(self.body) + args = {k: v for k, v in zip(self.args, args)} + return VarSubstituter(args).visit(ret) + +def proc(node: ast.AST, scope: Scope): + if isinstance(node, ast.Module): + for node in node.body: + proc(node, scope) + elif isinstance(node, ast.For) or isinstance(node, ast.While): + raise NotImplementedError("loops not supported") + elif isinstance(node, ast.If): + node.test = proc(node.test, scope) + for node in node.body: # FIXME properly handle branching + proc(node, scope) + for node in node.orelse: + proc(node, scope) + elif isinstance(node, ast.Assign): + if len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name): + scope.set(target.id, proc(node.value, scope)) + elif isinstance(target, ast.Attribute): + raise NotImplementedError("assign.attr") + else: + raise NotImplementedError("unpacking not supported") + elif isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): + return scope.lookup(node.id) + elif isinstance(node, ast.FunctionDef): + scope.set(node.name, Function.from_func_def(node, scope)) + elif isinstance(node, ast.ClassDef): + pass + elif isinstance(node, ast.UnaryOp): + return ast.UnaryOp(node.op, proc(node.operand, scope)) + elif isinstance(node, ast.BinOp): + return ast.BinOp(proc(node.left, scope), node.op, proc(node.right, scope)) + elif isinstance(node, ast.BoolOp): + return ast.BoolOp(node.op, [proc(val, scope) for val in node.values]) + elif isinstance(node, ast.Compare): + if len(node.ops) > 1 or len(node.comparators) > 1: + raise NotImplementedError("too many comparisons") + return ast.Compare(proc(node.left, scope), node.ops, [proc(node.comparators[0], scope)]) + elif isinstance(node, ast.Call): + fun = proc(node.func, scope) + if not isinstance(fun, Function): + raise Exception("???") + return fun.apply([proc(a, scope) for a in node.args]) + elif isinstance(node, ast.List): + return ast.List([proc(e, scope) for e in node.elts]) + elif isinstance(node, ast.Tuple): + return ast.Tuple([proc(e, scope) for e in node.elts]) + elif isinstance(node, ast.Return): + return proc(node.value, scope) + elif isinstance(node, ast.Constant): + return node # XXX simplification? + +def parse(fn: str): + with open(fn) as f: + cont = f.read() + root = ast.parse(cont, fn) + scope = Scope() + proc(root, scope) + scope.dump() + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + print("usage: parse.py <file.py>") + sys.exit(1) + parse(sys.argv[1]) diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..873f5c159b68c4589690be596e7eb045db41b064 --- /dev/null +++ b/test.py @@ -0,0 +1,7 @@ +def f(x, b): + c = x < 3 + d = x + 2 * b + return d > 10 or c +x = 10 +y = 20 + x +z = f(y, x) \ No newline at end of file