From b088099ef7befa1cc5418215b1a90de87db187dd Mon Sep 17 00:00:00 2001 From: unknown <lyg1597@gmail.com> Date: Tue, 7 Jun 2022 21:51:09 -0500 Subject: [PATCH] get a prototype implementation to handle any --- develop/parse_any_all.py | 119 +++++++++++++++++++++++++++++---------- 1 file changed, 89 insertions(+), 30 deletions(-) diff --git a/develop/parse_any_all.py b/develop/parse_any_all.py index 132725a4..2caa5cba 100644 --- a/develop/parse_any_all.py +++ b/develop/parse_any_all.py @@ -45,52 +45,98 @@ def _parse_elt(root, cont_var_dict, disc_var_dict, iter_name_list, targ_name_lis if node.value.id in targ_name_list: # Find corresponding targ_name in the targ_name_list targ_name = node.value.id - var_index = targ_name_list.find(targ_name) + var_index = targ_name_list.index(targ_name) # Find the corresponding iter_name in the iter_name_list iter_name = iter_name_list[var_index] # Create the name for the tmp variable iter_pos = iter_pos_list[var_index] - tmp_variable_name = f"{iter_name}_{iter_pos}" + tmp_variable_name = f"{iter_name}_{iter_pos}.{node.attr}" # Replace variables in the etl by using tmp variables - AttributeSubstituter(tmp_variable_name).visit(node) + root = AttributeNameSubstituter(tmp_variable_name, node).visit(root) # Find the value of the tmp variable in the cont/disc_var_dict - # Add the tmp variables into the cont/disc_var_dict + variable_name = iter_name + '.' + node.attr + variable_val = None + if variable_name in cont_var_dict: + variable_val = cont_var_dict[variable_name][iter_pos] + cont_var_dict[tmp_variable_name] = variable_val + elif variable_name in disc_var_dict: + variable_val = disc_var_dict[variable_name][iter_pos] + disc_var_dict[tmp_variable_name] = variable_val + + if isinstance(node, ast.Name): + if node.id in targ_name_list: + node:ast.Name + # Find corresponding targ_name in the targ_name_list + targ_name = node.id + var_index = targ_name_list.index(targ_name) + + # Find the corresponding iter_name in the iter_name_list + iter_name = iter_name_list[var_index] + + # Create the name for the tmp variable + iter_pos = iter_pos_list[var_index] + tmp_variable_name = f"{iter_name}_{iter_pos}" - # Return the modified node - pass + # Replace variables in the etl by using tmp variables + root = AttributeNameSubstituter(tmp_variable_name, node).visit(root) -class AttributeSubstituter(ast.NodeTransformer): - def __init__(self, name:str): + # Find the value of the tmp variable in the cont/disc_var_dict + # Add the tmp variables into the cont/disc_var_dict + variable_name = iter_name + variable_val = None + if variable_name in cont_var_dict: + variable_val = cont_var_dict[variable_name][iter_pos] + cont_var_dict[tmp_variable_name] = variable_val + elif variable_name in disc_var_dict: + variable_val = disc_var_dict[variable_name][iter_pos] + disc_var_dict[tmp_variable_name] = variable_val + + # Return the modified node + return root + +class AttributeNameSubstituter(ast.NodeTransformer): + def __init__(self, name:str, node): super().__init__() self.name = name + self.node = node def visit_Attribute(self, node: ast.Attribute) -> Any: - return ast.Name( - id = self.name, - ctx = ast.Load() - ) - + if node == self.node: + return ast.Name( + id = self.name, + ctx = ast.Load() + ) + return node + + def visit_Name(self, node: ast.Attribute) -> Any: + if node == self.node: + return ast.Name( + id = self.name, + ctx = ast.Load + ) + return node class FunctionCallSubstituter(ast.NodeTransformer): - def __init__(self, values:List[ast.Expr]): + def __init__(self, values:List[ast.Expr], node): super().__init__() self.values = values + self.node = node def visit_Call(self, node: ast.Call) -> Any: - if node.func.id == 'any': - raise ast.BoolOp( - op = ast.Or(), - values = self.values - ) - elif node.func.id == 'all': - raise NotImplementedError - else: - return node + if node == self.node: + if node.func.id == 'any': + return ast.BoolOp( + op = ast.Or(), + values = self.values + ) + elif node.func.id == 'all': + raise NotImplementedError + return node def parse_any( node: ast.Call, @@ -106,15 +152,15 @@ def parse_any( iter_len_list = [] # Get all the iter, targets and the length of iter list for generator in parse_arg.generators: - iter_name_list.append(generator.iter.name) # a_list - targ_name_list.append(generator.target.name) # a - iter_len_list.append(range(len_dict[generator.iter.name])) # len(a_list) + iter_name_list.append(generator.iter.id) # a_list + targ_name_list.append(generator.target.id) # a + iter_len_list.append(range(len_dict[generator.iter.id])) # len(a_list) - elt = generator.elt + elt = parse_arg.elt expand_elt_ast_list = [] iter_len_list = list(itertools.product(*iter_len_list)) # Loop through all possible combination of iter value - for i in range(iter_len_list): + for i in range(len(iter_len_list)): changed_elt = copy.deepcopy(elt) iter_pos_list = iter_len_list[i] # substitute temporary variable in each of the elt and add corresponding variables in the variable dicts @@ -122,7 +168,7 @@ def parse_any( # Add the expanded elt into the list expand_elt_ast_list.append(parsed_elt) # Create the new boolop (or) node based on the list of expanded elt - return FunctionCallSubstituter(expand_elt_ast_list).visit(node) + return FunctionCallSubstituter(expand_elt_ast_list, node).visit(node) pass if __name__ == "__main__": @@ -130,5 +176,18 @@ if __name__ == "__main__": ego = State() code_any = "any((other.x -ego.x > 5 and other.type==Vehicle) for other in others)" ast_any = ast.parse(code_any).body[0].value - parse_any(ast_any) + cont_var_dict = { + "others.x":[1,2], + "others.y":[3,4], + "ego.x":5, + "ego.y":2 + } + disc_var_dict = { + "others.type":['Vehicle','Sign'], + "ego.type":'Ped' + } + len_dict = { + "others":2 + } + res = parse_any(ast_any, cont_var_dict, disc_var_dict, len_dict) print(ast_any) -- GitLab