Skip to content
Snippets Groups Projects
Commit acfe68c0 authored by lyg1597's avatar lyg1597
Browse files

working on any/all function calls

parent c0aa25e6
No related branches found
No related tags found
No related merge requests found
import ast
from email import generator
from re import M
import astunparse
from enum import Enum, auto
import itertools
import copy
from typing import Any, Dict, List, Tuple
class VehicleMode(Enum):
Normal = auto()
SwitchLeft = auto()
SwitchRight = auto()
Brake = auto()
class LaneMode(Enum):
Lane0 = auto()
Lane1 = auto()
Lane2 = auto()
class LaneObjectMode(Enum):
Vehicle = auto()
Ped = auto() # Pedestrians
Sign = auto() # Signs, stop signs, merge, yield etc.
Signal = auto() # Traffic lights
Obstacle = auto() # Static (to road/lane) obstacles
class State:
x: float
y: float
theta: float
v: float
vehicle_mode: VehicleMode
lane_mode: LaneMode
type: LaneObjectMode
def __init__(self, x: float = 0, y: float = 0, theta: float = 0, v: float = 0, vehicle_mode: VehicleMode = VehicleMode.Normal, lane_mode: LaneMode = LaneMode.Lane0, type: LaneObjectMode = LaneObjectMode.Vehicle):
pass
def _parse_elt(root, cont_var_dict, disc_var_dict, iter_name_list, targ_name_list, iter_pos_list) -> Any:
# Loop through all node in the elt ast
for node in ast.walk(root):
# If the node is an attribute
if isinstance(node, ast.Attribute):
if node.value.id in targ_name_list:
# Find corresponding targ_name in the targ_name_list
targ_name = node.value.id
var_index = targ_name_list.find(targ_name)
# Find the corresponding iter_name in the iter_name_list
iter_name = iter_name_list[var_index]
# Create the name for the tmp variable
iter_pos = iter_pos_list[var_index]
tmp_variable_name = f"{iter_name}_{iter_pos}"
# Replace variables in the etl by using tmp variables
AttributeSubstituter(tmp_variable_name).visit(node)
# Find the value of the tmp variable in the cont/disc_var_dict
# Add the tmp variables into the cont/disc_var_dict
# Return the modified node
pass
class AttributeSubstituter(ast.NodeTransformer):
def __init__(self, name:str):
super().__init__()
self.name = name
def visit_Attribute(self, node: ast.Attribute) -> Any:
return ast.Name(
id = self.name,
ctx = ast.Load()
)
class FunctionCallSubstituter(ast.NodeTransformer):
def __init__(self, values:List[ast.Expr]):
super().__init__()
self.values = values
def visit_Call(self, node: ast.Call) -> Any:
if node.func.id == 'any':
raise ast.BoolOp(
op = ast.Or(),
values = self.values
)
elif node.func.id == 'all':
raise NotImplementedError
else:
return node
def parse_any(
node: ast.Call,
cont_var_dict: Dict[str, float],
disc_var_dict: Dict[str, float],
len_dict: Dict[str, int]
) -> ast.BoolOp:
parse_arg = node.args[0]
if isinstance(parse_arg, ast.GeneratorExp):
iter_name_list = []
targ_name_list = []
iter_len_list = []
# Get all the iter, targets and the length of iter list
for generator in parse_arg.generators:
iter_name_list.append(generator.iter.name) # a_list
targ_name_list.append(generator.target.name) # a
iter_len_list.append(range(len_dict[generator.iter.name])) # len(a_list)
elt = generator.elt
expand_elt_ast_list = []
iter_len_list = list(itertools.product(*iter_len_list))
# Loop through all possible combination of iter value
for i in range(iter_len_list):
changed_elt = copy.deepcopy(elt)
iter_pos_list = iter_len_list[i]
# substitute temporary variable in each of the elt and add corresponding variables in the variable dicts
parsed_elt = _parse_elt(changed_elt, cont_var_dict, disc_var_dict, iter_name_list, targ_name_list, iter_pos_list)
# Add the expanded elt into the list
expand_elt_ast_list.append(parsed_elt)
# Create the new boolop (or) node based on the list of expanded elt
return FunctionCallSubstituter(expand_elt_ast_list).visit(node)
pass
if __name__ == "__main__":
others = [State(), State()]
ego = State()
code_any = "any((other.x -ego.x > 5 and other.type==Vehicle) for other in others)"
ast_any = ast.parse(code_any).body[0].value
parse_any(ast_any)
print(ast_any)
parse.py 0 → 100644
import ast, copy
from typing import List, Dict, Optional
import astunparse
class Scope:
scopes: List[Dict[str, ast.AST]]
def __init__(self):
self.scopes = [{}]
def push(self):
self.scopes = [{}] + self.scopes
def pop(self):
self.scopes = self.scopes[1:]
def lookup(self, key):
for scope in self.scopes:
if key in scope:
return scope[key]
return None
def set(self, key, val):
for scope in self.scopes:
if key in scope:
scope[key] = val
return
self.scopes[0][key] = val
def dump(self, dump=False):
ast_dump = lambda node: ast.dump(node, indent=2) if dump else astunparse.unparse(node)
for scope in self.scopes:
for k, node in scope.items():
print(f"{k}: ", end="")
if isinstance(node, Function):
print(f"Function args: {node.args} body: {ast_dump(node.body)}")
else:
print(ast_dump(node))
print("===")
class VarSubstituter(ast.NodeTransformer):
args: Dict[str, ast.expr]
def __init__(self, args):
super().__init__()
self.args = args
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load) and node.id in self.args:
return self.args[node.id]
self.generic_visit(node)
return node # XXX needed?
class Function:
args: List[str]
body: ast.expr
def __init__(self, args, body):
self.args = args
self.body = body
@staticmethod
def from_func_def(fd: ast.FunctionDef, scope: Scope) -> "Function":
args = [a.arg for a in fd.args.args]
scope.push()
for a in args:
scope.set(a, ast.Name(a, ctx=ast.Load()))
ret = None
for node in fd.body:
ret = proc(node, scope)
scope.pop()
return Function(args, ret)
def apply(self, args: List[ast.expr]) -> ast.expr:
ret = copy.deepcopy(self.body)
args = {k: v for k, v in zip(self.args, args)}
return VarSubstituter(args).visit(ret)
def proc(node: ast.AST, scope: Scope):
if isinstance(node, ast.Module):
for node in node.body:
proc(node, scope)
elif isinstance(node, ast.For) or isinstance(node, ast.While):
raise NotImplementedError("loops not supported")
elif isinstance(node, ast.If):
node.test = proc(node.test, scope)
for node in node.body: # FIXME properly handle branching
proc(node, scope)
for node in node.orelse:
proc(node, scope)
elif isinstance(node, ast.Assign):
if len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
scope.set(target.id, proc(node.value, scope))
elif isinstance(target, ast.Attribute):
raise NotImplementedError("assign.attr")
else:
raise NotImplementedError("unpacking not supported")
elif isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
return scope.lookup(node.id)
elif isinstance(node, ast.FunctionDef):
scope.set(node.name, Function.from_func_def(node, scope))
elif isinstance(node, ast.ClassDef):
pass
elif isinstance(node, ast.UnaryOp):
return ast.UnaryOp(node.op, proc(node.operand, scope))
elif isinstance(node, ast.BinOp):
return ast.BinOp(proc(node.left, scope), node.op, proc(node.right, scope))
elif isinstance(node, ast.BoolOp):
return ast.BoolOp(node.op, [proc(val, scope) for val in node.values])
elif isinstance(node, ast.Compare):
if len(node.ops) > 1 or len(node.comparators) > 1:
raise NotImplementedError("too many comparisons")
return ast.Compare(proc(node.left, scope), node.ops, [proc(node.comparators[0], scope)])
elif isinstance(node, ast.Call):
fun = proc(node.func, scope)
if not isinstance(fun, Function):
raise Exception("???")
return fun.apply([proc(a, scope) for a in node.args])
elif isinstance(node, ast.List):
return ast.List([proc(e, scope) for e in node.elts])
elif isinstance(node, ast.Tuple):
return ast.Tuple([proc(e, scope) for e in node.elts])
elif isinstance(node, ast.Return):
return proc(node.value, scope)
elif isinstance(node, ast.Constant):
return node # XXX simplification?
def parse(fn: str):
with open(fn) as f:
cont = f.read()
root = ast.parse(cont, fn)
scope = Scope()
proc(root, scope)
scope.dump()
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
print("usage: parse.py <file.py>")
sys.exit(1)
parse(sys.argv[1])
def f(x, b):
c = x < 3
d = x + 2 * b
return d > 10 or c
x = 10
y = 20 + x
z = f(y, x)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment