diff --git a/demo/demo1.py b/demo/demo1.py index a686374505c36ab2a53edc6c729da4166d24982d..2a776b9520a6638a8d33d5ece4a176e465738851 100644 --- a/demo/demo1.py +++ b/demo/demo1.py @@ -8,17 +8,20 @@ import matplotlib.pyplot as plt import numpy as np from enum import Enum, auto + class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() SwitchRight = auto() Brake = auto() + class LaneMode(Enum): Lane0 = auto() Lane1 = auto() Lane2 = auto() + class State: x = 0.0 y = 0.0 @@ -30,6 +33,7 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): self.data = [] + if __name__ == "__main__": input_code_name = 'example_controller1.py' scenario = Scenario() @@ -38,31 +42,38 @@ if __name__ == "__main__": scenario.add_agent(car) car = CarAgent('car2', file_name=input_code_name) scenario.add_agent(car) - tmp_map = SimpleMap3() + tmp_map = SimpleMap2() scenario.set_map(tmp_map) scenario.set_sensor(FakeSensor2()) scenario.set_init( [ - [[10.0, 0, 0, 0.5],[10.0, 0, 0, 0.5]], - [[5.0, -0.2, 0, 2.0],[6.0, 0.2, 0, 3.0]], + [[10.0, 0, 0, 0.5], [10.0, 0, 0, 0.5]], + [[5.0, -0.2, 0, 2.0], [6.0, 0.2, 0, 3.0]], ], [ (VehicleMode.Normal, LaneMode.Lane1), (VehicleMode.Normal, LaneMode.Lane1), ] ) - res_list = scenario.simulate_multi(10,1) + res_list = scenario.simulate_multi(10, 1) # traces = scenario.verify(10) - fig = plt.figure(2) + # fig = plt.figure(2) # fig = plot_map(tmp_map, 'g', fig) - # fig = plot_reachtube_tree(traces, 'car1', 0, [1], 'b', fig, (1000,-1000), (1000,-1000)) - # fig = plot_reachtube_tree(traces, 'car2', 0, [1], 'r', fig) - for traces in res_list: - # fig = plot_simulation_tree(traces, 'car1', 0, [1], 'b', fig, (1000,-1000), (1000,-1000)) - # fig = plot_simulation_tree(traces, 'car2', 0, [1], 'r', fig) - generate_simulation_anime(traces, tmp_map, fig) - + # plt.show() + # # fig = plot_reachtube_tree(traces, 'car1', 0, [1], 'b', fig, (1000,-1000), (1000,-1000)) + # # fig = plot_reachtube_tree(traces, 'car2', 0, [1], 'r', fig) + # # AnalysisTreeNode + # for traces in res_list: + # # fig = plot_simulation_tree( + # # traces, 'car1', 0, [1], 'b', fig, (1000, -1000), (1000, -1000)) + # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig) + # # generate_simulation_anime(traces, tmp_map, fig) + # plt.show() - plt.show() + fig = go.Figure() + for traces in res_list: + # plotly_map(tmp_map, 'g', fig) + fig = plotly_simulation_anime(traces, tmp_map, fig) + fig.show() diff --git a/demo/demo2.py b/demo/demo2.py index 43df7dfd4b38ca6cb3b06acb3e93c63df43aa841..9031e1a3c69c66d2f1058e3fcf0701e903beccaf 100644 --- a/demo/demo2.py +++ b/demo/demo2.py @@ -3,22 +3,25 @@ from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6 from dryvr_plus_plus.plotter.plotter2D import * from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2 - +import plotly.graph_objects as go import matplotlib.pyplot as plt import numpy as np from enum import Enum, auto + class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() SwitchRight = auto() Brake = auto() + class LaneMode(Enum): Lane0 = auto() Lane1 = auto() Lane2 = auto() + class State: x = 0.0 y = 0.0 @@ -30,6 +33,7 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): self.data = [] + if __name__ == "__main__": input_code_name = 'example_controller2.py' scenario = Scenario() @@ -43,8 +47,8 @@ if __name__ == "__main__": scenario.set_sensor(FakeSensor2()) scenario.set_init( [ - [[0, -0.2, 0, 1.0],[0.1, 0.2, 0, 1.0]], - [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], + [[0, -0.2, 0, 1.0], [0.1, 0.2, 0, 1.0]], + [[10, 0, 0, 0.5], [10, 0, 0, 0.5]], ], [ (VehicleMode.Normal, LaneMode.Lane1), @@ -52,15 +56,22 @@ if __name__ == "__main__": ] ) # res_list = scenario.simulate_multi(40,1) - traces = scenario.verify(40) + # traces = scenario.verify(40) + traces = scenario.simulate(40) - fig = plt.figure(2) - fig = plot_map(tmp_map, 'g', fig) - fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) - fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) - # for traces in res_list: - # generate_simulation_anime(traces, tmp_map, fig) - # # fig,x_lim,y_lim = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig,x_lim,y_lim) - # # fig,x_lim,y_lim = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig,x_lim,y_lim) + # fig = plt.figure(2) + # fig = plot_map(tmp_map, 'g', fig) + # # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig) + # # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig) + # # # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) + # # # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) + # # plt.show() + # # # fig1 = plt.figure(2) + # fig = generate_simulation_anime(traces, tmp_map, fig) + # plt.show() - plt.show() + fig = go.Figure() + # fig = plotly_simulation_tree(traces, 'car1', 1, [2], 'b', fig) + # fig.show() + fig = plotly_simulation_anime(traces, tmp_map, fig) + fig.show() diff --git a/dryvr_plus_plus/__init__.py b/demo/dryvr_plus_plus/__init__.py similarity index 100% rename from dryvr_plus_plus/__init__.py rename to demo/dryvr_plus_plus/__init__.py diff --git a/dryvr_plus_plus/example/__init__.py b/demo/dryvr_plus_plus/example/__init__.py similarity index 100% rename from dryvr_plus_plus/example/__init__.py rename to demo/dryvr_plus_plus/example/__init__.py diff --git a/dryvr_plus_plus/example/example_agent/__init__.py b/demo/dryvr_plus_plus/example/example_agent/__init__.py similarity index 100% rename from dryvr_plus_plus/example/example_agent/__init__.py rename to demo/dryvr_plus_plus/example/example_agent/__init__.py diff --git a/dryvr_plus_plus/example/example_agent/car_agent.py b/demo/dryvr_plus_plus/example/example_agent/car_agent.py similarity index 73% rename from dryvr_plus_plus/example/example_agent/car_agent.py rename to demo/dryvr_plus_plus/example/example_agent/car_agent.py index 9d1212f514f13bff40bc9f6a7263602ab4f249b3..cfad53fbf357e106ecd6bf1ab21233d98f8ba03e 100644 --- a/dryvr_plus_plus/example/example_agent/car_agent.py +++ b/demo/dryvr_plus_plus/example/example_agent/car_agent.py @@ -1,12 +1,13 @@ from typing import Tuple, List -import numpy as np +import numpy as np from scipy.integrate import ode from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap from dryvr_plus_plus.scene_verifier.code_parser.pythonparser import EmptyAst + class NPCAgent(BaseAgent): def __init__(self, id): self.id = id @@ -15,73 +16,77 @@ class NPCAgent(BaseAgent): @staticmethod def dynamic(t, state, u): x, y, theta, v = state - delta, a = u + delta, a = u x_dot = v*np.cos(theta+delta) y_dot = v*np.sin(theta+delta) theta_dot = v/1.75*np.sin(delta) - v_dot = a + v_dot = a return [x_dot, y_dot, theta_dot, v_dot] - def action_handler(self, mode, state, lane_map:LaneMap)->Tuple[float, float]: - x,y,theta,v = state + def action_handler(self, mode, state, lane_map: LaneMap) -> Tuple[float, float]: + x, y, theta, v = state vehicle_mode = mode[0] vehicle_lane = mode[1] - vehicle_pos = np.array([x,y]) + vehicle_pos = np.array([x, y]) d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta steering = psi + np.arctan2(0.45*d, v) steering = np.clip(steering, -0.61, 0.61) a = 0 - return steering, a + return steering, a - def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map:LaneMap=None)->np.ndarray: + def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map: LaneMap = None) -> np.ndarray: time_step = 0.05 time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) - t = [i*time_step for i in range(0,number_points)] + t = [i*time_step for i in range(0, number_points)] init = initialCondition trace = [[0]+init] for i in range(len(t)): steering, a = self.action_handler(mode, init, lane_map) - r = ode(self.dynamic) - r.set_initial_value(init).set_f_params([steering, a]) - res:np.ndarray = r.integrate(r.t + time_step) + r = ode(self.dynamic) + r.set_initial_value(init).set_f_params([steering, a]) + res: np.ndarray = r.integrate(r.t + time_step) init = res.flatten().tolist() - trace.append([t[i] + time_step] + init) + trace.append([t[i] + time_step] + init) return np.array(trace) + class CarAgent(BaseAgent): - def __init__(self, id, code = None, file_name = None): + def __init__(self, id, code=None, file_name=None): super().__init__(id, code, file_name) @staticmethod def dynamic(t, state, u): x, y, theta, v = state - delta, a = u + delta, a = u x_dot = v*np.cos(theta+delta) y_dot = v*np.sin(theta+delta) theta_dot = v/1.75*np.sin(delta) - v_dot = a + v_dot = a return [x_dot, y_dot, theta_dot, v_dot] - def action_handler(self, mode: List[str], state, lane_map:LaneMap)->Tuple[float, float]: - x,y,theta,v = state + def action_handler(self, mode: List[str], state, lane_map: LaneMap) -> Tuple[float, float]: + x, y, theta, v = state vehicle_mode = mode[0] vehicle_lane = mode[1] - vehicle_pos = np.array([x,y]) + vehicle_pos = np.array([x, y]) a = 0 if vehicle_mode == "Normal": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) + # # keyi: just toy mod + # if v <= 2: + # a = 0.2 elif vehicle_mode == "SwitchLeft": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) + 3 elif vehicle_mode == "SwitchRight": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - 3 elif vehicle_mode == "Brake": d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - a = -1 - if v<=0.02: + a = -1 + if v <= 0.02: a = 0 elif vehicle_mode == 'Stop': d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) @@ -89,24 +94,25 @@ class CarAgent(BaseAgent): psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta steering = psi + np.arctan2(0.45*d, v) steering = np.clip(steering, -0.61, 0.61) - return steering, a + return steering, a - def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map:LaneMap=None)->np.ndarray: + def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map: LaneMap = None) -> np.ndarray: time_step = 0.05 time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) - t = [i*time_step for i in range(0,number_points)] + t = [i*time_step for i in range(0, number_points)] init = initialCondition + # [time, x, y, theta, v] trace = [[0]+init] for i in range(len(t)): steering, a = self.action_handler(mode, init, lane_map) - r = ode(self.dynamic) - r.set_initial_value(init).set_f_params([steering, a]) - res:np.ndarray = r.integrate(r.t + time_step) + r = ode(self.dynamic) + r.set_initial_value(init).set_f_params([steering, a]) + res: np.ndarray = r.integrate(r.t + time_step) init = res.flatten().tolist() if init[3] < 0: init[3] = 0 - trace.append([t[i] + time_step] + init) + trace.append([t[i] + time_step] + init) return np.array(trace) diff --git a/dryvr_plus_plus/example/example_agent/sign_agent.py b/demo/dryvr_plus_plus/example/example_agent/sign_agent.py similarity index 100% rename from dryvr_plus_plus/example/example_agent/sign_agent.py rename to demo/dryvr_plus_plus/example/example_agent/sign_agent.py diff --git a/dryvr_plus_plus/example/example_map/__init__.py b/demo/dryvr_plus_plus/example/example_map/__init__.py similarity index 100% rename from dryvr_plus_plus/example/example_map/__init__.py rename to demo/dryvr_plus_plus/example/example_map/__init__.py diff --git a/dryvr_plus_plus/example/example_map/simple_map.py b/demo/dryvr_plus_plus/example/example_map/simple_map.py similarity index 100% rename from dryvr_plus_plus/example/example_map/simple_map.py rename to demo/dryvr_plus_plus/example/example_map/simple_map.py diff --git a/dryvr_plus_plus/example/example_map/simple_map2.py b/demo/dryvr_plus_plus/example/example_map/simple_map2.py similarity index 100% rename from dryvr_plus_plus/example/example_map/simple_map2.py rename to demo/dryvr_plus_plus/example/example_map/simple_map2.py diff --git a/dryvr_plus_plus/example/example_sensor/__init__.py b/demo/dryvr_plus_plus/example/example_sensor/__init__.py similarity index 100% rename from dryvr_plus_plus/example/example_sensor/__init__.py rename to demo/dryvr_plus_plus/example/example_sensor/__init__.py diff --git a/dryvr_plus_plus/example/example_sensor/fake_sensor.py b/demo/dryvr_plus_plus/example/example_sensor/fake_sensor.py similarity index 100% rename from dryvr_plus_plus/example/example_sensor/fake_sensor.py rename to demo/dryvr_plus_plus/example/example_sensor/fake_sensor.py diff --git a/dryvr_plus_plus/plotter/__init__.py b/demo/dryvr_plus_plus/plotter/__init__.py similarity index 100% rename from dryvr_plus_plus/plotter/__init__.py rename to demo/dryvr_plus_plus/plotter/__init__.py diff --git a/dryvr_plus_plus/plotter/parser.py b/demo/dryvr_plus_plus/plotter/parser.py similarity index 100% rename from dryvr_plus_plus/plotter/parser.py rename to demo/dryvr_plus_plus/plotter/parser.py diff --git a/demo/dryvr_plus_plus/plotter/plotter2D.py b/demo/dryvr_plus_plus/plotter/plotter2D.py new file mode 100644 index 0000000000000000000000000000000000000000..7540effb80b41017e95543373dae2e592a4b8b5b --- /dev/null +++ b/demo/dryvr_plus_plus/plotter/plotter2D.py @@ -0,0 +1,560 @@ +""" +This file consist main plotter code for DryVR reachtube output +""" + +from __future__ import annotations +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +from math import pi +import plotly.express as px +import plotly.graph_objects as go +from typing import List +from PIL import Image, ImageDraw +import io +from collections import OrderedDict +from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode + +colors = ['red', 'green', 'blue', 'yellow', 'black'] + + +def plot( + data, + x_dim: int = 0, + y_dim_list: List[int] = [1], + color='b', + fig=None, + x_lim=None, + y_lim=None +): + if fig is None: + fig = plt.figure() + + ax = fig.gca() + if x_lim is None: + x_lim = ax.get_xlim() + if y_lim is None: + y_lim = ax.get_ylim() + + x_min, x_max = x_lim + y_min, y_max = y_lim + for rect in data: + lb = rect[0] + ub = rect[1] + for y_dim in y_dim_list: + rect_patch = patches.Rectangle( + (lb[x_dim], lb[y_dim]), ub[x_dim]-lb[x_dim], ub[y_dim]-lb[y_dim], color=color) + ax.add_patch(rect_patch) + x_min = min(lb[x_dim], x_min) + y_min = min(lb[y_dim], y_min) + x_max = max(ub[x_dim], x_max) + y_max = max(ub[y_dim], y_max) + + ax.set_xlim([x_min-1, x_max+1]) + ax.set_ylim([y_min-1, y_max+1]) + return fig, (x_min, x_max), (y_min, y_max) + + +def plot_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None): + if fig is None: + fig = plt.figure() + + ax = fig.gca() + if x_lim is None: + x_lim = ax.get_xlim() + if y_lim is None: + y_lim = ax.get_ylim() + + queue = [root] + while queue != []: + node = queue.pop(0) + traces = node.trace + trace = traces[agent_id] + data = [] + for i in range(0, len(trace), 2): + data.append([trace[i], trace[i+1]]) + fig, x_lim, y_lim = plot( + data, x_dim, y_dim_list, color, fig, x_lim, y_lim) + + queue += node.child + + return fig + + +def plotly_map(map, color='b', fig=None, x_lim=None, y_lim=None): + if fig is None: + fig = go.figure() + + for lane_idx in map.lane_dict: + lane = map.lane_dict[lane_idx] + for lane_seg in lane.segment_list: + if lane_seg.type == 'Straight': + start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral + end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral + # fig.add_trace(go.Scatter(x=[start1[0], end1[0]], y=[start1[1], end1[1]], + # mode='lines', + # line_color='black', + # showlegend=False, + # # text=theta, + # name='lines')) + start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral + end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral + # fig.add_trace(go.Scatter(x=[start2[0], end2[0]], y=[start2[1], end2[1]], + # mode='lines', + # line_color='black', + # showlegend=False, + # # text=theta, + # name='lines')) + fig.add_trace(go.Scatter(x=[start1[0], end1[0], end2[0], start2[0]], y=[start1[1], end1[1], end2[1], start2[1]], + mode='lines', + line_color='black', + # fill='toself', + # fillcolor='rgba(255,255,255,0)', + # line_color='rgba(0,0,0,0)', + showlegend=False, + # text=theta, + name='lines')) + elif lane_seg.type == "Circular": + phase_array = np.linspace( + start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100) + r1 = lane_seg.radius - lane_seg.width/2 + x = np.cos(phase_array)*r1 + lane_seg.center[0] + y = np.sin(phase_array)*r1 + lane_seg.center[1] + fig.add_trace(go.Scatter(x=x, y=y, + mode='lines', + line_color='black', + showlegend=False, + # text=theta, + name='lines')) + + r2 = lane_seg.radius + lane_seg.width/2 + x = np.cos(phase_array)*r2 + lane_seg.center[0] + y = np.sin(phase_array)*r2 + lane_seg.center[1] + fig.add_trace(go.Scatter(x=x, y=y, + mode='lines', + line_color='black', + showlegend=False, + # text=theta, + name='lines')) + else: + raise ValueError(f'Unknown lane segment type {lane_seg.type}') + return fig + + +def plot_map(map, color='b', fig=None, x_lim=None, y_lim=None): + if fig is None: + fig = plt.figure() + + ax = fig.gca() + if x_lim is None: + x_lim = ax.get_xlim() + if y_lim is None: + y_lim = ax.get_ylim() + + for lane_idx in map.lane_dict: + lane = map.lane_dict[lane_idx] + for lane_seg in lane.segment_list: + if lane_seg.type == 'Straight': + start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral + end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral + + ax.plot([start1[0], end1[0]], [start1[1], end1[1]], color) + start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral + end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral + ax.plot([start2[0], end2[0]], [start2[1], end2[1]], color) + elif lane_seg.type == "Circular": + phase_array = np.linspace( + start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100) + r1 = lane_seg.radius - lane_seg.width/2 + x = np.cos(phase_array)*r1 + lane_seg.center[0] + y = np.sin(phase_array)*r1 + lane_seg.center[1] + ax.plot(x, y, color) + + r2 = lane_seg.radius + lane_seg.width/2 + x = np.cos(phase_array)*r2 + lane_seg.center[0] + y = np.sin(phase_array)*r2 + lane_seg.center[1] + ax.plot(x, y, color) + else: + raise ValueError(f'Unknown lane segment type {lane_seg.type}') + return fig + + +def plotly_simulation_tree(root: AnalysisTreeNode, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None): + if fig is None: + fig = go.Figure() + i = 0 + fg_color = ['rgb(31,119,180)', 'rgb(255,127,14)', 'rgb(44,160,44)', 'rgb(214,39,40)', 'rgb(148,103,189)', + 'rgb(140,86,75)', 'rgb(227,119,194)', 'rgb(127,127,127)', 'rgb(188,189,34)', 'rgb(23,190,207)'] + bg_color = ['rgba(31,119,180,0.2)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)', 'rgba(214,39,40,0.2)', 'rgba(148,103,189,0.2)', + 'rgba(140,86,75,0.2)', 'rgba(227,119,194,0.2)', 'rgba(127,127,127,0.2)', 'rgba(188,189,34,0.2)', 'rgba(23,190,207,0.2)'] + queue = [root] + while queue != []: + node = queue.pop(0) + traces = node.trace + print(node.mode) + # [[time,x,y,theta,v]...] + trace = np.array(traces[agent_id]) + # print(trace) + for y_dim in y_dim_list: + trace_y = trace[:, y_dim].tolist() + trace_x = trace[:, x_dim].tolist() + theta = [i/pi*180 for i in trace[:, x_dim+2]] + trace_x_rev = trace_x[::-1] + # print(trace_x) + trace_upper = [i+1 for i in trace_y] + trace_lower = [i-1 for i in trace_y] + trace_lower = trace_lower[::-1] + # print(trace_upper) + # print(trace[:, y_dim]) + fig.add_trace(go.Scatter(x=trace_x+trace_x_rev, y=trace_upper+trace_lower, + fill='toself', + fillcolor=bg_color[i % 10], + line_color='rgba(255,255,255,0)', + showlegend=False)) + fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim], + mode='lines', + line_color=fg_color[i % 10], + text=theta, + name='lines')) + i += 1 + queue += node.child + fig.update_traces(mode='lines') + return fig + + +def plot_simulation_tree(root: AnalysisTreeNode, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None): + if fig is None: + fig = plt.figure() + + ax = fig.gca() + if x_lim is None: + x_lim = ax.get_xlim() + if y_lim is None: + y_lim = ax.get_ylim() + + x_min, x_max = x_lim + y_min, y_max = y_lim + + queue = [root] + while queue != []: + node = queue.pop(0) + traces = node.trace + print(node.mode) + # [[time,x,y,theta,v]...] + trace = np.array(traces[agent_id]) + # print(trace) + for y_dim in y_dim_list: + ax.plot(trace[:, x_dim], trace[:, y_dim], color) + x_min = min(x_min, trace[:, x_dim].min()) + x_max = max(x_max, trace[:, x_dim].max()) + + y_min = min(y_min, trace[:, y_dim].min()) + y_max = max(y_max, trace[:, y_dim].max()) + queue += node.child + ax.set_xlim([x_min-1, x_max+1]) + ax.set_ylim([y_min-1, y_max+1]) + # plt.show() + # generate_simulation_anime(root, None, fig) + return fig + + +def generate_simulation_anime(root, map=None, fig=None): + if fig is None: + fig = plt.figure() + # fig = plot_map(map, 'g', fig) + timed_point_dict = {} + stack = [root] + ax = fig.gca() + x_min, x_max = float('inf'), -float('inf') + y_min, y_max = ax.get_ylim() + while stack != []: + node = stack.pop() + traces = node.trace + for agent_id in traces: + trace = traces[agent_id] + color = 'b' + if agent_id == 'car2': + color = 'r' + for i in range(len(trace)): + x_min = min(x_min, trace[i][1]) + x_max = max(x_max, trace[i][1]) + y_min = min(y_min, trace[i][2]) + y_max = max(y_max, trace[i][2]) + if round(trace[i][0], 5) not in timed_point_dict: + timed_point_dict[round(trace[i][0], 5)] = [ + (trace[i][1:], color)] + else: + timed_point_dict[round(trace[i][0], 5)].append( + (trace[i][1:], color)) + stack += node.child + + frames = [] + for time_point in timed_point_dict: + point_list = timed_point_dict[time_point] + plt.xlim((x_min-2, x_max+2)) + plt.ylim((y_min-2, y_max+2)) + # plot_map(map, color='g', fig=fig) + for data in point_list: + point = data + color = data[1] + ax = plt.gca() + ax.plot([point[0]], [point[1]], markerfacecolor=color, + markeredgecolor=color, marker='.', markersize=20) + x_tail = point[0] + y_tail = point[1] + dx = np.cos(point[2])*point[3] + dy = np.sin(point[2])*point[3] + ax.arrow(x_tail, y_tail, dx, dy, head_width=1, head_length=0.5) + plt.pause(0.005) + plt.clf() + return fig + # img_buf = io.BytesIO() + # plt.savefig(img_buf, format = 'png') + # im = Image.open(img_buf) + # frames.append(im) + # plt.clf() + # frame_one = frames[0] + # frame_one.save(fn, format = "GIF", append_images = frames, save_all = True, duration = 100, loop = 0) + + +def plotly_simulation_anime(root, map=None, fig=None): + # make figure + fig_dict = { + "data": [], + "layout": {}, + "frames": [] + } + # fig = plot_map(map, 'g', fig) + timed_point_dict = {} + stack = [root] + x_min, x_max = float('inf'), -float('inf') + y_min, y_max = float('inf'), -float('inf') + while stack != []: + node = stack.pop() + traces = node.trace + for agent_id in traces: + trace = np.array(traces[agent_id]) + for i in range(len(trace)): + x_min = min(x_min, trace[i][1]) + x_max = max(x_max, trace[i][1]) + y_min = min(y_min, trace[i][2]) + y_max = max(y_max, trace[i][2]) + if round(trace[i][0], 2) not in timed_point_dict: + timed_point_dict[round(trace[i][0], 2)] = [ + trace[i][1:].tolist()] + else: + init = False + for record in timed_point_dict[round(trace[i][0], 2)]: + if record == trace[i][1:].tolist(): + init = True + break + if init == False: + timed_point_dict[round(trace[i][0], 2)].append( + trace[i][1:].tolist()) + time = round(trace[i][0], 2) + stack += node.child + # fill in most of layout + # print(time) + duration = int(600/time) + fig_dict["layout"]["xaxis"] = { + "range": [(x_min-10), (x_max+10)], + "title": "x position"} + fig_dict["layout"]["yaxis"] = { + "range": [(y_min-2), (y_max+2)], + "title": "y position"} + fig_dict["layout"]["hovermode"] = "closest" + fig_dict["layout"]["updatemenus"] = [ + { + "buttons": [ + { + "args": [None, {"frame": {"duration": duration, "redraw": False}, + "fromcurrent": True, "transition": {"duration": duration, + "easing": "quadratic-in-out"}}], + "label": "Play", + "method": "animate" + }, + { + "args": [[None], {"frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}}], + "label": "Pause", + "method": "animate" + } + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top" + } + ] + sliders_dict = { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 20}, + "prefix": "time:", + "visible": True, + "xanchor": "right" + }, + "transition": {"duration": duration, "easing": "cubic-in-out"}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [] + } + # make data + point_list = timed_point_dict[0] + # print(point_list) + data_dict = { + "x": [data[0] for data in point_list], + "y": [data[1] for data in point_list], + "mode": "markers + text", + "text": [(round(data[3], 2), round(data[2]/pi*180, 2)) for data in point_list], + "textposition": "bottom center", + # "marker": { + # "sizemode": "area", + # "sizeref": 200000, + # "size": 2 + # }, + "name": "Current Position" + } + fig_dict["data"].append(data_dict) + + # make frames + for time_point in timed_point_dict: + frame = {"data": [], "layout": { + "annotations": []}, "name": str(time_point)} + # print(timed_point_dict[time_point]) + point_list = timed_point_dict[time_point] + # point_list = list(OrderedDict.fromkeys(timed_point_dict[time_point])) + trace_x = [data[0] for data in point_list] + trace_y = [data[1] for data in point_list] + trace_theta = [data[2] for data in point_list] + trace_v = [data[3] for data in point_list] + data_dict = { + "x": trace_x, + "y": trace_y, + "mode": "markers + text", + "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 2)) for i in range(len(trace_theta))], + "textposition": "bottom center", + # "marker": { + # "sizemode": "area", + # "sizeref": 200000, + # "size": 2 + # }, + "name": "current position" + } + frame["data"].append(data_dict) + for i in range(len(trace_x)): + ax = np.cos(trace_theta[i])*trace_v[i] + ay = np.sin(trace_theta[i])*trace_v[i] + # print(trace_x[i]+ax, trace_y[i]+ay) + annotations_dict = {"x": trace_x[i]+ax+0.1, "y": trace_y[i]+ay, + # "xshift": ax, "yshift": ay, + "ax": trace_x[i], "ay": trace_y[i], + "arrowwidth": 2, + # "arrowside": 'end', + "showarrow": True, + # "arrowsize": 1, + "xref": 'x', "yref": 'y', + "axref": 'x', "ayref": 'y', + # "text": "erver", + "arrowhead": 2, + "arrowcolor": "black"} + frame["layout"]["annotations"].append(annotations_dict) + + fig_dict["frames"].append(frame) + slider_step = {"args": [ + [time_point], + {"frame": {"duration": duration, "redraw": False}, + "mode": "immediate", + "transition": {"duration": duration}} + ], + "label": time_point, + "method": "animate"} + sliders_dict["steps"].append(slider_step) + # print(len(frame["layout"]["annotations"])) + + fig_dict["layout"]["sliders"] = [sliders_dict] + + fig = go.Figure(fig_dict) + fig = plotly_map(map, 'g', fig) + i = 0 + queue = [root] + while queue != []: + node = queue.pop(0) + traces = node.trace + print(node.mode) + # [[time,x,y,theta,v]...] + for agent_id in traces: + trace = np.array(traces[agent_id]) + # print(trace) + trace_y = trace[:, 2].tolist() + trace_x = trace[:, 1].tolist() + # theta = [i/pi*180 for i in trace[:, 3]] + color = 'green' + if agent_id == 'car2': + color = 'red' + fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2], + mode='lines', + line_color=color, + text=[(round(trace[i, 3]/pi*180, 2), round(trace[i, 4], 2)) + for i in range(len(trace_y))], + showlegend=False) + # name='lines') + ) + i += 1 + queue += node.child + # fig.update_traces(mode='lines') + + return fig + # fig.show() + + +# The 'color' property is a color and may be specified as: +# - A hex string (e.g. '#ff0000') +# - An rgb/rgba string (e.g. 'rgb(255,0,0)') +# - An hsl/hsla string (e.g. 'hsl(0,100%,50%)') +# - An hsv/hsva string (e.g. 'hsv(0,100%,100%)') +# - A named CSS color: +# aliceblue, antiquewhite, aqua, aquamarine, azure, +# beige, bisque, black, blanchedalmond, blue, +# blueviolet, brown, burlywood, cadetblue, +# chartreuse, chocolate, coral, cornflowerblue, +# cornsilk, crimson, cyan, darkblue, darkcyan, +# darkgoldenrod, darkgray, darkgrey, darkgreen, +# darkkhaki, darkmagenta, darkolivegreen, darkorange, +# darkorchid, darkred, darksalmon, darkseagreen, +# darkslateblue, darkslategray, darkslategrey, +# darkturquoise, darkviolet, deeppink, deepskyblue, +# dimgray, dimgrey, dodgerblue, firebrick, +# floralwhite, forestgreen, fuchsia, gainsboro, +# ghostwhite, gold, goldenrod, gray, grey, green, +# greenyellow, honeydew, hotpink, indianred, indigo, +# ivory, khaki, lavender, lavenderblush, lawngreen, +# lemonchiffon, lightblue, lightcoral, lightcyan, +# lightgoldenrodyellow, lightgray, lightgrey, +# lightgreen, lightpink, lightsalmon, lightseagreen, +# lightskyblue, lightslategray, lightslategrey, +# lightsteelblue, lightyellow, lime, limegreen, +# linen, magenta, maroon, mediumaquamarine, +# mediumblue, mediumorchid, mediumpurple, +# mediumseagreen, mediumslateblue, mediumspringgreen, +# mediumturquoise, mediumvioletred, midnightblue, +# mintcream, mistyrose, moccasin, navajowhite, navy, +# oldlace, olive, olivedrab, orange, orangered, +# orchid, palegoldenrod, palegreen, paleturquoise, +# palevioletred, papayawhip, peachpuff, peru, pink, +# plum, powderblue, purple, red, rosybrown, +# royalblue, rebeccapurple, saddlebrown, salmon, +# sandybrown, seagreen, seashell, sienna, silver, +# skyblue, slateblue, slategray, slategrey, snow, +# springgreen, steelblue, tan, teal, thistle, tomato, +# turquoise, violet, wheat, white, whitesmoke, +# yellow, yellowgreen diff --git a/dryvr_plus_plus/plotter/plotter3D.py b/demo/dryvr_plus_plus/plotter/plotter3D.py similarity index 100% rename from dryvr_plus_plus/plotter/plotter3D.py rename to demo/dryvr_plus_plus/plotter/plotter3D.py diff --git a/dryvr_plus_plus/scene_verifier/__init__.py b/demo/dryvr_plus_plus/scene_verifier/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/agents/__init__.py b/demo/dryvr_plus_plus/scene_verifier/agents/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/agents/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/agents/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/agents/base_agent.py b/demo/dryvr_plus_plus/scene_verifier/agents/base_agent.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/agents/base_agent.py rename to demo/dryvr_plus_plus/scene_verifier/agents/base_agent.py diff --git a/dryvr_plus_plus/scene_verifier/analysis/__init__.py b/demo/dryvr_plus_plus/scene_verifier/analysis/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/analysis/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/analysis/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py b/demo/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py similarity index 65% rename from dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py rename to demo/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py index f307e0313fc87a60dda4d109daa1ce1c8e7a9364..381de10d0e4095168d1112299cd889050e72423e 100644 --- a/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py +++ b/demo/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py @@ -1,13 +1,14 @@ from typing import List, Dict + class AnalysisTreeNode: """AnalysisTreeNode class A AnalysisTreeNode stores the continous execution of the system without transition happening""" trace: Dict """The trace for each agent. The key of the dict is the agent id and the value of the dict is simulated traces for each agent""" - init: Dict - + init: Dict + def __init__( self, trace={}, @@ -15,14 +16,14 @@ class AnalysisTreeNode: mode={}, agent={}, child=[], - start_time = 0, - ndigits = 10, - type = 'simtrace' + start_time=0, + ndigits=10, + type='simtrace' ): - self.trace:Dict = trace + self.trace: Dict = trace self.init: Dict[str, List[float]] = init self.mode: Dict[str, List[str]] = mode - self.agent:Dict = agent - self.child:List[AnalysisTreeNode] = child - self.start_time:float = round(start_time,ndigits) - self.type:str = type + self.agent: Dict = agent + self.child: List[AnalysisTreeNode] = child + self.start_time: float = round(start_time, ndigits) + self.type: str = type diff --git a/dryvr_plus_plus/scene_verifier/analysis/simulator.py b/demo/dryvr_plus_plus/scene_verifier/analysis/simulator.py similarity index 73% rename from dryvr_plus_plus/scene_verifier/analysis/simulator.py rename to demo/dryvr_plus_plus/scene_verifier/analysis/simulator.py index 7d064ff432cefe5ef58ba15659498b405e6df1fd..97514d88af347846c6c985ae2a340b505838c24d 100644 --- a/dryvr_plus_plus/scene_verifier/analysis/simulator.py +++ b/demo/dryvr_plus_plus/scene_verifier/analysis/simulator.py @@ -6,11 +6,12 @@ import numpy as np from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode + class Simulator: def __init__(self): self.simulation_tree_root = None - def simulate(self, init_list, init_mode_list, agent_list:List[BaseAgent], transition_graph, time_horizon, lane_map): + def simulate(self, init_list, init_mode_list, agent_list: List[BaseAgent], transition_graph, time_horizon, lane_map): # Setup the root of the simulation tree root = AnalysisTreeNode( trace={}, @@ -18,7 +19,7 @@ class Simulator: mode={}, agent={}, child=[], - start_time = 0, + start_time=0, ) for i, agent in enumerate(agent_list): root.init[agent.id] = init_list[i] @@ -31,7 +32,7 @@ class Simulator: simulation_queue.append(root) # Perform BFS through the simulation tree to loop through all possible transitions while simulation_queue != []: - node:AnalysisTreeNode = simulation_queue.pop(0) + node: AnalysisTreeNode = simulation_queue.pop(0) print(node.mode) remain_time = time_horizon - node.start_time if remain_time <= 0: @@ -40,14 +41,18 @@ class Simulator: for agent_id in node.agent: if agent_id not in node.trace: # Simulate the trace starting from initial condition + # [time, x, y, theta, v] mode = node.mode[agent_id] init = node.init[agent_id] - trace = node.agent[agent_id].TC_simulate(mode, init, remain_time,lane_map) - trace[:,0] += node.start_time + trace = node.agent[agent_id].TC_simulate( + mode, init, remain_time, lane_map) + trace[:, 0] += node.start_time node.trace[agent_id] = trace.tolist() trace_length = len(list(node.trace.values())[0]) + print(trace_length) transitions = [] + # time step for idx in range(trace_length): # For each trace, check with the guard to see if there's any possible transition # Store all possible transition in a list @@ -55,47 +60,54 @@ class Simulator: # Here we enforce that only one agent transit at a time all_agent_state = {} for agent_id in node.agent: - all_agent_state[agent_id] = (node.trace[agent_id][idx], node.mode[agent_id]) - possible_transitions = transition_graph.get_all_transition(all_agent_state) + all_agent_state[agent_id] = ( + node.trace[agent_id][idx], node.mode[agent_id]) + # keyi: whether it is possible to not call get_all_transition every time + possible_transitions = transition_graph.get_all_transition( + all_agent_state) if possible_transitions != []: + print(possible_transitions) for agent_idx, src_mode, dest_mode, next_init in possible_transitions: - transitions.append((agent_idx, src_mode, dest_mode, next_init, idx)) + transitions.append( + (agent_idx, src_mode, dest_mode, next_init, idx)) break # truncate the computed trajectories from idx and store the content after truncate truncated_trace = {} + print("idx", idx) for agent_idx in node.agent: truncated_trace[agent_idx] = node.trace[agent_idx][idx:] node.trace[agent_idx] = node.trace[agent_idx][:idx+1] - # For each possible transition, construct the new node. + # For each possible transition, construct the new node. # Obtain the new initial condition for agent having transition # copy the traces that are not under transition for transition in transitions: + # mode including veh and lane transit_agent_idx, src_mode, dest_mode, next_init, idx = transition if dest_mode is None: continue # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0) - next_node_mode = copy.deepcopy(node.mode) - next_node_mode[transit_agent_idx] = dest_mode - next_node_agent = node.agent + next_node_mode = copy.deepcopy(node.mode) + next_node_mode[transit_agent_idx] = dest_mode + next_node_agent = node.agent next_node_start_time = list(truncated_trace.values())[0][0][0] next_node_init = {} next_node_trace = {} for agent_idx in next_node_agent: if agent_idx == transit_agent_idx: - next_node_init[agent_idx] = next_init + next_node_init[agent_idx] = next_init else: next_node_trace[agent_idx] = truncated_trace[agent_idx] - + tmp = AnalysisTreeNode( - trace = next_node_trace, - init = next_node_init, - mode = next_node_mode, - agent = next_node_agent, - child = [], - start_time = next_node_start_time, - type = 'simtrace' + trace=next_node_trace, + init=next_node_init, + mode=next_node_mode, + agent=next_node_agent, + child=[], + start_time=next_node_start_time, + type='simtrace' ) node.child.append(tmp) simulation_queue.append(tmp) diff --git a/dryvr_plus_plus/scene_verifier/analysis/verifier.py b/demo/dryvr_plus_plus/scene_verifier/analysis/verifier.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/analysis/verifier.py rename to demo/dryvr_plus_plus/scene_verifier/analysis/verifier.py diff --git a/dryvr_plus_plus/scene_verifier/automaton/__init__.py b/demo/dryvr_plus_plus/scene_verifier/automaton/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/automaton/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/automaton/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/automaton/guard.py b/demo/dryvr_plus_plus/scene_verifier/automaton/guard.py similarity index 69% rename from dryvr_plus_plus/scene_verifier/automaton/guard.py rename to demo/dryvr_plus_plus/scene_verifier/automaton/guard.py index 80d2359e876970f4d1e258580a5cee90fbed7d4b..a96fd0554fef02d0b7bee38a3b58e08fc320206d 100644 --- a/dryvr_plus_plus/scene_verifier/automaton/guard.py +++ b/demo/dryvr_plus_plus/scene_verifier/automaton/guard.py @@ -15,20 +15,23 @@ import numpy as np from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane from dryvr_plus_plus.scene_verifier.utils.utils import * + + class LogicTreeNode: - def __init__(self, data, child = [], val = None, mode_guard = None): - self.data = data + def __init__(self, data, child=[], val=None, mode_guard=None): + self.data = data self.child = child self.val = val self.mode_guard = mode_guard + class GuardExpressionAst: def __init__(self, guard_list): self.ast_list = [] for guard in guard_list: self.ast_list.append(copy.deepcopy(guard.ast)) self.cont_variables = {} - self.varDict = {'t':Real('t')} + self.varDict = {'t': Real('t')} def _build_guard(self, guard_str, agent): """ @@ -49,21 +52,25 @@ class GuardExpressionAst: # Thus we need to replace "==" to something else sympy_guard_str = guard_str.replace("==", ">=") for vars in self.cont_variables: - sympy_guard_str = sympy_guard_str.replace(vars, self.cont_variables[vars]) + sympy_guard_str = sympy_guard_str.replace( + vars, self.cont_variables[vars]) - symbols = list(sympy.sympify(sympy_guard_str, evaluate=False).free_symbols) + symbols = list(sympy.sympify( + sympy_guard_str, evaluate=False).free_symbols) symbols = [str(s) for s in symbols] tmp = list(self.cont_variables.values()) symbols_map = {} for s in symbols: if s in tmp: - key = list(self.cont_variables.keys())[list(self.cont_variables.values()).index(s)] + key = list(self.cont_variables.keys())[ + list(self.cont_variables.values()).index(s)] symbols_map[s] = key for vars in reversed(self.cont_variables): guard_str = guard_str.replace(vars, self.cont_variables[vars]) guard_str = self._handleReplace(guard_str) - cur_solver.add(eval(guard_str)) # TODO use an object instead of `eval` a string + # TODO use an object instead of `eval` a string + cur_solver.add(eval(guard_str)) return cur_solver, symbols_map def _handleReplace(self, input_str): @@ -74,7 +81,7 @@ class GuardExpressionAst: And(y<=0,t>=0.2,v>=-0.1) output: And(self.varDic["y"]<=0,self.varDic["t"]>=0.2,self.varDic["v"]>=-0.1) - + Args: input_str (str): original string need to be replaced keys (list): list of variable strings @@ -93,7 +100,8 @@ class GuardExpressionAst: for i in range(len(input_str)): if input_str[i:].startswith(key): idxes.append((i, i + len(key))) - input_str = input_str[:i] + "@" * len(key) + input_str[i + len(key):] + input_str = input_str[:i] + "@" * \ + len(key) + input_str[i + len(key):] idxes = sorted(idxes) @@ -109,32 +117,37 @@ class GuardExpressionAst: is_contained = False for cont_vars in continuous_variable_dict: - self.cont_variables[cont_vars] = cont_vars.replace('.','_') - self.varDict[cont_vars.replace('.','_')] = Real(cont_vars.replace('.','_')) + self.cont_variables[cont_vars] = cont_vars.replace('.', '_') + self.varDict[cont_vars.replace('.', '_')] = Real( + cont_vars.replace('.', '_')) - z3_string = self.generate_z3_expression() + z3_string = self.generate_z3_expression() if isinstance(z3_string, bool): if z3_string: - return True, True + return True, True else: return False, False cur_solver, symbols = self._build_guard(z3_string, agent) cur_solver.push() for symbol in symbols: - cur_solver.add(self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0]) - cur_solver.add(self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1]) + cur_solver.add( + self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0]) + cur_solver.add( + self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1]) if cur_solver.check() == sat: # The reachtube hits the guard cur_solver.pop() res = True - + # TODO: If the reachtube completely fall inside guard, break tmp_solver = Solver() tmp_solver.add(Not(cur_solver.assertions()[0])) for symbol in symbols: - tmp_solver.add(self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0]) - tmp_solver.add(self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1]) + tmp_solver.add( + self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0]) + tmp_solver.add( + self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1]) if tmp_solver.check() == unsat: print("Full intersect, break") is_contained = True @@ -177,7 +190,7 @@ class GuardExpressionAst: If without evaluating the continuous variables the result is True, then the guard condition will automatically be satisfied - + If without evaluating the continuous variables the result is False, then the guard condition will not be satisfied @@ -189,11 +202,11 @@ class GuardExpressionAst: # For each value in the boolop, check results if isinstance(node.op, ast.And): z3_str = [] - for i,val in enumerate(node.values): + for i, val in enumerate(node.values): tmp = self._generate_z3_expression_node(val) if isinstance(tmp, bool): if tmp: - continue + continue else: return False z3_str.append(tmp) @@ -213,7 +226,7 @@ class GuardExpressionAst: return z3_str # If string, construct string # If bool, check result and discard/evaluate result according to operator - pass + pass elif isinstance(node, ast.Constant): # If is bool, return boolean result if isinstance(node.value, bool): @@ -224,7 +237,7 @@ class GuardExpressionAst: expr = expr.strip('\n') return expr elif isinstance(node, ast.UnaryOp): - # If is UnaryOp, + # If is UnaryOp, value = self._generate_z3_expression_node(node.operand) if isinstance(node.op, ast.USub): return -value @@ -234,7 +247,7 @@ class GuardExpressionAst: expr = expr.strip('\n') return expr - def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map:LaneMap): + def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map: LaneMap): """ Handle guard atomics that contains both continuous and hybrid variables Especially, we want to handle function calls that need both continuous and @@ -244,41 +257,48 @@ class GuardExpressionAst: By doing this, all calls that need both continuous and discrete variables as input will now become only continuous variables. We can then handle these using what we already have for the continous variables """ - res = True + res = True for i, node in enumerate(self.ast_list): - tmp, self.ast_list[i] = self._evaluate_guard_hybrid(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) - res = res and tmp + tmp, self.ast_list[i] = self._evaluate_guard_hybrid( + node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) + res = res and tmp return res - def _evaluate_guard_hybrid(self, root, agent, disc_var_dict, cont_var_dict, lane_map:LaneMap): - if isinstance(root, ast.Compare): + def _evaluate_guard_hybrid(self, root, agent, disc_var_dict, cont_var_dict, lane_map: LaneMap): + if isinstance(root, ast.Compare): expr = astunparse.unparse(root) - left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.comparators[0] = self._evaluate_guard_hybrid(root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map) + left, root.left = self._evaluate_guard_hybrid( + root.left, agent, disc_var_dict, cont_var_dict, lane_map) + right, root.comparators[0] = self._evaluate_guard_hybrid( + root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map) return True, root elif isinstance(root, ast.BoolOp): if isinstance(root.op, ast.And): res = True for i, val in enumerate(root.values): - tmp, root.values[i] = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map) - res = res and tmp + tmp, root.values[i] = self._evaluate_guard_hybrid( + val, agent, disc_var_dict, cont_var_dict, lane_map) + res = res and tmp if not res: - break - return res, root + break + return res, root elif isinstance(root.op, ast.Or): for val in root.values: - tmp,val = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map) + tmp, val = self._evaluate_guard_hybrid( + val, agent, disc_var_dict, cont_var_dict, lane_map) res = res or tmp if res: break - return res, root + return res, root elif isinstance(root, ast.BinOp): - left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.right = self._evaluate_guard_hybrid(root.right, agent, disc_var_dict, cont_var_dict, lane_map) + left, root.left = self._evaluate_guard_hybrid( + root.left, agent, disc_var_dict, cont_var_dict, lane_map) + right, root.right = self._evaluate_guard_hybrid( + root.right, agent, disc_var_dict, cont_var_dict, lane_map) return True, root elif isinstance(root, ast.Call): if isinstance(root.func, ast.Attribute): - func = root.func + func = root.func if func.value.id == 'lane_map': if func.attr == 'get_lateral_distance': # Get function arguments @@ -294,20 +314,25 @@ class GuardExpressionAst: if isinstance(elt, ast.Attribute): var = elt.value.id + '.' + elt.attr arg1_lower.append(cont_var_dict[var][0]) - arg1_upper.append(cont_var_dict[var][1]) + arg1_upper.append(cont_var_dict[var][1]) vehicle_pos = (arg1_lower, arg1_upper) # Get corresponding lane segments with respect to the set of vehicle pos - lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower) - lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper) + lane_seg1 = lane_map.get_lane_segment( + vehicle_lane, arg1_lower) + lane_seg2 = lane_map.get_lane_segment( + vehicle_lane, arg1_upper) # Compute the set of possible lateral values with respect to all possible segments - lateral_set1 = self._handle_lateral_set(lane_seg1, np.array(vehicle_pos)) - lateral_set2 = self._handle_lateral_set(lane_seg2, np.array(vehicle_pos)) + lateral_set1 = self._handle_lateral_set( + lane_seg1, np.array(vehicle_pos)) + lateral_set2 = self._handle_lateral_set( + lane_seg2, np.array(vehicle_pos)) # Use the union of two sets as the set of possible lateral positions - lateral_set = [min(lateral_set1[0], lateral_set2[0]), max(lateral_set1[1], lateral_set2[1])] - + lateral_set = [min(lateral_set1[0], lateral_set2[0]), max( + lateral_set1[1], lateral_set2[1])] + # Construct the tmp variable tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}' # Add the tmp variable to the cont var dict @@ -329,20 +354,25 @@ class GuardExpressionAst: if isinstance(elt, ast.Attribute): var = elt.value.id + '.' + elt.attr arg1_lower.append(cont_var_dict[var][0]) - arg1_upper.append(cont_var_dict[var][1]) + arg1_upper.append(cont_var_dict[var][1]) vehicle_pos = (arg1_lower, arg1_upper) # Get corresponding lane segments with respect to the set of vehicle pos - lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower) - lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper) + lane_seg1 = lane_map.get_lane_segment( + vehicle_lane, arg1_lower) + lane_seg2 = lane_map.get_lane_segment( + vehicle_lane, arg1_upper) # Compute the set of possible longitudinal values with respect to all possible segments - longitudinal_set1 = self._handle_longitudinal_set(lane_seg1, np.array(vehicle_pos)) - longitudinal_set2 = self._handle_longitudinal_set(lane_seg2, np.array(vehicle_pos)) + longitudinal_set1 = self._handle_longitudinal_set( + lane_seg1, np.array(vehicle_pos)) + longitudinal_set2 = self._handle_longitudinal_set( + lane_seg2, np.array(vehicle_pos)) # Use the union of two sets as the set of possible longitudinal positions - longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max(longitudinal_set1[1], longitudinal_set2[1])] - + longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max( + longitudinal_set1[1], longitudinal_set2[1])] + # Construct the tmp variable tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}' # Add the tmp variable to the cont var dict @@ -351,45 +381,53 @@ class GuardExpressionAst: root = ast.parse(tmp_var_name).body[0].value return True, root else: - raise ValueError(f'Node type {func} from {astunparse.unparse(func)} is not supported') + raise ValueError( + f'Node type {func} from {astunparse.unparse(func)} is not supported') else: - raise ValueError(f'Node type {func} from {astunparse.unparse(func)} is not supported') + raise ValueError( + f'Node type {func} from {astunparse.unparse(func)} is not supported') else: - raise ValueError(f'Node type {root.func} from {astunparse.unparse(root.func)} is not supported') + raise ValueError( + f'Node type {root.func} from {astunparse.unparse(root.func)} is not supported') elif isinstance(root, ast.Attribute): - return True, root + return True, root elif isinstance(root, ast.Constant): - return root.value, root + return root.value, root elif isinstance(root, ast.UnaryOp): if isinstance(root.op, ast.USub): - res, root.operand = self._evaluate_guard_hybrid(root.operand, agent, disc_var_dict, cont_var_dict, lane_map) + res, root.operand = self._evaluate_guard_hybrid( + root.operand, agent, disc_var_dict, cont_var_dict, lane_map) else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') - return True, root + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') + return True, root else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') def _handle_longitudinal_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]: if lane_seg.type == "Straight": # Delta lower - delta0 = position[0,:] - lane_seg.start + delta0 = position[0, :] - lane_seg.start # Delta upper - delta1 = position[1,:] - lane_seg.start + delta1 = position[1, :] - lane_seg.start longitudinal_low = min(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \ - min(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1]) + min(delta0[1]*lane_seg.direction[1], + delta1[1]*lane_seg.direction[1]) longitudinal_high = max(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \ - max(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1]) + max(delta0[1]*lane_seg.direction[1], + delta1[1]*lane_seg.direction[1]) longitudinal_low += lane_seg.longitudinal_start longitudinal_high += lane_seg.longitudinal_start assert longitudinal_high >= longitudinal_low - return longitudinal_low, longitudinal_high + return longitudinal_low, longitudinal_high elif lane_seg.type == "Circular": # Delta lower - delta0 = position[0,:] - lane_seg.center + delta0 = position[0, :] - lane_seg.center # Delta upper - delta1 = position[1,:] - lane_seg.center + delta1 = position[1, :] - lane_seg.center phi0 = np.min([ np.arctan2(delta0[1], delta0[0]), @@ -404,50 +442,66 @@ class GuardExpressionAst: np.arctan2(delta1[1], delta1[0]), ]) - phi0 = lane_seg.start_phase + wrap_to_pi(phi0 - lane_seg.start_phase) - phi1 = lane_seg.start_phase + wrap_to_pi(phi1 - lane_seg.start_phase) + phi0 = lane_seg.start_phase + \ + wrap_to_pi(phi0 - lane_seg.start_phase) + phi1 = lane_seg.start_phase + \ + wrap_to_pi(phi1 - lane_seg.start_phase) longitudinal_low = min( - lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius, - lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius + lane_seg.direction * + (phi0 - lane_seg.start_phase)*lane_seg.radius, + lane_seg.direction * + (phi1 - lane_seg.start_phase)*lane_seg.radius ) + lane_seg.longitudinal_start longitudinal_high = max( - lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius, - lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius + lane_seg.direction * + (phi0 - lane_seg.start_phase)*lane_seg.radius, + lane_seg.direction * + (phi1 - lane_seg.start_phase)*lane_seg.radius ) + lane_seg.longitudinal_start assert longitudinal_high >= longitudinal_low return longitudinal_low, longitudinal_high else: - raise ValueError(f'Lane segment with type {lane_seg.type} is not supported') + raise ValueError( + f'Lane segment with type {lane_seg.type} is not supported') def _handle_lateral_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]: if lane_seg.type == "Straight": # Delta lower - delta0 = position[0,:] - lane_seg.start + delta0 = position[0, :] - lane_seg.start # Delta upper - delta1 = position[1,:] - lane_seg.start + delta1 = position[1, :] - lane_seg.start lateral_low = min(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \ - min(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1]) + min(delta0[1]*lane_seg.direction_lateral[1], + delta1[1]*lane_seg.direction_lateral[1]) lateral_high = max(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \ - max(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1]) + max(delta0[1]*lane_seg.direction_lateral[1], + delta1[1]*lane_seg.direction_lateral[1]) assert lateral_high >= lateral_low return lateral_low, lateral_high elif lane_seg.type == "Circular": - dx = np.max([position[0,0]-lane_seg.center[0],0,lane_seg.center[0]-position[1,0]]) - dy = np.max([position[0,1]-lane_seg.center[1],0,lane_seg.center[1]-position[1,1]]) + dx = np.max([position[0, 0]-lane_seg.center[0], + 0, lane_seg.center[0]-position[1, 0]]) + dy = np.max([position[0, 1]-lane_seg.center[1], + 0, lane_seg.center[1]-position[1, 1]]) r_low = np.linalg.norm([dx, dy]) - dx = np.max([np.abs(position[0,0]-lane_seg.center[0]),np.abs(position[1,0]-lane_seg.center[0])]) - dy = np.max([np.abs(position[0,1]-lane_seg.center[1]),np.abs(position[1,1]-lane_seg.center[1])]) + dx = np.max([np.abs(position[0, 0]-lane_seg.center[0]), + np.abs(position[1, 0]-lane_seg.center[0])]) + dy = np.max([np.abs(position[0, 1]-lane_seg.center[1]), + np.abs(position[1, 1]-lane_seg.center[1])]) r_high = np.linalg.norm([dx, dy]) - lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low)) - lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low)) + lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high), + lane_seg.direction*(lane_seg.radius - r_low)) + lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high), + lane_seg.direction*(lane_seg.radius - r_low)) # print(lateral_low, lateral_high) assert lateral_high >= lateral_low return lateral_low, lateral_high else: - raise ValueError(f'Lane segment with type {lane_seg.type} is not supported') + raise ValueError( + f'Lane segment with type {lane_seg.type} is not supported') def evaluate_guard_disc(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map): """ @@ -455,50 +509,55 @@ class GuardExpressionAst: """ res = True for i, node in enumerate(self.ast_list): - tmp, self.ast_list[i] = self._evaluate_guard_disc(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) - res = res and tmp + tmp, self.ast_list[i] = self._evaluate_guard_disc( + node, agent, discrete_variable_dict, continuous_variable_dict, lane_map) + res = res and tmp return res - + def _evaluate_guard_disc(self, root, agent, disc_var_dict, cont_var_dict, lane_map): """ Recursively called function to evaluate guard with only discrete variables The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by boolean constants - + :params: :return: The return value will be a tuple. The first element in the tuple will either be a boolean value or a the evaluated value of of an expression involving guard The second element in the tuple will be the updated ast node """ if isinstance(root, ast.Compare): expr = astunparse.unparse(root) - left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.comparators[0] = self._evaluate_guard_disc(root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map) + left, root.left = self._evaluate_guard_disc( + root.left, agent, disc_var_dict, cont_var_dict, lane_map) + right, root.comparators[0] = self._evaluate_guard_disc( + root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map) if isinstance(left, bool) or isinstance(right, bool): return True, root if isinstance(root.ops[0], ast.GtE): - res = left>=right + res = left >= right elif isinstance(root.ops[0], ast.Gt): - res = left>right + res = left > right elif isinstance(root.ops[0], ast.Lt): - res = left<right + res = left < right elif isinstance(root.ops[0], ast.LtE): - res = left<=right + res = left <= right elif isinstance(root.ops[0], ast.Eq): - res = left == right + res = left == right elif isinstance(root.ops[0], ast.NotEq): - res = left != right + res = left != right else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') if res: root = ast.parse('True').body[0].value else: - root = ast.parse('False').body[0].value + root = ast.parse('False').body[0].value return res, root elif isinstance(root, ast.BoolOp): if isinstance(root.op, ast.And): res = True - for i,val in enumerate(root.values): - tmp,root.values[i] = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map) + for i, val in enumerate(root.values): + tmp, root.values[i] = self._evaluate_guard_disc( + val, agent, disc_var_dict, cont_var_dict, lane_map) res = res and tmp if not res: break @@ -506,15 +565,18 @@ class GuardExpressionAst: elif isinstance(root.op, ast.Or): res = False for val in root.values: - tmp,val = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map) + tmp, val = self._evaluate_guard_disc( + val, agent, disc_var_dict, cont_var_dict, lane_map) res = res or tmp if res: break - return res, root + return res, root elif isinstance(root, ast.BinOp): # Check left and right in the binop and replace all attributes involving discrete variables - left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map) - right, root.right = self._evaluate_guard_disc(root.right, agent, disc_var_dict, cont_var_dict, lane_map) + left, root.left = self._evaluate_guard_disc( + root.left, agent, disc_var_dict, cont_var_dict, lane_map) + right, root.right = self._evaluate_guard_disc( + root.right, agent, disc_var_dict, cont_var_dict, lane_map) return True, root elif isinstance(root, ast.Call): expr = astunparse.unparse(root) @@ -534,7 +596,7 @@ class GuardExpressionAst: if res: root = ast.parse('True').body[0].value else: - root = ast.parse('False').body[0].value + root = ast.parse('False').body[0].value else: root = ast.parse(str(res)).body[0].value return res, root @@ -558,17 +620,21 @@ class GuardExpressionAst: return root.value, root elif isinstance(root, ast.UnaryOp): if isinstance(root.op, ast.USub): - res, root.operand = self._evaluate_guard_disc(root.operand, agent, disc_var_dict, cont_var_dict, lane_map) + res, root.operand = self._evaluate_guard_disc( + root.operand, agent, disc_var_dict, cont_var_dict, lane_map) else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') return True, root else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') def evaluate_guard(self, agent, continuous_variable_dict, discrete_variable_dict, lane_map): res = True for node in self.ast_list: - tmp = self._evaluate_guard(node, agent, continuous_variable_dict, discrete_variable_dict, lane_map) + tmp = self._evaluate_guard( + node, agent, continuous_variable_dict, discrete_variable_dict, lane_map) res = tmp and res if not res: break @@ -576,28 +642,32 @@ class GuardExpressionAst: def _evaluate_guard(self, root, agent, cnts_var_dict, disc_var_dict, lane_map): if isinstance(root, ast.Compare): - left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard(root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map) + left = self._evaluate_guard( + root.left, agent, cnts_var_dict, disc_var_dict, lane_map) + right = self._evaluate_guard( + root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map) if isinstance(root.ops[0], ast.GtE): - return left>=right + return left >= right elif isinstance(root.ops[0], ast.Gt): - return left>right + return left > right elif isinstance(root.ops[0], ast.Lt): - return left<right + return left < right elif isinstance(root.ops[0], ast.LtE): - return left<=right + return left <= right elif isinstance(root.ops[0], ast.Eq): - return left == right + return left == right elif isinstance(root.ops[0], ast.NotEq): - return left != right + return left != right else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') elif isinstance(root, ast.BoolOp): if isinstance(root.op, ast.And): res = True for val in root.values: - tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map) + tmp = self._evaluate_guard( + val, agent, cnts_var_dict, disc_var_dict, lane_map) res = res and tmp if not res: break @@ -605,20 +675,24 @@ class GuardExpressionAst: elif isinstance(root.op, ast.Or): res = False for val in root.values: - tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map) + tmp = self._evaluate_guard( + val, agent, cnts_var_dict, disc_var_dict, lane_map) res = res or tmp if res: break return res elif isinstance(root, ast.BinOp): - left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map) - right = self._evaluate_guard(root.right, agent, cnts_var_dict, disc_var_dict, lane_map) + left = self._evaluate_guard( + root.left, agent, cnts_var_dict, disc_var_dict, lane_map) + right = self._evaluate_guard( + root.right, agent, cnts_var_dict, disc_var_dict, lane_map) if isinstance(root.op, ast.Sub): return left - right elif isinstance(root.op, ast.Add): return left + right else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') elif isinstance(root, ast.Call): expr = astunparse.unparse(root) # Check if the root is a function @@ -633,7 +707,7 @@ class GuardExpressionAst: for arg in disc_var_dict: expr = expr.replace(arg, f'"{disc_var_dict[arg]}"') for arg in cnts_var_dict: - expr = expr.replace(arg, str(cnts_var_dict[arg])) + expr = expr.replace(arg, str(cnts_var_dict[arg])) res = eval(expr) return res elif isinstance(root, ast.Attribute): @@ -654,18 +728,22 @@ class GuardExpressionAst: elif isinstance(root, ast.Constant): return root.value elif isinstance(root, ast.UnaryOp): - val = self._evaluate_guard(root.operand, agent, cnts_var_dict, disc_var_dict, lane_map) + val = self._evaluate_guard( + root.operand, agent, cnts_var_dict, disc_var_dict, lane_map) if isinstance(root.op, ast.USub): return -val else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') else: - raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported') + raise ValueError( + f'Node type {root} from {astunparse.unparse(root)} is not supported') + if __name__ == "__main__": - with open('tmp.pickle','rb') as f: + with open('tmp.pickle', 'rb') as f: guard_list = pickle.load(f) tmp = GuardExpressionAst(guard_list) # tmp.evaluate_guard() # tmp.construct_tree_from_str('(other_x-ego_x<20) and other_x-ego_x>10 and other_vehicle_lane==ego_vehicle_lane') - print("stop") \ No newline at end of file + print("stop") diff --git a/dryvr_plus_plus/scene_verifier/automaton/hybrid_automaton.py b/demo/dryvr_plus_plus/scene_verifier/automaton/hybrid_automaton.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/automaton/hybrid_automaton.py rename to demo/dryvr_plus_plus/scene_verifier/automaton/hybrid_automaton.py diff --git a/dryvr_plus_plus/scene_verifier/automaton/hybrid_io_automaton.py b/demo/dryvr_plus_plus/scene_verifier/automaton/hybrid_io_automaton.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/automaton/hybrid_io_automaton.py rename to demo/dryvr_plus_plus/scene_verifier/automaton/hybrid_io_automaton.py diff --git a/dryvr_plus_plus/scene_verifier/automaton/reset.py b/demo/dryvr_plus_plus/scene_verifier/automaton/reset.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/automaton/reset.py rename to demo/dryvr_plus_plus/scene_verifier/automaton/reset.py diff --git a/dryvr_plus_plus/scene_verifier/code_parser/__init__.py b/demo/dryvr_plus_plus/scene_verifier/code_parser/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/code_parser/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/code_parser/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py b/demo/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py similarity index 74% rename from dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py rename to demo/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py index e3201422fda74b63148f60ef0444aac1d50fe521..cba535bffed1427a0e03ff6a6e6e90fc3bd4250d 100644 --- a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py +++ b/demo/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py @@ -1,21 +1,23 @@ -#parse python file +# parse python file -#REQUIRES PYTHON 3.8! +# REQUIRES PYTHON 3.8! from cgitb import reset -#import clang.cindex +# import clang.cindex import typing import json import sys from typing import List, Tuple -import re +import re import itertools import ast - +import time from treelib import Node, Tree ''' Edge class: utility class to hold the source, dest, guards, and resets for a transition ''' + + class Edge: def __init__(self, source, dest, guards, resets): self.source = source @@ -23,18 +25,21 @@ class Edge: self.guards = guards self.resets = resets + ''' -Statement super class. Holds the code and mode information for a statement. +Statement super class. Holds the code and mode information for a statement. If there is no mode information, mode and modeType are None. ''' + + class Statement: - def __init__(self, code, mode, modeType, func = None, args = None): + def __init__(self, code, mode, modeType, func=None, args=None): self.code = code self.modeType = modeType self.mode = mode - self.func = func + self.func = func self.args = args - + def print(self): print(self.code) @@ -42,26 +47,28 @@ class Statement: ''' Guard class. Subclass of statement. ''' + + class Guard(Statement): def __init__(self, code, mode, modeType, inp_ast, func=None, args=None): super().__init__(code, mode, modeType, func, args) self.ast = inp_ast - ''' - Returns true if a guard is checking that we are in a mode. + Returns true if a guard is checking that we are in a mode. ''' + def isModeCheck(self): return self.modeType != None ''' Helper function to parse a node that contains a guard. Parses out the code and mode. - Returns a Guard. + Returns a Guard. TODO: needs to handle more complex guards. ''' def parseGuard(node, code): - #assume guard is a strict comparision (modeType == mode) - if isinstance(node.test, ast.Compare): + # assume guard is a strict comparision (modeType == mode) + if isinstance(node.test, ast.Compare): # == if isinstance(node.test.comparators[0], ast.Attribute): if ("Mode" in str(node.test.comparators[0].value.id)): modeType = str(node.test.comparators[0].value.id) @@ -69,40 +76,44 @@ class Guard(Statement): return Guard(ast.get_source_segment(code, node.test), mode, modeType, node.test) else: return Guard(ast.get_source_segment(code, node.test), None, None, node.test) - elif isinstance(node.test, ast.BoolOp): + elif isinstance(node.test, ast.BoolOp): # or and return Guard(ast.get_source_segment(code, node.test), None, None, node.test) - elif isinstance(node.test, ast.Call): + elif isinstance(node.test, ast.Call): # function not used source_segment = ast.get_source_segment(code, node.test) if "map" in source_segment: - func = node.test.func.value.id + '.' + node.test.func.attr + func = node.test.func.value.id + '.' + node.test.func.attr args = [] for arg in node.test.args: args.append(arg.value.id + '.' + arg.attr) return Guard(source_segment, None, None, node.test, func, args) + ''' Reset class. Subclass of statement. ''' + + class Reset(Statement): def __init__(self, code, mode, modeType, inp_ast): super().__init__(code, mode, modeType) self.ast = inp_ast ''' - Returns true if a reset is updating our mode. + Returns true if a reset is updating our mode. ''' + def isModeUpdate(self): return self.modeType != None ''' Helper function to parse a node that contains a reset. Parses out the code and mode. - Returns a reset. + Returns a reset. ''' def parseReset(node, code): - #assume reset is modeType = newMode + # assume reset is modeType = newMode if isinstance(node.value, ast.Attribute): - #print("resets " + str(node.value.value.id)) - #print("resets " + str(node.value.attr)) + # print("resets " + str(node.value.value.id)) + # print("resets " + str(node.value.attr)) if ("Mode" in str(node.value.value.id)): modeType = str(node.value.value.id) mode = str(node.value.attr) @@ -113,6 +124,8 @@ class Reset(Statement): ''' Util class to handle building transitions given a path. ''' + + class TransitionUtil: ''' Takes in a list of reset objects. Returns a string in the format json expected. @@ -120,13 +133,13 @@ class TransitionUtil: def resetString(resets): outstr = "" for reset in resets: - outstr+= reset.code + ";" + outstr += reset.code + ";" outstr = outstr.strip(";") return outstr ''' Takes in guard code. Returns a string in the format json expected. - TODO: needs to handle more complex guards. + TODO: needs to handle more complex guards. ''' def parseGuardCode(code): parts = code.split("and") @@ -138,33 +151,33 @@ class TransitionUtil: return out ''' - Helper function for parseGuardCode. + Helper function for parseGuardCode. ''' def guardString(guards): output = "" first = True - for guard in guards: - #print(type(condition)) + for guard in guards: + # print(type(condition)) if first: - output+= TransitionUtil.parseGuardCode(guard.code) + output += TransitionUtil.parseGuardCode(guard.code) else: - output = "And(" + TransitionUtil.parseGuardCode(guard.code) + ",(" + output + "))" + output = "And(" + TransitionUtil.parseGuardCode(guard.code) + \ + ",(" + output + "))" first = False return output - ''' Helper function to get the index of the vertex for a set of modes. Modes is a list of all modes in the current vertex. - Vertices is the list of vertices. + Vertices is the list of vertices. TODO: needs to be tested on more complex examples to see if ordering stays and we can use index function ''' def getIndex(modes, vertices): return vertices.index(tuple(modes)) ''' - Function that creates transitions given a path. - Will create multiple transitions if not all modeTypes are checked/set in the path. + Function that creates transitions given a path. + Will create multiple transitions if not all modeTypes are checked/set in the path. Returns a list of edges that correspond to the path. ''' def createTransition(path, vertices, modes): @@ -190,8 +203,8 @@ class TransitionUtil: for modeType in modes.keys(): foundMode = False for condition in modeChecks: - #print(condition.modeType) - #print(modeType) + # print(condition.modeType) + # print(modeType) if condition.modeType == modeType: sourceModes.append(condition.mode) foundMode = True @@ -233,6 +246,7 @@ class TransitionUtil: edges.append(Edge(sourceindex, destindex, guards, resets)) return edges + class ControllerAst(): ''' Initalizing function for a controllerAst object. @@ -240,17 +254,19 @@ class ControllerAst(): Statement tree is a tree of nodes that contain a list in their data. The list contains a single guard or a list of resets. Variables (inputs to the controller) are collected. Modes are collected from all enums that have the word "mode" in them. - Vertices are generated by taking the products of mode types. + Vertices are generated by taking the products of mode types. ''' - def __init__(self, code = None, file_name = None): + + def __init__(self, code=None, file_name=None): assert code is not None or file_name is not None if file_name is not None: - with open(file_name,'r') as f: - code = f.read() + with open(file_name, 'r') as f: + code = f.read() self.code = code self.tree = ast.parse(code) - self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict, self.type_vars = self.initalwalktree(code, self.tree) + self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict, self.type_vars = self.initalwalktree( + code, self.tree) self.vertices = [] self.vertexStrings = [] for vertex in itertools.product(*self.modes.values()): @@ -262,62 +278,65 @@ class ControllerAst(): ''' Function to populate paths variable with all paths of the controller. ''' + def getAllPaths(self): self.paths = self.getNextModes([], True) return self.paths - + ''' - getNextModes takes in a list of current modes. It should include all modes. + getNextModes takes in a list of current modes. It should include all modes. getNextModes returns a list of paths that can be followed when in the given mode. A path is a list of statements, all guards and resets along the path. They are in the order they are encountered in the code. TODO: should we not force all modes be listed? Or rerun for each unknown/don't care node? Or add them all to the list ''' - def getNextModes(self, currentModes: List[str], getAllPaths= False) -> List[str]: - #walk the tree and capture all paths that have modes that are listed. Path is a list of statements + + def getNextModes(self, currentModes: List[str], getAllPaths=False) -> List[str]: + # walk the tree and capture all paths that have modes that are listed. Path is a list of statements paths = [] rootid = self.statementtree.root currnode = self.statementtree.get_node(rootid) paths = self.walkstatements(currnode, currentModes, getAllPaths) - - return paths + + return paths ''' Helper function to walk the statement tree from parentnode and find paths that are allowed in the currentMode. - Returns a list of paths. + Returns a list of paths. ''' - def walkstatements(self, parentnode, currentModes, getAllPaths): - nextsPaths = [] + def walkstatements(self, parentnode: Node, currentModes, getAllPaths): + nextsPaths = [] + # print("walkstatements", parentnode.tag) for node in self.statementtree.children(parentnode.identifier): statement = node.data - + # if parentnode.tag == "ego.vehicle_mode == VehicleMode.Brake": + # print(statement[0]) if isinstance(statement[0], Guard) and statement[0].isModeCheck(): if getAllPaths or statement[0].mode in currentModes: - #print(statement.mode) - newPaths = self.walkstatements(node, currentModes, getAllPaths) + newPaths = self.walkstatements( + node, currentModes, getAllPaths) for path in newPaths: newpath = statement.copy() newpath.extend(path) nextsPaths.append(newpath) if len(nextsPaths) == 0: nextsPaths.append(statement) - + else: - newPaths =self.walkstatements(node, currentModes, getAllPaths) + newPaths = self.walkstatements(node, currentModes, getAllPaths) for path in newPaths: newpath = statement.copy() newpath.extend(path) nextsPaths.append(newpath) if len(nextsPaths) == 0: - nextsPaths.append(statement) - + nextsPaths.append(statement) return nextsPaths - ''' Function to create a json of the full graph. - Requires that paths class variables has been set. + Requires that paths class variables has been set. ''' + def create_json(self, input_file_name, output_file_name): if not self.paths: print("Cannot call create_json without calling getAllPaths") @@ -336,22 +355,23 @@ class ControllerAst(): resets = [] for path in self.paths: - transitions = TransitionUtil.createTransition(path, self.vertices, self.modes) + transitions = TransitionUtil.createTransition( + path, self.vertices, self.modes) for edge in transitions: edges.append([edge.source, edge.dest]) guards.append(TransitionUtil.guardString(edge.guards)) resets.append(TransitionUtil.resetString(edge.resets)) - + output_dict['vertex'] = self.vertexStrings - #print(vertices) + # print(vertices) output_dict['variables'] = self.variables # #add edge, transition(guards) and resets output_dict['edge'] = edges - #print(len(edges)) + # print(len(edges)) output_dict['guards'] = guards - #print(len(guards)) + # print(len(guards)) output_dict['resets'] = resets - #print(len(resets)) + # print(len(resets)) output_json = json.dumps(output_dict, indent=4) outfile = open(output_file_name, "w") @@ -360,11 +380,12 @@ class ControllerAst(): print("wrote json to " + output_file_name) - #inital tree walk, parse into a tree of resets/modes + # inital tree walk, parse into a tree of resets/modes ''' - Function called by init function. Walks python ast and parses to a statement tree. + Function called by init function. Walks python ast and parses to a statement tree. Returns a statement tree (nodes contain a list of either a single guard or muliple resets), the variables, and a mode dictionary ''' + def initalwalktree(self, code, tree): vars = [] discrete_vars = [] @@ -374,7 +395,8 @@ class ControllerAst(): state_object_dict = {} vars_dict = {} statementtree = Tree() - for node in ast.walk(tree): #don't think we want to walk the whole thing because lose ordering/depth + # don't think we want to walk the whole thing because lose ordering/depth + for node in ast.walk(tree): # Get all the modes if isinstance(node, ast.ClassDef): if "Mode" in node.name: @@ -385,25 +407,30 @@ class ControllerAst(): mode_dict[modeType] = modes if isinstance(node, ast.ClassDef): if "State" in node.name: - state_object_dict[node.name] = {"cont":[],"disc":[], "type": []} + state_object_dict[node.name] = { + "cont": [], "disc": [], "type": []} for item in node.body: if isinstance(item, ast.FunctionDef): if "init" in item.name: for arg in item.args.args: if "self" not in arg.arg: if "type" == arg.arg: - state_object_dict[node.name]["type"].append(arg.arg) + state_object_dict[node.name]["type"].append( + arg.arg) elif "mode" not in arg.arg: - state_object_dict[node.name]['cont'].append(arg.arg) + state_object_dict[node.name]['cont'].append( + arg.arg) # vars.append(arg.arg) else: - state_object_dict[node.name]['disc'].append(arg.arg) + state_object_dict[node.name]['disc'].append( + arg.arg) # discrete_vars.append(arg.arg) if isinstance(node, ast.FunctionDef): if node.name == 'controller': - #print(node.body) - statementtree = self.parsenodelist(code, node.body, False, Tree(), None) - #print(type(node.args)) + # print(node.body) + statementtree = self.parsenodelist( + code, node.body, False, Tree(), None) + # print(type(node.args)) args = node.args.args for arg in args: if arg.annotation is None: @@ -412,7 +439,8 @@ class ControllerAst(): continue arg_annotation = arg.annotation.id arg_name = arg.arg - vars_dict[arg_name] = {'cont':[], 'disc':[], "type": []} + vars_dict[arg_name] = { + 'cont': [], 'disc': [], "type": []} for var in state_object_dict[arg_annotation]['cont']: vars.append(arg_name+"."+var) vars_dict[arg_name]['cont'].append(var) @@ -428,17 +456,18 @@ class ControllerAst(): # #todo: what to add for return values # else: # discrete_vars.append(arg.arg) + statementtree.show() return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict, type_vars] - ''' Helper function for initalwalktree which parses the statements in the controller function into a statement tree ''' - def parsenodelist(self, code, nodes, addResets, tree, parent): - childrens_guards=[] - childrens_resets=[] + + def parsenodelist(self, code, nodes, addResets, tree: Tree, parent): + childrens_guards = [] + childrens_resets = [] recoutput = [] - #tree.show() + # tree.show() if parent == None: s = Statement("root", None, None) tree.create_node("root") @@ -447,57 +476,73 @@ class ControllerAst(): for childnode in nodes: if isinstance(childnode, ast.Assign) and addResets: reset = Reset.parseReset(childnode, code) - #print("found reset: " + reset.code) + print("found reset: " + reset.code) childrens_resets.append(reset) if isinstance(childnode, ast.If): guard = Guard.parseGuard(childnode, code) childrens_guards.append(guard) - #print("found if statement: " + guard.code) + print("found if statement: " + guard.code) newTree = Tree() - newTree.create_node(tag= guard.code, data = [guard]) - #print(self.nodect) - tempresults = self.parsenodelist(code, childnode.body, True, newTree, newTree.root) - #for result in tempresults: + newTree.create_node(tag=guard.code, data=[guard]) + # print(self.nodect) + tempresults = self.parsenodelist( + code, childnode.body, True, newTree, newTree.root) + # for result in tempresults: recoutput.append(tempresults) - - #pathsafterme = [] + # if (len(childnode.orelse) > 0): + # childnode = childnode.orelse[0] + # if isinstance(childnode, ast.If): + # guard = Guard.parseGuard(childnode, code) + # childrens_guards.append(guard) + # print("found if statement: " + guard.code) + # newTree = Tree() + # newTree.create_node(tag=guard.code, data=[guard]) + # # print(self.nodect) + # tempresults = self.parsenodelist( + # code, childnode.body, True, newTree, newTree.root) + # # for result in tempresults: + # recoutput.append(tempresults) + + # pathsafterme = [] if len(childrens_resets) > 0: - #print("adding node:" + str(self.nodect) + "with parent:" + str(parent)) - tree.create_node(tag = childrens_resets[0].code, data = childrens_resets, parent= parent) + # print("adding node:" + str(self.nodect) + "with parent:" + str(parent)) + tree.create_node( + tag=childrens_resets[0].code, data=childrens_resets, parent=parent) for subtree in recoutput: - #print("adding subtree:" + " to parent:" + str(parent)) + # print("adding subtree:" + " to parent:" + str(parent)) tree.paste(parent, subtree) - - + # tree.show() return tree + class EmptyAst(ControllerAst): def __init__(self): super().__init__(code="True", file_name=None) self.discrete_variables = [] self.modes = { - 'NullMode':['Null'], - 'LaneMode':['Normal'] + 'NullMode': ['Null'], + 'LaneMode': ['Normal'] } self.paths = None self.state_object_dict = { - 'State':{ - 'cont':[], - 'disc':[], - 'type':[] + 'State': { + 'cont': [], + 'disc': [], + 'type': [] } } self.type_vars = [] self.variables = [] self.vars_dict = [] self.vertexStrings = ['Null,Normal'] - self.vertices=[('Null','Normal')] + self.vertices = [('Null', 'Normal')] self.statementtree.create_node("root") + ##main code### if __name__ == "__main__": - #if len(sys.argv) < 4: + # if len(sys.argv) < 4: # print("incorrect usage. call createGraph.py program inputfile outputfilename") # quit() @@ -511,18 +556,18 @@ if __name__ == "__main__": output_dict = { } - #read in the controler code - f = open(input_code_name,'r') + # read in the controler code + f = open(input_code_name, 'r') code = f.read() - #parse the controller code into our controller ast objct + # parse the controller code into our controller ast objct controller_obj = ControllerAst(code) print(controller_obj.variables) - #demonstrate you can check getNextModes after only initalizing + # demonstrate you can check getNextModes after only initalizing paths = controller_obj.getNextModes("NormalA;Normal3") - + print("Results") for path in paths: for item in path: @@ -530,16 +575,11 @@ if __name__ == "__main__": print() print("Done") - #attempt to write to json, fail because we haven't populated paths yet + # attempt to write to json, fail because we haven't populated paths yet controller_obj.create_json(input_file_name, output_file_name) - #call function that gets all paths + # call function that gets all paths controller_obj.getAllPaths() - #write json with all paths + # write json with all paths controller_obj.create_json(input_file_name, output_file_name) - - - - - diff --git a/dryvr_plus_plus/scene_verifier/dryvr/__init__.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/common/__init__.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/common/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/common/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/common/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/common/config.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/common/config.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/common/config.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/common/config.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/common/constant.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/common/constant.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/common/constant.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/common/constant.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/common/io.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/common/io.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/common/io.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/common/io.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/common/utils.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/__init__.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/distance.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/distance.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/distance.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/distance.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrmain.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrmain.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/dryvrmain.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrmain.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/goalchecker.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/goalchecker.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/goalchecker.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/goalchecker.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/graph.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/graph.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/graph.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/graph.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py similarity index 83% rename from dryvr_plus_plus/scene_verifier/dryvr/core/guard.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py index 2996f61bcc914234488f9957fe9c8c7549192856..579ab29b594b05e37aa37dbd79404aa1ccd5b437 100644 --- a/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py +++ b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py @@ -47,14 +47,17 @@ class Guard: # Thus we need to replace "==" to something else sympy_guard_str = guard_str.replace("==", ">=") - symbols = list(sympy.sympify(sympy_guard_str, evaluate=False).free_symbols) + symbols = list(sympy.sympify( + sympy_guard_str, evaluate=False).free_symbols) symbols = [str(s) for s in symbols] - symbols_idx = {s: self.variables.index(s) + 1 for s in symbols if s in self.variables} + symbols_idx = {s: self.variables.index( + s) + 1 for s in symbols if s in self.variables} if 't' in symbols: symbols_idx['t'] = 0 guard_str = handleReplace(guard_str, list(self.varDic.keys())) - cur_solver.add(eval(guard_str)) # TODO use an object instead of `eval` a string + # TODO use an object instead of `eval` a string + cur_solver.add(eval(guard_str)) return cur_solver, symbols_idx def guard_sim_trace(self, trace, guard_str): @@ -97,8 +100,10 @@ class Guard: upper = trace[idx + 1] cur_solver.push() for symbol in symbols: - cur_solver.add(self.varDic[symbol] >= min(lower[symbols[symbol]], upper[symbols[symbol]])) - cur_solver.add(self.varDic[symbol] <= max(lower[symbols[symbol]], upper[symbols[symbol]])) + cur_solver.add(self.varDic[symbol] >= min( + lower[symbols[symbol]], upper[symbols[symbol]])) + cur_solver.add(self.varDic[symbol] <= max( + lower[symbols[symbol]], upper[symbols[symbol]])) if cur_solver.check() == sat: cur_solver.pop() guard_set[idx] = upper @@ -163,8 +168,10 @@ class Guard: lower_bound = tube[i] upper_bound = tube[i + 1] for symbol in symbols: - cur_solver.add(self.varDic[symbol] >= lower_bound[symbols[symbol]]) - cur_solver.add(self.varDic[symbol] <= upper_bound[symbols[symbol]]) + cur_solver.add(self.varDic[symbol] >= + lower_bound[symbols[symbol]]) + cur_solver.add(self.varDic[symbol] <= + upper_bound[symbols[symbol]]) if cur_solver.check() == sat: # The reachtube hits the guard cur_solver.pop() @@ -175,8 +182,10 @@ class Guard: tmp_solver = Solver() tmp_solver.add(Not(cur_solver.assertions()[0])) for symbol in symbols: - tmp_solver.add(self.varDic[symbol] >= lower_bound[symbols[symbol]]) - tmp_solver.add(self.varDic[symbol] <= upper_bound[symbols[symbol]]) + tmp_solver.add( + self.varDic[symbol] >= lower_bound[symbols[symbol]]) + tmp_solver.add( + self.varDic[symbol] <= upper_bound[symbols[symbol]]) if tmp_solver.check() == unsat: print("Full intersect, break") break @@ -189,8 +198,10 @@ class Guard: init_upper = guard_set_upper[0][1:] for j in range(1, len(guard_set_lower)): for k in range(1, len(guard_set_lower[0])): - init_lower[k - 1] = min(init_lower[k - 1], guard_set_lower[j][k]) - init_upper[k - 1] = max(init_upper[k - 1], guard_set_upper[j][k]) + init_lower[k - 1] = min(init_lower[k - 1], + guard_set_lower[j][k]) + init_upper[k - 1] = max(init_upper[k - 1], + guard_set_upper[j][k]) # Return next initial Set, the result tube, and the true transit time return [init_lower, init_upper], tube[:i], guard_set_lower[0][0] @@ -202,11 +213,13 @@ class Guard: init_upper = guard_set_upper[0][1:] for j in range(1, len(guard_set_lower)): for k in range(1, len(guard_set_lower[0])): - init_lower[k - 1] = min(init_lower[k - 1], guard_set_lower[j][k]) - init_upper[k - 1] = max(init_upper[k - 1], guard_set_upper[j][k]) + init_lower[k - 1] = min(init_lower[k - 1], + guard_set_lower[j][k]) + init_upper[k - 1] = max(init_upper[k - 1], + guard_set_upper[j][k]) # TODO: Treat tau as a special clock variable that don't have variation # init_upper[0] = init_lower[0] - + # Return next initial Set, the result tube, and the true transit time return [init_lower, init_upper], tube[:i], guard_set_lower[0][0] diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/initialset.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/initialset.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/initialset.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/initialset.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/initialsetstack.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/initialsetstack.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/initialsetstack.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/initialsetstack.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/reachtube.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/reachtube.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/reachtube.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/reachtube.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/reset.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/reset.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/core/reset.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/reset.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py similarity index 98% rename from dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py index dfa49b0a6d2d55b3c539029518cfbb8f1459a128..cce5c1a73199dda227b3ade0b8bbade9794c0d0b 100644 --- a/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py +++ b/demo/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py @@ -49,7 +49,8 @@ class UniformChecker: cond = cond.replace("==", ">=") symbols = list(sympy.sympify(cond).free_symbols) symbols = [str(s) for s in symbols] - symbols_idx = {s: self._variables.index(s) + 1 for s in symbols if s in self._variables} + symbols_idx = {s: self._variables.index( + s) + 1 for s in symbols if s in self._variables} if 't' in symbols: symbols_idx['t'] = 0 self._solver_dict[mode].append(symbols_idx) # TODO Fix typing diff --git a/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/PW_Discrepancy.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/PW_Discrepancy.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/discrepancy/PW_Discrepancy.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/PW_Discrepancy.py diff --git a/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/__init__.py b/demo/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/dryvr/discrepancy/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/map/__init__.py b/demo/dryvr_plus_plus/scene_verifier/map/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/map/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/map/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/map/lane.py b/demo/dryvr_plus_plus/scene_verifier/map/lane.py similarity index 78% rename from dryvr_plus_plus/scene_verifier/map/lane.py rename to demo/dryvr_plus_plus/scene_verifier/map/lane.py index e48e3f10cc14647bd5d6d3cfcba4d2198a4e7b25..da3adef5919ea2d6db8906369d287a9df7f352dd 100644 --- a/dryvr_plus_plus/scene_verifier/map/lane.py +++ b/demo/dryvr_plus_plus/scene_verifier/map/lane.py @@ -4,8 +4,10 @@ import numpy as np from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane + class Lane(): COMPENSATE = 3 + def __init__(self, id, seg_list: List[AbstractLane]): self.id = id self.segment_list: List[AbstractLane] = seg_list @@ -17,28 +19,29 @@ class Lane(): lane_seg.longitudinal_start = longitudinal_start longitudinal_start += lane_seg.length - def get_lane_segment(self, position:np.ndarray) -> AbstractLane: + def get_lane_segment(self, position: np.ndarray) -> AbstractLane: for seg_idx, segment in enumerate(self.segment_list): logitudinal, lateral = segment.local_coordinates(position) + # why COMPENSATE? why no lateral? use on_lane? is_on = 0-Lane.COMPENSATE <= logitudinal < segment.length if is_on: return seg_idx, segment - return -1,None + return -1, None - def get_heading(self, position:np.ndarray) -> float: + def get_heading(self, position: np.ndarray) -> float: seg_idx, segment = self.get_lane_segment(position) longitudinal, lateral = segment.local_coordinates(position) heading = segment.heading_at(longitudinal) return heading - def get_longitudinal_position(self, position:np.ndarray) -> float: + def get_longitudinal_position(self, position: np.ndarray) -> float: seg_idx, segment = self.get_lane_segment(position) longitudinal, lateral = segment.local_coordinates(position) - for i in range(seg_idx): + for i in range(seg_idx): longitudinal += self.segment_list[i].length return longitudinal - def get_lateral_distance(self, position:np.ndarray) -> float: + def get_lateral_distance(self, position: np.ndarray) -> float: seg_idx, segment = self.get_lane_segment(position) longitudinal, lateral = segment.local_coordinates(position) return lateral diff --git a/dryvr_plus_plus/scene_verifier/map/lane_map.py b/demo/dryvr_plus_plus/scene_verifier/map/lane_map.py similarity index 78% rename from dryvr_plus_plus/scene_verifier/map/lane_map.py rename to demo/dryvr_plus_plus/scene_verifier/map/lane_map.py index 13ddcb7ab3b59035ec0d5f561b0f811c3829fbe9..f17afe6f389ad59ae023616b87530baa8d41771b 100644 --- a/dryvr_plus_plus/scene_verifier/map/lane_map.py +++ b/demo/dryvr_plus_plus/scene_verifier/map/lane_map.py @@ -7,17 +7,19 @@ import numpy as np from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane from dryvr_plus_plus.scene_verifier.map.lane import Lane + class LaneMap: - def __init__(self, lane_seg_list:List[Lane] = []): - self.lane_dict:Dict[str, Lane] = {} - self.left_lane_dict:Dict[str, List[str]] = {} - self.right_lane_dict:Dict[str, List[str]] = {} + def __init__(self, lane_seg_list: List[Lane] = []): + self.lane_dict: Dict[str, Lane] = {} + self.left_lane_dict: Dict[str, List[str]] = {} + self.right_lane_dict: Dict[str, List[str]] = {} for lane_seg in lane_seg_list: self.lane_dict[lane_seg.id] = lane_seg self.left_lane_dict[lane_seg.id] = [] self.right_lane_dict[lane_seg.id] = [] - def add_lanes(self, lane_seg_list:List[AbstractLane]): + # why AbstractLane not Lane + def add_lanes(self, lane_seg_list: List[AbstractLane]): for lane_seg in lane_seg_list: self.lane_dict[lane_seg.id] = lane_seg self.left_lane_dict[lane_seg.id] = [] @@ -30,7 +32,7 @@ class LaneMap: Warning(f'lane {lane_idx} not available') return False left_lane_list = self.left_lane_dict[lane_idx] - return len(left_lane_list)>0 + return len(left_lane_list) > 0 def left_lane(self, lane_idx): assert all((elem in self.left_lane_dict) for elem in self.lane_dict) @@ -40,7 +42,7 @@ class LaneMap: raise ValueError(f"lane_idx {lane_idx} not in lane_dict") left_lane_list = self.left_lane_dict[lane_idx] return copy.deepcopy(left_lane_list) - + def has_right(self, lane_idx): if isinstance(lane_idx, Enum): lane_idx = lane_idx.name @@ -48,7 +50,7 @@ class LaneMap: Warning(f'lane {lane_idx} not available') return False right_lane_list = self.right_lane_dict[lane_idx] - return len(right_lane_list)>0 + return len(right_lane_list) > 0 def right_lane(self, lane_idx): assert all((elem in self.right_lane_dict) for elem in self.lane_dict) @@ -58,36 +60,36 @@ class LaneMap: raise ValueError(f"lane_idx {lane_idx} not in lane_dict") right_lane_list = self.right_lane_dict[lane_idx] return copy.deepcopy(right_lane_list) - + def lane_geometry(self, lane_idx): if isinstance(lane_idx, Enum): lane_idx = lane_idx.name return self.lane_dict[lane_idx].get_geometry() - def get_longitudinal_position(self, lane_idx:str, position:np.ndarray) -> float: + def get_longitudinal_position(self, lane_idx: str, position: np.ndarray) -> float: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] return lane.get_longitudinal_position(position) - def get_lateral_distance(self, lane_idx:str, position:np.ndarray) -> float: + def get_lateral_distance(self, lane_idx: str, position: np.ndarray) -> float: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] return lane.get_lateral_distance(position) - def get_altitude(self, lane_idx, position:np.ndarray) -> float: + def get_altitude(self, lane_idx, position: np.ndarray) -> float: raise NotImplementedError - def get_lane_heading(self, lane_idx:str, position: np.ndarray) -> float: + def get_lane_heading(self, lane_idx: str, position: np.ndarray) -> float: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] return lane.get_heading(position) - def get_lane_segment(self, lane_idx:str, position: np.ndarray) -> AbstractLane: + def get_lane_segment(self, lane_idx: str, position: np.ndarray) -> AbstractLane: if not isinstance(position, np.ndarray): position = np.array(position) lane = self.lane_dict[lane_idx] seg_idx, segment = lane.get_lane_segment(position) - return segment \ No newline at end of file + return segment diff --git a/dryvr_plus_plus/scene_verifier/map/lane_segment.py b/demo/dryvr_plus_plus/scene_verifier/map/lane_segment.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/map/lane_segment.py rename to demo/dryvr_plus_plus/scene_verifier/map/lane_segment.py diff --git a/dryvr_plus_plus/scene_verifier/scenario/__init__.py b/demo/dryvr_plus_plus/scene_verifier/scenario/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/scenario/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/scenario/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/scenario/scenario.py b/demo/dryvr_plus_plus/scene_verifier/scenario/scenario.py similarity index 75% rename from dryvr_plus_plus/scene_verifier/scenario/scenario.py rename to demo/dryvr_plus_plus/scene_verifier/scenario/scenario.py index 715e836cc7648d079f558c745196dc6c3aff795c..2ad2454986a1e90884a4887c64fd19b3026c5929 100644 --- a/dryvr_plus_plus/scene_verifier/scenario/scenario.py +++ b/demo/dryvr_plus_plus/scene_verifier/scenario/scenario.py @@ -14,6 +14,7 @@ from dryvr_plus_plus.scene_verifier.analysis.verifier import Verifier from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap from dryvr_plus_plus.scene_verifier.utils.utils import * + class Scenario: def __init__(self): self.agent_dict = {} @@ -27,14 +28,14 @@ class Scenario: def set_sensor(self, sensor): self.sensor = sensor - def set_map(self, lane_map:LaneMap): + def set_map(self, lane_map: LaneMap): self.map = lane_map # Update the lane mode field in the agent for agent_id in self.agent_dict: agent = self.agent_dict[agent_id] self.update_agent_lane_mode(agent, lane_map) - def add_agent(self, agent:BaseAgent): + def add_agent(self, agent: BaseAgent): if self.map is not None: # Update the lane mode field in the agent self.update_agent_lane_mode(agent, self.map) @@ -46,12 +47,13 @@ class Scenario: agent.controller.modes['LaneMode'].append(lane_id) mode_vals = list(agent.controller.modes.values()) agent.controller.vertices = list(itertools.product(*mode_vals)) - agent.controller.vertexStrings = [','.join(elem) for elem in agent.controller.vertices] + agent.controller.vertexStrings = [ + ','.join(elem) for elem in agent.controller.vertices] def set_init(self, init_list, init_mode_list): assert len(init_list) == len(self.agent_dict) assert len(init_mode_list) == len(self.agent_dict) - for i,agent_id in enumerate(self.agent_dict.keys()): + for i, agent_id in enumerate(self.agent_dict.keys()): self.init_dict[agent_id] = copy.deepcopy(init_list[i]) self.init_mode_dict[agent_id] = copy.deepcopy(init_mode_list[i]) @@ -88,13 +90,13 @@ class Scenario: return self.verifier.compute_full_reachtube(init_list, init_mode_list, agent_list, self, time_horizon, self.map) def check_guard_hit(self, state_dict): - lane_map = self.map + lane_map = self.map guard_hits = [] any_contained = False # TODO: Handle this for agent_id in state_dict: - agent:BaseAgent = self.agent_dict[agent_id] + agent: BaseAgent = self.agent_dict[agent_id] agent_state, agent_mode = state_dict[agent_id] - + t = agent_state[0] agent_state = agent_state[1:] paths = agent.controller.getNextModes(agent_mode) @@ -110,21 +112,25 @@ class Scenario: # guard_expression = GuardExpression(guard_list=guard_list) guard_expression = GuardExpressionAst(guard_list) # Map the values to variables using sensor - continuous_variable_dict, discrete_variable_dict = self.sensor.sense(self, agent, state_dict, self.map) - + continuous_variable_dict, discrete_variable_dict = self.sensor.sense( + self, agent, state_dict, self.map) + # Check if the guard can be satisfied - # First Check if the discrete guards can be satisfied by actually evaluate the values + # First Check if the discrete guards can be satisfied by actually evaluate the values # since there's no uncertainty. If there's functions, actually execute the functions - guard_can_satisfied = guard_expression.evaluate_guard_disc(agent, discrete_variable_dict, continuous_variable_dict, self.map) + guard_can_satisfied = guard_expression.evaluate_guard_disc( + agent, discrete_variable_dict, continuous_variable_dict, self.map) if not guard_can_satisfied: continue - # TODO: Handle hybrid guards that involves both continuous and discrete dynamics + # TODO: Handle hybrid guards that involves both continuous and discrete dynamics # Will have to limit the amount of hybrid guards that we want to handle. The difficulty will be handle function guards. - guard_can_satisfied = guard_expression.evaluate_guard_hybrid(agent, discrete_variable_dict, continuous_variable_dict, self.map) + guard_can_satisfied = guard_expression.evaluate_guard_hybrid( + agent, discrete_variable_dict, continuous_variable_dict, self.map) # Handle guards realted only to continuous variables using SMT solvers. These types of guards can be pretty general - guard_satisfied, is_contained = guard_expression.evaluate_guard_cont(agent, continuous_variable_dict, self.map) + guard_satisfied, is_contained = guard_expression.evaluate_guard_cont( + agent, continuous_variable_dict, self.map) any_contained = any_contained or is_contained if guard_satisfied: guard_hits.append((agent_id, guard_list, reset_list)) @@ -137,14 +143,15 @@ class Scenario: guard_hit_bool = False # TODO: can add parallalization for this loop - for idx in range(0,trace_length): + for idx in range(0, trace_length): # For each trace, check with the guard to see if there's any possible transition # Store all possible transition in a list # A transition is defined by (agent, src_mode, dest_mode, corresponding reset, transit idx) # Here we enforce that only one agent transit at a time all_agent_state = {} for agent_id in node.agent: - all_agent_state[agent_id] = (node.trace[agent_id][idx*2:idx*2+2], node.mode[agent_id]) + all_agent_state[agent_id] = ( + node.trace[agent_id][idx*2:idx*2+2], node.mode[agent_id]) hits, is_contain = self.check_guard_hit(all_agent_state) # print(idx, is_contain) if hits != []: @@ -159,12 +166,14 @@ class Scenario: reset_idx_dict = {} for hits, all_agent_state, hit_idx in guard_hits: for agent_id, guard_list, reset_list in hits: - dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state) + dest_list, reset_rect = self.apply_reset( + node.agent[agent_id], reset_list, all_agent_state) if agent_id not in reset_dict: reset_dict[agent_id] = {} reset_idx_dict[agent_id] = {} if not dest_list: - warnings.warn(f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode") + warnings.warn( + f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode") dest_list.append(None) for dest in dest_list: if dest not in reset_dict[agent_id]: @@ -172,38 +181,44 @@ class Scenario: reset_idx_dict[agent_id][dest] = [] reset_dict[agent_id][dest].append(reset_rect) reset_idx_dict[agent_id][dest].append(hit_idx) - + # Combine reset rects and construct transitions for agent in reset_dict: for dest in reset_dict[agent]: - combined_rect = None + combined_rect = None for rect in reset_dict[agent][dest]: rect = np.array(rect) if combined_rect is None: - combined_rect = rect + combined_rect = rect else: - combined_rect[0,:] = np.minimum(combined_rect[0,:], rect[0,:]) - combined_rect[1,:] = np.maximum(combined_rect[1,:], rect[1,:]) + combined_rect[0, :] = np.minimum( + combined_rect[0, :], rect[0, :]) + combined_rect[1, :] = np.maximum( + combined_rect[1, :], rect[1, :]) combined_rect = combined_rect.tolist() min_idx = min(reset_idx_dict[agent][dest]) max_idx = max(reset_idx_dict[agent][dest]) - transition = (agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx)) + transition = ( + agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx)) possible_transitions.append(transition) # Return result return possible_transitions def apply_reset(self, agent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]: reset_expr = ResetExpression(reset_list) - continuous_variable_dict, discrete_variable_dict = self.sensor.sense(self, agent, all_agent_state, self.map) - dest = reset_expr.get_dest(agent, all_agent_state[agent.id], discrete_variable_dict, self.map) - rect = reset_expr.apply_reset_continuous(agent, continuous_variable_dict, self.map) + continuous_variable_dict, discrete_variable_dict = self.sensor.sense( + self, agent, all_agent_state, self.map) + dest = reset_expr.get_dest( + agent, all_agent_state[agent.id], discrete_variable_dict, self.map) + rect = reset_expr.apply_reset_continuous( + agent, continuous_variable_dict, self.map) return dest, rect def get_all_transition(self, state_dict: Dict[str, Tuple[List[float], List[str]]]): lane_map = self.map satisfied_guard = [] for agent_id in state_dict: - agent:BaseAgent = self.agent_dict[agent_id] + agent: BaseAgent = self.agent_dict[agent_id] agent_state, agent_mode = state_dict[agent_id] t = agent_state[0] agent_state = agent_state[1:] @@ -213,6 +228,7 @@ class Scenario: guard_list = [] reset_list = [] for item in path: + # print(item.code) if isinstance(item, Guard): guard_list.append(item) elif isinstance(item, Reset): @@ -220,52 +236,70 @@ class Scenario: # guard_expression = GuardExpression(guard_list=guard_list) guard_expression = GuardExpressionAst(guard_list) # Map the values to variables using sensor - continuous_variable_dict, discrete_variable_dict = self.sensor.sense(self, agent, state_dict, self.map) - + continuous_variable_dict, discrete_variable_dict = self.sensor.sense( + self, agent, state_dict, self.map) + # print(continuous_variable_dict) + # print(discrete_variable_dict) '''Execute functions related to map to see if the guard can be satisfied''' '''Check guards related to modes to see if the guards can be satisfied''' '''Actually plug in the values to see if the guards can be satisfied''' # Check if the guard can be satisfied - guard_satisfied = guard_expression.evaluate_guard(agent, continuous_variable_dict, discrete_variable_dict, self.map) + guard_satisfied = guard_expression.evaluate_guard( + agent, continuous_variable_dict, discrete_variable_dict, self.map) if guard_satisfied: # If the guard can be satisfied, handle resets next_init = agent_state dest = copy.deepcopy(agent_mode) possible_dest = [[elem] for elem in dest] + # like [['Normal'], ['Lane1']] + print('possible_dest', possible_dest) for reset in reset_list: # Specify the destination mode reset = reset.code if "mode" in reset: + print(agent.controller.vars_dict['ego']) + print(reset) + # why break + # vars_dict: {'cont': ['x', 'y', 'theta', 'v'], 'disc': ['vehicle_mode', 'lane_mode'], 'type': []} for i, discrete_variable_ego in enumerate(agent.controller.vars_dict['ego']['disc']): if discrete_variable_ego in reset: break tmp = reset.split('=') + # like output.lane_mode = lane_map.right_lane(ego.lane_mode) if 'map' in tmp[1]: tmp = tmp[1] for var in discrete_variable_dict: - tmp = tmp.replace(var, f"'{discrete_variable_dict[var]}'") + tmp = tmp.replace( + var, f"'{discrete_variable_dict[var]}'") possible_dest[i] = eval(tmp) + # like output.vehicle_mode = VehicleMode.Normal else: tmp = tmp[1].split('.') if tmp[0].strip(' ') in agent.controller.modes: - possible_dest[i] = [tmp[1]] - else: + possible_dest[i] = [tmp[1]] + else: + # for i, cts_variable in enumerate(agent.controller.vars_dict['ego']['cont']): if "output."+cts_variable in reset: - break + break tmp = reset.split('=') tmp = tmp[1] for cts_variable in continuous_variable_dict: - tmp = tmp.replace(cts_variable, str(continuous_variable_dict[cts_variable])) + tmp = tmp.replace(cts_variable, str( + continuous_variable_dict[cts_variable])) next_init[i] = eval(tmp) + # print('possible_dest', possible_dest) + # [['Brake'], ['Lane1']] -> [('Brake', 'Lane1')] all_dest = list(itertools.product(*possible_dest)) + # print('all_dest', all_dest) if not all_dest: - warnings.warn(f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode") + warnings.warn( + f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode") all_dest.append(None) for dest in all_dest: next_transition = ( - agent_id, agent_mode, dest, next_init, + agent_id, agent_mode, dest, next_init, ) + # print(next_transition) satisfied_guard.append(next_transition) - return satisfied_guard diff --git a/dryvr_plus_plus/scene_verifier/utils/__init__.py b/demo/dryvr_plus_plus/scene_verifier/utils/__init__.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/utils/__init__.py rename to demo/dryvr_plus_plus/scene_verifier/utils/__init__.py diff --git a/dryvr_plus_plus/scene_verifier/utils/utils.py b/demo/dryvr_plus_plus/scene_verifier/utils/utils.py similarity index 100% rename from dryvr_plus_plus/scene_verifier/utils/utils.py rename to demo/dryvr_plus_plus/scene_verifier/utils/utils.py diff --git a/demo/example_controller1.py b/demo/example_controller1.py index bc3c11401fe4fbd7c648d09614ef97a091014f9a..3d467ccec825dd16d599d35eafd5cd1bbc83f2b9 100644 --- a/demo/example_controller1.py +++ b/demo/example_controller1.py @@ -1,17 +1,24 @@ from enum import Enum, auto import copy +from sympy import false + +from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap + + class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() SwitchRight = auto() Brake = auto() + class LaneMode(Enum): Lane0 = auto() Lane1 = auto() Lane2 = auto() + class State: x = 0.0 y = 0.0 @@ -23,17 +30,17 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): self.data = [] -def controller(ego:State, other:State, lane_map): + +def controller(ego: State, other: State, lane_map: LaneMap): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: - if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 0 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 3 \ - and ego.lane_mode == other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 0 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 3 \ + and ego.lane_mode == other.lane_mode: output.vehicle_mode = VehicleMode.Brake elif ego.vehicle_mode == VehicleMode.Brake: - if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 10 \ - or ego.lane_mode != other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ + or ego.lane_mode != other.lane_mode: output.vehicle_mode = VehicleMode.Normal return output - diff --git a/demo/example_two_car_sign_lane_switch.py b/demo/example_two_car_sign_lane_switch.py index fcfaf9235f447795527bbcc46b3923130295e149..5b5e610226f4c8cbde402e16789de4602292a438 100644 --- a/demo/example_two_car_sign_lane_switch.py +++ b/demo/example_two_car_sign_lane_switch.py @@ -1,6 +1,14 @@ +import matplotlib.pyplot as plt +from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2 +from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree, generate_simulation_anime, plot_map +from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3 +from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario +from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent +from dryvr_plus_plus.example.example_agent.car_agent import CarAgent from enum import Enum, auto import copy + class LaneObjectMode(Enum): Vehicle = auto() Ped = auto() # Pedestrians @@ -8,6 +16,7 @@ class LaneObjectMode(Enum): Signal = auto() # Traffic lights Obstacle = auto() # Static (to road/lane) obstacles + class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() @@ -20,6 +29,7 @@ class LaneMode(Enum): Lane1 = auto() Lane2 = auto() + class State: x: float y: float @@ -40,24 +50,25 @@ class State: # self.lane_mode = lane_mode # self.obj_mode = obj_mode + def controller(ego: State, other: State, sign: State, lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: if sign.type == LaneObjectMode.Obstacle and sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode: output.vehicle_mode = VehicleMode.SwitchLeft return output - if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ + and ego.lane_mode == other.lane_mode: if lane_map.has_left(ego.lane_mode): output.vehicle_mode = VehicleMode.SwitchLeft - if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \ - and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \ - and ego.lane_mode == other.lane_mode: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \ + and ego.lane_mode == other.lane_mode: if lane_map.has_right(ego.lane_mode): output.vehicle_mode = VehicleMode.SwitchRight if ego.vehicle_mode == VehicleMode.SwitchLeft: - if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5: + if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5: output.vehicle_mode = VehicleMode.Normal output.lane_mode = lane_map.left_lane(ego.lane_mode) if ego.vehicle_mode == VehicleMode.SwitchRight: @@ -68,15 +79,6 @@ def controller(ego: State, other: State, sign: State, lane_map): return output -from dryvr_plus_plus.example.example_agent.car_agent import CarAgent -from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent -from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario -from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3 -from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree -from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2 - -import matplotlib.pyplot as plt - if __name__ == "__main__": import sys input_code_name = sys.argv[0] @@ -91,9 +93,9 @@ if __name__ == "__main__": scenario.set_sensor(FakeSensor2()) scenario.set_init( [ - [[0, -0.2, 0, 1.0],[0.2, 0.2, 0, 1.0]], - [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], - [[20, 0, 0, 0],[20, 0, 0, 0]], + [[0, -0.2, 0, 1.0], [0.2, 0.2, 0, 1.0]], + [[10, 0, 0, 0.5], [10, 0, 0, 0.5]], + [[20, 0, 0, 0], [20, 0, 0, 0]], ], [ (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), @@ -106,8 +108,11 @@ if __name__ == "__main__": # traces = scenario.verify(40) fig = plt.figure() + fig = plot_map(SimpleMap3(), 'g', fig) fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig) fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) + # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) + # generate_simulation_anime(traces, SimpleMap3(), fig) plt.show() - diff --git a/demo/plot1.py b/demo/plot1.py new file mode 100644 index 0000000000000000000000000000000000000000..93733594d3c4ed05567c7eb974440aa8c0514e61 --- /dev/null +++ b/demo/plot1.py @@ -0,0 +1,38 @@ +from dash import Dash, dcc, html, Input, Output +import plotly.express as px + +app = Dash(__name__) + + +app.layout = html.Div([ + html.H4('Animated GDP and population over decades'), + html.P("Select an animation:"), + dcc.RadioItems( + id='selection', + options=["GDP - Scatter", "Population - Bar"], + value='GDP - Scatter', + ), + dcc.Loading(dcc.Graph(id="graph"), type="cube") +]) + + +@app.callback( + Output("graph", "figure"), + Input("selection", "value")) +def display_animated_graph(selection): + df = px.data.gapminder() # replace with your own data source + animations = { + 'GDP - Scatter': px.scatter( + df, x="gdpPercap", y="lifeExp", animation_frame="year", + animation_group="country", size="pop", color="continent", + hover_name="country", log_x=True, size_max=55, + range_x=[100, 100000], range_y=[25, 90]), + 'Population - Bar': px.bar( + df, x="continent", y="pop", color="continent", + animation_frame="year", animation_group="country", + range_y=[0, 4000000000]), + } + return animations[selection] + + +app.run_server() diff --git a/demo/plot2.py b/demo/plot2.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8a74c6b000de18fc232bed841ea8b0c47a55c8 --- /dev/null +++ b/demo/plot2.py @@ -0,0 +1,133 @@ +import plotly.graph_objects as go + +import pandas as pd + +url = "https://raw.githubusercontent.com/plotly/datasets/master/gapminderDataFiveYear.csv" +dataset = pd.read_csv(url) + +years = ["1952", "1962", "1967", "1972", "1977", "1982", "1987", "1992", "1997", "2002", + "2007"] + +# make list of continents +continents = [] +for continent in dataset["continent"]: + if continent not in continents: + continents.append(continent) +# make figure +fig_dict = { + "data": [], + "layout": {}, + "frames": [] +} + +# fill in most of layout +fig_dict["layout"]["xaxis"] = {"range": [30, 85], "title": "Life Expectancy"} +fig_dict["layout"]["yaxis"] = {"title": "GDP per Capita", "type": "log"} +fig_dict["layout"]["hovermode"] = "closest" +fig_dict["layout"]["updatemenus"] = [ + { + "buttons": [ + { + "args": [None, {"frame": {"duration": 500, "redraw": False}, + "fromcurrent": True, "transition": {"duration": 300, + "easing": "quadratic-in-out"}}], + "label": "Play", + "method": "animate" + }, + { + "args": [[None], {"frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}}], + "label": "Pause", + "method": "animate" + } + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top" + } +] + +sliders_dict = { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 20}, + "prefix": "Year:", + "visible": False, + "xanchor": "right" + }, + "transition": {"duration": 300, "easing": "cubic-in-out"}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [] +} + +# make data +year = 1952 +for continent in continents: + dataset_by_year = dataset[dataset["year"] == year] + dataset_by_year_and_cont = dataset_by_year[ + dataset_by_year["continent"] == continent] + + data_dict = { + "x": list(dataset_by_year_and_cont["lifeExp"]), + "y": list(dataset_by_year_and_cont["gdpPercap"]), + "mode": "lines", + "text": list(dataset_by_year_and_cont["country"]), + "marker": { + "sizemode": "area", + "sizeref": 200000, + "size": list(dataset_by_year_and_cont["pop"]) + }, + "name": continent + } + fig_dict["data"].append(data_dict) + +# make frames +for year in years: + frame = {"data": [], "name": str(year)} + for continent in continents: + dataset_by_year = dataset[dataset["year"] == int(year)] + dataset_by_year_and_cont = dataset_by_year[ + dataset_by_year["continent"] == continent] + + data_dict = { + "x": list(dataset_by_year_and_cont["lifeExp"]), + "y": list(dataset_by_year_and_cont["gdpPercap"]), + "mode": "lines", + "text": list(dataset_by_year_and_cont["country"]), + "marker": { + "sizemode": "area", + "sizeref": 200000, + "size": list(dataset_by_year_and_cont["pop"]) + }, + "name": continent + } + frame["data"].append(data_dict) + + fig_dict["frames"].append(frame) + slider_step = {"args": [ + [year], + {"frame": {"duration": 300, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 300}} + ], + "label": year, + "method": "animate"} + sliders_dict["steps"].append(slider_step) + + +fig_dict["layout"]["sliders"] = [sliders_dict] + +fig = go.Figure(fig_dict) + +fig.show() diff --git a/demo/plot_test.py b/demo/plot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66102f97aded78aded42a9eebc55faf3ae7b3ae6 --- /dev/null +++ b/demo/plot_test.py @@ -0,0 +1,75 @@ +import plotly.graph_objects as go +import numpy as np + + +x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +x_rev = x[::-1] + +# Line 1 +y1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +y1_upper = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +y1_lower = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +y1_lower = y1_lower[::-1] + +# Line 2 +y2 = [5, 2.5, 5, 7.5, 5, 2.5, 7.5, 4.5, 5.5, 5] +y2_upper = [5.5, 3, 5.5, 8, 6, 3, 8, 5, 6, 5.5] +y2_lower = [4.5, 2, 4.4, 7, 4, 2, 7, 4, 5, 4.75] +y2_lower = y2_lower[::-1] + +# Line 3 +y3 = [10, 8, 6, 4, 2, 0, 2, 4, 2, 0] +y3_upper = [11, 9, 7, 5, 3, 1, 3, 5, 3, 1] +y3_lower = [9, 7, 5, 3, 1, -.5, 1, 3, 1, -1] +y3_lower = y3_lower[::-1] + + +fig = go.Figure() + +fig.add_trace(go.Scatter( + x=x+x_rev, + y=y1_upper+y1_lower, + # fill='toself', + # fillcolor='rgba(0,100,80,0.2)', + # line_color='rgba(255,255,255,0)', + # showlegend=False, + name='Fair', +)) +# fig.add_trace(go.Scatter( +# x=x+x_rev, +# y=y2_upper+y2_lower, +# fill='toself', +# fillcolor='rgba(0,176,246,0.2)', +# line_color='rgba(255,255,255,0)', +# name='Premium', +# showlegend=False, +# )) +# fig.add_trace(go.Scatter( +# x=x+x_rev, +# y=y3_upper+y3_lower, +# fill='toself', +# fillcolor='rgba(231,107,243,0.2)', +# line_color='rgba(255,255,255,0)', +# showlegend=False, +# name='Ideal', +# )) +fig.add_trace(go.Scatter( + x=x, y=y1, + line_color='rgb(0,100,80)', + name='Fair', +)) +# fig.add_trace(go.Scatter( +# x=x, y=y2, +# line_color='rgb(0,176,246)', +# name='Premium', +# )) +# fig.add_trace(go.Scatter( +# x=x, y=y3, +# line_color='rgb(231,107,243)', +# name='Ideal', +# )) + +fig.update_traces(mode='lines') +fig.show() +print(x+x_rev) +print(y1_upper+y1_lower) diff --git a/dryvr_plus_plus/plotter/plotter2D.py b/dryvr_plus_plus/plotter/plotter2D.py deleted file mode 100644 index a43045dd41bc6059b132d7f93da21743029a80ab..0000000000000000000000000000000000000000 --- a/dryvr_plus_plus/plotter/plotter2D.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -This file consist main plotter code for DryVR reachtube output -""" - -import matplotlib.patches as patches -import matplotlib.pyplot as plt -import numpy as np -from typing import List -from PIL import Image, ImageDraw -import io - -colors = ['red', 'green', 'blue', 'yellow', 'black'] - -def plot( - data, - x_dim: int = 0, - y_dim_list: List[int] = [1], - color = 'b', - fig = None, - x_lim = None, - y_lim = None -): - if fig is None: - fig = plt.figure() - - ax = fig.gca() - if x_lim is None: - x_lim = ax.get_xlim() - if y_lim is None: - y_lim = ax.get_ylim() - - x_min, x_max = x_lim - y_min, y_max = y_lim - for rect in data: - lb = rect[0] - ub = rect[1] - for y_dim in y_dim_list: - rect_patch = patches.Rectangle((lb[x_dim], lb[y_dim]), ub[x_dim]-lb[x_dim], ub[y_dim]-lb[y_dim], color = color) - ax.add_patch(rect_patch) - x_min = min(lb[x_dim], x_min) - y_min = min(lb[y_dim], y_min) - x_max = max(ub[x_dim], x_max) - y_max = max(ub[y_dim], y_max) - - ax.set_xlim([x_min-1, x_max+1]) - ax.set_ylim([y_min-1, y_max+1]) - return fig, (x_min, x_max), (y_min, y_max) - -def plot_reachtube_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None): - if fig is None: - fig = plt.figure() - - ax = fig.gca() - if x_lim is None: - x_lim = ax.get_xlim() - if y_lim is None: - y_lim = ax.get_ylim() - - queue = [root] - while queue != []: - node = queue.pop(0) - traces = node.trace - trace = traces[agent_id] - data = [] - for i in range(0,len(trace),2): - data.append([trace[i], trace[i+1]]) - fig, x_lim, y_lim = plot(data, x_dim, y_dim_list, color, fig, x_lim, y_lim) - - queue += node.child - - return fig - -def plot_map(map, color = 'b', fig = None, x_lim = None,y_lim = None): - if fig is None: - fig = plt.figure() - - ax = fig.gca() - if x_lim is None: - x_lim = ax.get_xlim() - if y_lim is None: - y_lim = ax.get_ylim() - - for lane_idx in map.lane_dict: - lane = map.lane_dict[lane_idx] - for lane_seg in lane.segment_list: - if lane_seg.type == 'Straight': - start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral - end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral - ax.plot([start1[0], end1[0]],[start1[1], end1[1]], color) - start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral - end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral - ax.plot([start2[0], end2[0]],[start2[1], end2[1]], color) - elif lane_seg.type == "Circular": - phase_array = np.linspace(start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100) - r1 = lane_seg.radius - lane_seg.width/2 - x = np.cos(phase_array)*r1 + lane_seg.center[0] - y = np.sin(phase_array)*r1 + lane_seg.center[1] - ax.plot(x,y,color) - - r2 = lane_seg.radius + lane_seg.width/2 - x = np.cos(phase_array)*r2 + lane_seg.center[0] - y = np.sin(phase_array)*r2 + lane_seg.center[1] - ax.plot(x,y,color) - else: - raise ValueError(f'Unknown lane segment type {lane_seg.type}') - return fig - -def plot_simulation_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None): - if fig is None: - fig = plt.figure() - - ax = fig.gca() - if x_lim is None: - x_lim = ax.get_xlim() - if y_lim is None: - y_lim = ax.get_ylim() - - x_min, x_max = x_lim - y_min, y_max = y_lim - - queue = [root] - while queue != []: - node = queue.pop(0) - traces = node.trace - trace = np.array(traces[agent_id]) - for y_dim in y_dim_list: - ax.plot(trace[:,x_dim], trace[:,y_dim], color) - x_min = min(x_min, trace[:,x_dim].min()) - x_max = max(x_max, trace[:,x_dim].max()) - - y_min = min(y_min, trace[:,y_dim].min()) - y_max = max(y_max, trace[:,y_dim].max()) - - queue += node.child - ax.set_xlim([x_min-1, x_max+1]) - ax.set_ylim([y_min-1, y_max+1]) - - return fig - -def generate_simulation_anime(root, map, fig = None): - if fig is None: - fig = plt.figure() - fig = plot_map(map, 'g', fig) - timed_point_dict = {} - stack = [root] - ax = fig.gca() - x_min, x_max = float('inf'), -float('inf') - y_min, y_max = ax.get_ylim() - while stack != []: - node = stack.pop() - traces = node.trace - for agent_id in traces: - trace = traces[agent_id] - color = 'b' - if agent_id == 'car2': - color = 'r' - for i in range(len(trace)): - x_min = min(x_min, trace[i][1]) - x_max = max(x_max, trace[i][1]) - y_min = min(y_min, trace[i][2]) - y_max = max(y_max, trace[i][2]) - if round(trace[i][0],5) not in timed_point_dict: - timed_point_dict[round(trace[i][0],5)] = [(trace[i][1:],color)] - else: - timed_point_dict[round(trace[i][0],5)].append((trace[i][1:],color)) - stack += node.child - - frames = [] - for time_point in timed_point_dict: - point_list = timed_point_dict[time_point] - plt.xlim((x_min-2, x_max+2)) - plt.ylim((y_min-2, y_max+2)) - plot_map(map,color = 'g', fig = fig) - for data in point_list: - point = data[0] - color = data[1] - ax = plt.gca() - ax.plot([point[0]], [point[1]], markerfacecolor = color, markeredgecolor = color, marker = '.', markersize = 20) - x_tail = point[0] - y_tail = point[1] - dx = np.cos(point[2])*point[3] - dy = np.sin(point[2])*point[3] - ax.arrow(x_tail, y_tail, dx, dy, head_width = 1, head_length = 0.5) - plt.pause(0.05) - plt.clf() - # img_buf = io.BytesIO() - # plt.savefig(img_buf, format = 'png') - # im = Image.open(img_buf) - # frames.append(im) - # plt.clf() - # frame_one = frames[0] - # frame_one.save(fn, format = "GIF", append_images = frames, save_all = True, duration = 100, loop = 0)