Skip to content
Snippets Groups Projects
Commit b088099e authored by unknown's avatar unknown
Browse files

get a prototype implementation to handle any

parent acfe68c0
No related branches found
No related tags found
No related merge requests found
......@@ -45,52 +45,98 @@ def _parse_elt(root, cont_var_dict, disc_var_dict, iter_name_list, targ_name_lis
if node.value.id in targ_name_list:
# Find corresponding targ_name in the targ_name_list
targ_name = node.value.id
var_index = targ_name_list.find(targ_name)
var_index = targ_name_list.index(targ_name)
# Find the corresponding iter_name in the iter_name_list
iter_name = iter_name_list[var_index]
# Create the name for the tmp variable
iter_pos = iter_pos_list[var_index]
tmp_variable_name = f"{iter_name}_{iter_pos}"
tmp_variable_name = f"{iter_name}_{iter_pos}.{node.attr}"
# Replace variables in the etl by using tmp variables
AttributeSubstituter(tmp_variable_name).visit(node)
root = AttributeNameSubstituter(tmp_variable_name, node).visit(root)
# Find the value of the tmp variable in the cont/disc_var_dict
# Add the tmp variables into the cont/disc_var_dict
variable_name = iter_name + '.' + node.attr
variable_val = None
if variable_name in cont_var_dict:
variable_val = cont_var_dict[variable_name][iter_pos]
cont_var_dict[tmp_variable_name] = variable_val
elif variable_name in disc_var_dict:
variable_val = disc_var_dict[variable_name][iter_pos]
disc_var_dict[tmp_variable_name] = variable_val
if isinstance(node, ast.Name):
if node.id in targ_name_list:
node:ast.Name
# Find corresponding targ_name in the targ_name_list
targ_name = node.id
var_index = targ_name_list.index(targ_name)
# Find the corresponding iter_name in the iter_name_list
iter_name = iter_name_list[var_index]
# Create the name for the tmp variable
iter_pos = iter_pos_list[var_index]
tmp_variable_name = f"{iter_name}_{iter_pos}"
# Return the modified node
pass
# Replace variables in the etl by using tmp variables
root = AttributeNameSubstituter(tmp_variable_name, node).visit(root)
class AttributeSubstituter(ast.NodeTransformer):
def __init__(self, name:str):
# Find the value of the tmp variable in the cont/disc_var_dict
# Add the tmp variables into the cont/disc_var_dict
variable_name = iter_name
variable_val = None
if variable_name in cont_var_dict:
variable_val = cont_var_dict[variable_name][iter_pos]
cont_var_dict[tmp_variable_name] = variable_val
elif variable_name in disc_var_dict:
variable_val = disc_var_dict[variable_name][iter_pos]
disc_var_dict[tmp_variable_name] = variable_val
# Return the modified node
return root
class AttributeNameSubstituter(ast.NodeTransformer):
def __init__(self, name:str, node):
super().__init__()
self.name = name
self.node = node
def visit_Attribute(self, node: ast.Attribute) -> Any:
return ast.Name(
id = self.name,
ctx = ast.Load()
)
if node == self.node:
return ast.Name(
id = self.name,
ctx = ast.Load()
)
return node
def visit_Name(self, node: ast.Attribute) -> Any:
if node == self.node:
return ast.Name(
id = self.name,
ctx = ast.Load
)
return node
class FunctionCallSubstituter(ast.NodeTransformer):
def __init__(self, values:List[ast.Expr]):
def __init__(self, values:List[ast.Expr], node):
super().__init__()
self.values = values
self.node = node
def visit_Call(self, node: ast.Call) -> Any:
if node.func.id == 'any':
raise ast.BoolOp(
op = ast.Or(),
values = self.values
)
elif node.func.id == 'all':
raise NotImplementedError
else:
return node
if node == self.node:
if node.func.id == 'any':
return ast.BoolOp(
op = ast.Or(),
values = self.values
)
elif node.func.id == 'all':
raise NotImplementedError
return node
def parse_any(
node: ast.Call,
......@@ -106,15 +152,15 @@ def parse_any(
iter_len_list = []
# Get all the iter, targets and the length of iter list
for generator in parse_arg.generators:
iter_name_list.append(generator.iter.name) # a_list
targ_name_list.append(generator.target.name) # a
iter_len_list.append(range(len_dict[generator.iter.name])) # len(a_list)
iter_name_list.append(generator.iter.id) # a_list
targ_name_list.append(generator.target.id) # a
iter_len_list.append(range(len_dict[generator.iter.id])) # len(a_list)
elt = generator.elt
elt = parse_arg.elt
expand_elt_ast_list = []
iter_len_list = list(itertools.product(*iter_len_list))
# Loop through all possible combination of iter value
for i in range(iter_len_list):
for i in range(len(iter_len_list)):
changed_elt = copy.deepcopy(elt)
iter_pos_list = iter_len_list[i]
# substitute temporary variable in each of the elt and add corresponding variables in the variable dicts
......@@ -122,7 +168,7 @@ def parse_any(
# Add the expanded elt into the list
expand_elt_ast_list.append(parsed_elt)
# Create the new boolop (or) node based on the list of expanded elt
return FunctionCallSubstituter(expand_elt_ast_list).visit(node)
return FunctionCallSubstituter(expand_elt_ast_list, node).visit(node)
pass
if __name__ == "__main__":
......@@ -130,5 +176,18 @@ if __name__ == "__main__":
ego = State()
code_any = "any((other.x -ego.x > 5 and other.type==Vehicle) for other in others)"
ast_any = ast.parse(code_any).body[0].value
parse_any(ast_any)
cont_var_dict = {
"others.x":[1,2],
"others.y":[3,4],
"ego.x":5,
"ego.y":2
}
disc_var_dict = {
"others.type":['Vehicle','Sign'],
"ego.type":'Ped'
}
len_dict = {
"others":2
}
res = parse_any(ast_any, cont_var_dict, disc_var_dict, len_dict)
print(ast_any)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment