Skip to content
Snippets Groups Projects
Commit b088099e authored by unknown's avatar unknown
Browse files

get a prototype implementation to handle any

parent acfe68c0
No related branches found
No related tags found
No related merge requests found
...@@ -45,52 +45,98 @@ def _parse_elt(root, cont_var_dict, disc_var_dict, iter_name_list, targ_name_lis ...@@ -45,52 +45,98 @@ def _parse_elt(root, cont_var_dict, disc_var_dict, iter_name_list, targ_name_lis
if node.value.id in targ_name_list: if node.value.id in targ_name_list:
# Find corresponding targ_name in the targ_name_list # Find corresponding targ_name in the targ_name_list
targ_name = node.value.id targ_name = node.value.id
var_index = targ_name_list.find(targ_name) var_index = targ_name_list.index(targ_name)
# Find the corresponding iter_name in the iter_name_list # Find the corresponding iter_name in the iter_name_list
iter_name = iter_name_list[var_index] iter_name = iter_name_list[var_index]
# Create the name for the tmp variable # Create the name for the tmp variable
iter_pos = iter_pos_list[var_index] iter_pos = iter_pos_list[var_index]
tmp_variable_name = f"{iter_name}_{iter_pos}" tmp_variable_name = f"{iter_name}_{iter_pos}.{node.attr}"
# Replace variables in the etl by using tmp variables # Replace variables in the etl by using tmp variables
AttributeSubstituter(tmp_variable_name).visit(node) root = AttributeNameSubstituter(tmp_variable_name, node).visit(root)
# Find the value of the tmp variable in the cont/disc_var_dict # Find the value of the tmp variable in the cont/disc_var_dict
# Add the tmp variables into the cont/disc_var_dict # Add the tmp variables into the cont/disc_var_dict
variable_name = iter_name + '.' + node.attr
variable_val = None
if variable_name in cont_var_dict:
variable_val = cont_var_dict[variable_name][iter_pos]
cont_var_dict[tmp_variable_name] = variable_val
elif variable_name in disc_var_dict:
variable_val = disc_var_dict[variable_name][iter_pos]
disc_var_dict[tmp_variable_name] = variable_val
if isinstance(node, ast.Name):
if node.id in targ_name_list:
node:ast.Name
# Find corresponding targ_name in the targ_name_list
targ_name = node.id
var_index = targ_name_list.index(targ_name)
# Find the corresponding iter_name in the iter_name_list
iter_name = iter_name_list[var_index]
# Create the name for the tmp variable
iter_pos = iter_pos_list[var_index]
tmp_variable_name = f"{iter_name}_{iter_pos}"
# Return the modified node # Replace variables in the etl by using tmp variables
pass root = AttributeNameSubstituter(tmp_variable_name, node).visit(root)
class AttributeSubstituter(ast.NodeTransformer): # Find the value of the tmp variable in the cont/disc_var_dict
def __init__(self, name:str): # Add the tmp variables into the cont/disc_var_dict
variable_name = iter_name
variable_val = None
if variable_name in cont_var_dict:
variable_val = cont_var_dict[variable_name][iter_pos]
cont_var_dict[tmp_variable_name] = variable_val
elif variable_name in disc_var_dict:
variable_val = disc_var_dict[variable_name][iter_pos]
disc_var_dict[tmp_variable_name] = variable_val
# Return the modified node
return root
class AttributeNameSubstituter(ast.NodeTransformer):
def __init__(self, name:str, node):
super().__init__() super().__init__()
self.name = name self.name = name
self.node = node
def visit_Attribute(self, node: ast.Attribute) -> Any: def visit_Attribute(self, node: ast.Attribute) -> Any:
return ast.Name( if node == self.node:
id = self.name, return ast.Name(
ctx = ast.Load() id = self.name,
) ctx = ast.Load()
)
return node
def visit_Name(self, node: ast.Attribute) -> Any:
if node == self.node:
return ast.Name(
id = self.name,
ctx = ast.Load
)
return node
class FunctionCallSubstituter(ast.NodeTransformer): class FunctionCallSubstituter(ast.NodeTransformer):
def __init__(self, values:List[ast.Expr]): def __init__(self, values:List[ast.Expr], node):
super().__init__() super().__init__()
self.values = values self.values = values
self.node = node
def visit_Call(self, node: ast.Call) -> Any: def visit_Call(self, node: ast.Call) -> Any:
if node.func.id == 'any': if node == self.node:
raise ast.BoolOp( if node.func.id == 'any':
op = ast.Or(), return ast.BoolOp(
values = self.values op = ast.Or(),
) values = self.values
elif node.func.id == 'all': )
raise NotImplementedError elif node.func.id == 'all':
else: raise NotImplementedError
return node return node
def parse_any( def parse_any(
node: ast.Call, node: ast.Call,
...@@ -106,15 +152,15 @@ def parse_any( ...@@ -106,15 +152,15 @@ def parse_any(
iter_len_list = [] iter_len_list = []
# Get all the iter, targets and the length of iter list # Get all the iter, targets and the length of iter list
for generator in parse_arg.generators: for generator in parse_arg.generators:
iter_name_list.append(generator.iter.name) # a_list iter_name_list.append(generator.iter.id) # a_list
targ_name_list.append(generator.target.name) # a targ_name_list.append(generator.target.id) # a
iter_len_list.append(range(len_dict[generator.iter.name])) # len(a_list) iter_len_list.append(range(len_dict[generator.iter.id])) # len(a_list)
elt = generator.elt elt = parse_arg.elt
expand_elt_ast_list = [] expand_elt_ast_list = []
iter_len_list = list(itertools.product(*iter_len_list)) iter_len_list = list(itertools.product(*iter_len_list))
# Loop through all possible combination of iter value # Loop through all possible combination of iter value
for i in range(iter_len_list): for i in range(len(iter_len_list)):
changed_elt = copy.deepcopy(elt) changed_elt = copy.deepcopy(elt)
iter_pos_list = iter_len_list[i] iter_pos_list = iter_len_list[i]
# substitute temporary variable in each of the elt and add corresponding variables in the variable dicts # substitute temporary variable in each of the elt and add corresponding variables in the variable dicts
...@@ -122,7 +168,7 @@ def parse_any( ...@@ -122,7 +168,7 @@ def parse_any(
# Add the expanded elt into the list # Add the expanded elt into the list
expand_elt_ast_list.append(parsed_elt) expand_elt_ast_list.append(parsed_elt)
# Create the new boolop (or) node based on the list of expanded elt # Create the new boolop (or) node based on the list of expanded elt
return FunctionCallSubstituter(expand_elt_ast_list).visit(node) return FunctionCallSubstituter(expand_elt_ast_list, node).visit(node)
pass pass
if __name__ == "__main__": if __name__ == "__main__":
...@@ -130,5 +176,18 @@ if __name__ == "__main__": ...@@ -130,5 +176,18 @@ if __name__ == "__main__":
ego = State() ego = State()
code_any = "any((other.x -ego.x > 5 and other.type==Vehicle) for other in others)" code_any = "any((other.x -ego.x > 5 and other.type==Vehicle) for other in others)"
ast_any = ast.parse(code_any).body[0].value ast_any = ast.parse(code_any).body[0].value
parse_any(ast_any) cont_var_dict = {
"others.x":[1,2],
"others.y":[3,4],
"ego.x":5,
"ego.y":2
}
disc_var_dict = {
"others.type":['Vehicle','Sign'],
"ego.type":'Ped'
}
len_dict = {
"others":2
}
res = parse_any(ast_any, cont_var_dict, disc_var_dict, len_dict)
print(ast_any) print(ast_any)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment