diff --git a/develop/parse_any_all.py b/develop/parse_any_all.py index c9c2ff637ed30a8bbf9141bd47020f1764efa3d0..6dfc7f434f4a7d61e2cbc62ecb2336278fbc6601 100644 --- a/develop/parse_any_all.py +++ b/develop/parse_any_all.py @@ -123,7 +123,9 @@ class ValueSubstituter(ast.NodeTransformer): def visit_Call(self, node: ast.Call) -> Any: if node == self.node: - if node.func.id == 'any': + if len(self.val) == 1: + return self.val[0] + elif node.func.id == 'any': return ast.BoolOp( op = ast.Or(), values = self.val @@ -194,11 +196,11 @@ def unroll_any_all(root, cont_var_dict: Dict[str, float], disc_var_dict: Dict[st return root if __name__ == "__main__": - others = [State(), State(), State(), State()] + others = [State()] ego = State() - code_any = "any((any(other.y < 100 for other in others) and other.x -ego.x > 5 and other.type==Vehicle) for other in others)" - # code_any = "all((other.x -ego.x > 5 and other.type==Vehicle) for other in others)" + # code_any = "any((any(other.y < 100 for other in others) and other.x -ego.x > 5 and other.type==Vehicle) for other in others)" + code_any = "all((other.x -ego.x > 5 and other.type==Vehicle) for other in others)" ast_any = ast.parse(code_any).body[0].value cont_var_dict = { "others.x":[1,2,4,5],