diff --git a/demo.py b/demo.py index 827c97bf6aace862aef68a55868e3a5339115e9e..c0a62d944dda4dd3f13f2920330e195483d36223 100644 --- a/demo.py +++ b/demo.py @@ -54,13 +54,13 @@ if __name__ == "__main__": res_list = scenario.simulate_multi(40,1) # traces = scenario.verify(40) - fig = plt.figure(2) - fig,x_lim,y_lim = plot_map(tmp_map, 'g', fig) + # fig = plt.figure(2) + # fig,x_lim,y_lim = plot_map(tmp_map, 'g', fig) # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) for traces in res_list: - fig,x_lim,y_lim = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig,x_lim,y_lim) - fig,x_lim,y_lim = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig,x_lim,y_lim) - + generate_simulation_anime(traces, tmp_map) + # fig,x_lim,y_lim = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig,x_lim,y_lim) + # fig,x_lim,y_lim = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig,x_lim,y_lim) plt.show() diff --git a/demo2.py b/demo2.py index d49571b53d472405d69fc53946fe652ce1a26e2e..4ec0bec812a8615dcfc1353dd710cce679a32b44 100644 --- a/demo2.py +++ b/demo2.py @@ -51,16 +51,18 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane1), ] ) - res_list = scenario.simulate_multi(10,10) + res_list = scenario.simulate_multi(10,1) # traces = scenario.verify(10) - fig = plt.figure(2) + # fig = plt.figure(2) # fig,x_lim,y_lim = plot_map(tmp_map, 'g', fig) # fig,x_lim,y_lim = plot_reachtube_tree(traces, 'car1', 0, [1], 'b', fig) # fig,x_lim,y_lim = plot_reachtube_tree(traces, 'car2', 0, [1], 'r', fig,x_lim,y_lim) for traces in res_list: - fig,x_lim,y_lim = plot_simulation_tree(traces, 'car1', 0, [1], 'b', fig) - fig,x_lim,y_lim = plot_simulation_tree(traces, 'car2', 0, [1], 'r', fig, x_lim, y_lim) + # fig,x_lim,y_lim = plot_simulation_tree(traces, 'car1', 0, [1], 'b', fig) + # fig,x_lim,y_lim = plot_simulation_tree(traces, 'car2', 0, [1], 'r', fig, x_lim, y_lim) + generate_simulation_anime(traces, tmp_map) + plt.show() diff --git a/example_controller1.py b/example_controller1.py index 788d4b77980e7a691bd79cb0653f72122cc80931..5a870473b884ae50897f8fc7d4ea6fb409cd9216 100644 --- a/example_controller1.py +++ b/example_controller1.py @@ -2,6 +2,13 @@ from enum import Enum, auto import copy from src.scene_verifier.map.lane_map import LaneMap +class LaneObjectMode(Enum): + Vehicle = auto() + Ped = auto() # Pedestrians + Sign = auto() # Signs, stop signs, merge, yield etc. + Signal = auto() # Traffic lights + Obstacle = auto() # Static (to road/lane) obstacles + class VehicleMode(Enum): Normal = auto() SwitchLeft = auto() diff --git a/example_controller3.py b/example_controller3.py new file mode 100644 index 0000000000000000000000000000000000000000..504995e3601fa244237c2c91e5edfe4e6d8f3716 --- /dev/null +++ b/example_controller3.py @@ -0,0 +1,45 @@ +from enum import Enum, auto +import copy +from src.scene_verifier.map.lane_map import LaneMap + +class VehicleMode(Enum): + Normal = auto() + SwitchLeft = auto() + SwitchRight = auto() + Brake = auto() + +class LaneMode(Enum): + Lane0 = auto() + Lane1 = auto() + Lane2 = auto() + +class State: + x = 0.0 + y = 0.0 + theta = 0.0 + v = 0.0 + vehicle_mode: VehicleMode = VehicleMode.Normal + lane_mode: LaneMode = LaneMode.Lane0 + obj_mode: LaneObjectMode + + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode): + self.data = [] + +def controller(ego:State, other:State, sign:State, lane_map:LaneMap): + output = copy.deepcopy(ego) + if ego.vehicle_mode == VehicleMode.Normal: + if sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode: + output.vehicle_mode = VehicleMode.Brake + return output + + if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 0 \ + and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 3 \ + and ego.lane_mode == other.lane_mode: + output.vehicle_mode = VehicleMode.Brake + elif ego.vehicle_mode == VehicleMode.Brake: + if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 10 \ + or ego.lane_mode != other.lane_mode: + output.vehicle_mode = VehicleMode.Normal + + return output + diff --git a/example_two_car_sign_lane_switch.py b/example_two_car_sign_lane_switch.py index c0f5a2dfafbfd625d286a87acc192c363b1f8a7e..2593af3736e94a4208cfb72df97a3efe902b9320 100644 --- a/example_two_car_sign_lane_switch.py +++ b/example_two_car_sign_lane_switch.py @@ -81,7 +81,7 @@ if __name__ == "__main__": input_code_name = sys.argv[0] scenario = Scenario() - car = CarAgent('car1', file_name=input_code_name) + car = CarAgent('car1', file_name="example_controller3.py") scenario.add_agent(car) car = CarAgent('car2', file_name=input_code_name) scenario.add_agent(car) @@ -90,23 +90,23 @@ if __name__ == "__main__": scenario.set_sensor(FakeSensor2()) scenario.set_init( [ - [[10, -3, 0, 0.5],[10, -3, 0, 0.5]], - [[-0.2, -0.2, 0, 1.0],[-0.2, -0.2, 0, 1.0]], - [[20, 3, 0, 0],[20, 3, 0, 0]], + [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], + [[0, -0.2, 0, 1.0],[0.2, -0.2, 0, 1.0]], + [[20, 0, 0, 0],[20, 0, 0, 0]], ], [ - (VehicleMode.Normal, LaneMode.Lane2), + (VehicleMode.Normal, LaneMode.Lane1), (VehicleMode.Normal, LaneMode.Lane1), (VehicleMode.Normal, LaneMode.Lane1), ] ) # simulator = Simulator() - # traces = scenario.simulate(40) - traces = scenario.verify(40) + traces = scenario.simulate(40) + # traces = scenario.verify(40) fig = plt.figure() - fig, xlim, ylim = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) - fig, xlim, ylim = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig, xlim, ylim) + fig, xlim, ylim = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig) + fig, xlim, ylim = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig, xlim, ylim) plt.show() diff --git a/src/plotter/plotter2D.py b/src/plotter/plotter2D.py index f1722480acc43f7fc6badabcd461350d6d9891bb..a89c77c1357d47e79b1b7b8219e9214e7451791a 100644 --- a/src/plotter/plotter2D.py +++ b/src/plotter/plotter2D.py @@ -6,6 +6,8 @@ import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np from typing import List +from PIL import Image, ImageDraw +import io colors = ['red', 'green', 'blue', 'yellow', 'black'] @@ -99,3 +101,54 @@ def plot_simulation_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1] ax.set_ylim([y_min-1, y_max+1]) return fig, ax.get_xlim(), ax.get_ylim() + +def generate_simulation_anime(root, map): + timed_point_dict = {} + stack = [root] + x_min, x_max = float('inf'), -float('inf') + y_min, y_max = float('inf'), -float('inf') + while stack != []: + node = stack.pop() + traces = node.trace + for agent_id in traces: + trace = traces[agent_id] + color = 'b' + if agent_id == 'car2': + color = 'r' + for i in range(len(trace)): + x_min = min(x_min, trace[i][1]) + x_max = max(x_max, trace[i][1]) + y_min = min(y_min, trace[i][2]) + y_max = max(y_max, trace[i][2]) + if round(trace[i][0],5) not in timed_point_dict: + timed_point_dict[round(trace[i][0],5)] = [(trace[i][1:],color)] + else: + timed_point_dict[round(trace[i][0],5)].append((trace[i][1:],color)) + stack += node.child + + frames = [] + fig = plt.figure() + for time_point in timed_point_dict: + point_list = timed_point_dict[time_point] + plt.xlim((x_min-1, x_max+1)) + plt.ylim((y_min-1, y_max+1)) + plot_map(map,color = 'g', fig = fig) + for data in point_list: + point = data[0] + color = data[1] + ax = plt.gca() + ax.plot([point[0]], [point[1]], markerfacecolor = color, markeredgecolor = color, marker = '.', markersize = 20) + x_tail = point[0] + y_tail = point[1] + dx = np.cos(point[2])*point[3] + dy = np.sin(point[2])*point[3] + ax.arrow(x_tail, y_tail, dx, dy, head_width = 1, head_length = 0.5) + plt.pause(0.05) + plt.clf() + # img_buf = io.BytesIO() + # plt.savefig(img_buf, format = 'png') + # im = Image.open(img_buf) + # frames.append(im) + # plt.clf() + # frame_one = frames[0] + # frame_one.save(fn, format = "GIF", append_images = frames, save_all = True, duration = 100, loop = 0)