diff --git a/demo/demo3.py b/demo/demo3.py index c0d57fba325d3328533727a5568e5cf78bdf0e09..dd1bb151fb050fcd5528fef802730fe90a4d873d 100644 --- a/demo/demo3.py +++ b/demo/demo3.py @@ -1,10 +1,11 @@ -from dryvr_plus_plus.example.example_agent.car_agent import CarAgent +from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent from dryvr_plus_plus.example.example_agent.car_agent import CarAgent from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6 from dryvr_plus_plus.plotter.plotter2D import * from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor3 +import matplotlib.pyplot as plt import plotly.graph_objects as go import numpy as np from enum import Enum, auto @@ -46,11 +47,11 @@ if __name__ == "__main__": car = CarAgent('car1', file_name=input_code_name) scenario.add_agent(car) - car = CarAgent('car2', file_name=input_code_name) + car = NPCAgent('car2') scenario.add_agent(car) - car = CarAgent('car3', file_name=input_code_name) + car = NPCAgent('car3') scenario.add_agent(car) - car = CarAgent('car4', file_name=input_code_name) + car = NPCAgent('car4') scenario.add_agent(car) tmp_map = SimpleMap3() scenario.set_map(tmp_map) @@ -74,7 +75,7 @@ if __name__ == "__main__": # fig = plt.figure(2) # fig = plot_map(tmp_map, 'g', fig) - # fig = plot_reachtube_tree(traces, 'car1', 0, [1], 'b', fig) + # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) # fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig) # fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig) @@ -84,8 +85,8 @@ if __name__ == "__main__": # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig) # fig = plot_simulation_tree(traces, 'car3', 1, [2], 'r', fig) # fig = plot_simulation_tree(traces, 'car4', 1, [2], 'r', fig) - # plt.show() + fig = go.Figure() fig = plotly_simulation_anime(traces, tmp_map, fig) diff --git a/demo/demo4.py b/demo/demo4.py index 39c718c602a1f797b6ea3ea9f7cdbf268cc7aac6..2a9843f25ab4fb06ff8c031e3571f82b65acd670 100644 --- a/demo/demo4.py +++ b/demo/demo4.py @@ -5,7 +5,7 @@ from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMa from dryvr_plus_plus.plotter.plotter2D import * from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor3 -# import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import plotly.graph_objects as go import numpy as np from enum import Enum, auto @@ -84,12 +84,12 @@ if __name__ == "__main__": # fig = plt.figure(2) # fig = plot_map(tmp_map, 'g', fig) - # # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) - # # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) - # # fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig) - # # fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig) - # # fig = plot_reachtube_tree(traces, 'car5', 1, [2], 'r', fig) - # # fig = plot_reachtube_tree(traces, 'car6', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) + # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car5', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car6', 1, [2], 'r', fig) # for traces in res_list: # # generate_simulation_anime(traces, tmp_map, fig) # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig) diff --git a/dryvr_plus_plus/scene_verifier/automaton/guard.py b/dryvr_plus_plus/scene_verifier/automaton/guard.py index 541233310a96df8df9ec19ede97e3b04ebfe9068..0370c8371ac540a4ca3e785c6d38874adb6e42d1 100644 --- a/dryvr_plus_plus/scene_verifier/automaton/guard.py +++ b/dryvr_plus_plus/scene_verifier/automaton/guard.py @@ -28,8 +28,10 @@ class NodeSubstituter(ast.NodeTransformer): def visit_Call(self, node: ast.Call) -> Any: if node == self.old_node: + self.generic_visit(node) return self.new_node else: + self.generic_visit(node) return node class ValueSubstituter(ast.NodeTransformer): @@ -39,6 +41,7 @@ class ValueSubstituter(ast.NodeTransformer): self.node = node def visit_Attribute(self, node: ast.Attribute) -> Any: + # Substitute attribute node in the ast if node == self.node: return ast.Name( id = self.val, @@ -47,6 +50,7 @@ class ValueSubstituter(ast.NodeTransformer): return node def visit_Name(self, node: ast.Attribute) -> Any: + # Substitute name node in the ast if node == self.node: return ast.Name( id = self.val, @@ -55,19 +59,24 @@ class ValueSubstituter(ast.NodeTransformer): return node def visit_Call(self, node: ast.Call) -> Any: + # Substitute call node in the ast if node == self.node: if len(self.val) == 1: + self.generic_visit(node) return self.val[0] elif node.func.id == 'any': + self.generic_visit(node) return ast.BoolOp( op = ast.Or(), values = self.val ) elif node.func.id == 'all': + self.generic_visit(node) return ast.BoolOp( op = ast.And(), values = self.val ) + self.generic_visit(node) return node @@ -811,6 +820,9 @@ class GuardExpressionAst: # Find the value of the tmp variable in the cont/disc_var_dict # Add the tmp variables into the cont/disc_var_dict + # NOTE: At each time step, for each agent, the variable value mapping and their + # sequence in the list is single. Therefore, for the same key, we will always rewrite + # its content. variable_name = iter_name + '.' + node.attr variable_val = None if variable_name in cont_var_dict: @@ -820,7 +832,7 @@ class GuardExpressionAst: variable_val = disc_var_dict[variable_name][iter_pos] disc_var_dict[tmp_variable_name] = variable_val - if isinstance(node, ast.Name): + elif isinstance(node, ast.Name): if node.id in targ_name_list: node:ast.Name # Find corresponding targ_name in the targ_name_list diff --git a/parse.py b/parse.py deleted file mode 100644 index e25b97bcd4eb387064e60803cf3834897e4b3412..0000000000000000000000000000000000000000 --- a/parse.py +++ /dev/null @@ -1,141 +0,0 @@ -import ast, copy -from typing import List, Dict, Optional -import astunparse - - -class Scope: - scopes: List[Dict[str, ast.AST]] - def __init__(self): - self.scopes = [{}] - - def push(self): - self.scopes = [{}] + self.scopes - - def pop(self): - self.scopes = self.scopes[1:] - - def lookup(self, key): - for scope in self.scopes: - if key in scope: - return scope[key] - return None - - def set(self, key, val): - for scope in self.scopes: - if key in scope: - scope[key] = val - return - self.scopes[0][key] = val - - def dump(self, dump=False): - ast_dump = lambda node: ast.dump(node, indent=2) if dump else astunparse.unparse(node) - for scope in self.scopes: - for k, node in scope.items(): - print(f"{k}: ", end="") - if isinstance(node, Function): - print(f"Function args: {node.args} body: {ast_dump(node.body)}") - else: - print(ast_dump(node)) - print("===") - -class VarSubstituter(ast.NodeTransformer): - args: Dict[str, ast.expr] - def __init__(self, args): - super().__init__() - self.args = args - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Load) and node.id in self.args: - return self.args[node.id] - self.generic_visit(node) - return node # XXX needed? - -class Function: - args: List[str] - body: ast.expr - def __init__(self, args, body): - self.args = args - self.body = body - - @staticmethod - def from_func_def(fd: ast.FunctionDef, scope: Scope) -> "Function": - args = [a.arg for a in fd.args.args] - scope.push() - for a in args: - scope.set(a, ast.Name(a, ctx=ast.Load())) - ret = None - for node in fd.body: - ret = proc(node, scope) - scope.pop() - return Function(args, ret) - - def apply(self, args: List[ast.expr]) -> ast.expr: - ret = copy.deepcopy(self.body) - args = {k: v for k, v in zip(self.args, args)} - return VarSubstituter(args).visit(ret) - -def proc(node: ast.AST, scope: Scope): - if isinstance(node, ast.Module): - for node in node.body: - proc(node, scope) - elif isinstance(node, ast.For) or isinstance(node, ast.While): - raise NotImplementedError("loops not supported") - elif isinstance(node, ast.If): - node.test = proc(node.test, scope) - for node in node.body: # FIXME properly handle branching - proc(node, scope) - for node in node.orelse: - proc(node, scope) - elif isinstance(node, ast.Assign): - if len(node.targets) == 1: - target = node.targets[0] - if isinstance(target, ast.Name): - scope.set(target.id, proc(node.value, scope)) - elif isinstance(target, ast.Attribute): - raise NotImplementedError("assign.attr") - else: - raise NotImplementedError("unpacking not supported") - elif isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): - return scope.lookup(node.id) - elif isinstance(node, ast.FunctionDef): - scope.set(node.name, Function.from_func_def(node, scope)) - elif isinstance(node, ast.ClassDef): - pass - elif isinstance(node, ast.UnaryOp): - return ast.UnaryOp(node.op, proc(node.operand, scope)) - elif isinstance(node, ast.BinOp): - return ast.BinOp(proc(node.left, scope), node.op, proc(node.right, scope)) - elif isinstance(node, ast.BoolOp): - return ast.BoolOp(node.op, [proc(val, scope) for val in node.values]) - elif isinstance(node, ast.Compare): - if len(node.ops) > 1 or len(node.comparators) > 1: - raise NotImplementedError("too many comparisons") - return ast.Compare(proc(node.left, scope), node.ops, [proc(node.comparators[0], scope)]) - elif isinstance(node, ast.Call): - fun = proc(node.func, scope) - if not isinstance(fun, Function): - raise Exception("???") - return fun.apply([proc(a, scope) for a in node.args]) - elif isinstance(node, ast.List): - return ast.List([proc(e, scope) for e in node.elts]) - elif isinstance(node, ast.Tuple): - return ast.Tuple([proc(e, scope) for e in node.elts]) - elif isinstance(node, ast.Return): - return proc(node.value, scope) - elif isinstance(node, ast.Constant): - return node # XXX simplification? - -def parse(fn: str): - with open(fn) as f: - cont = f.read() - root = ast.parse(cont, fn) - scope = Scope() - proc(root, scope) - scope.dump() - -if __name__ == "__main__": - import sys - if len(sys.argv) != 2: - print("usage: parse.py <file.py>") - sys.exit(1) - parse(sys.argv[1]) diff --git a/test.py b/test.py deleted file mode 100644 index 873f5c159b68c4589690be596e7eb045db41b064..0000000000000000000000000000000000000000 --- a/test.py +++ /dev/null @@ -1,7 +0,0 @@ -def f(x, b): - c = x < 3 - d = x + 2 * b - return d > 10 or c -x = 10 -y = 20 + x -z = f(y, x) \ No newline at end of file