plotter2D.py 6.34 KiB
"""
This file consist main plotter code for DryVR reachtube output
"""
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from typing import List
from PIL import Image, ImageDraw
import io
colors = ['red', 'green', 'blue', 'yellow', 'black']
def plot(
data,
x_dim: int = 0,
y_dim_list: List[int] = [1],
color = 'b',
fig = None,
x_lim = None,
y_lim = None
):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
x_min, x_max = x_lim
y_min, y_max = y_lim
for rect in data:
lb = rect[0]
ub = rect[1]
for y_dim in y_dim_list:
rect_patch = patches.Rectangle((lb[x_dim], lb[y_dim]), ub[x_dim]-lb[x_dim], ub[y_dim]-lb[y_dim], color = color)
ax.add_patch(rect_patch)
x_min = min(lb[x_dim], x_min)
y_min = min(lb[y_dim], y_min)
x_max = max(ub[x_dim], x_max)
y_max = max(ub[y_dim], y_max)
ax.set_xlim([x_min-1, x_max+1])
ax.set_ylim([y_min-1, y_max+1])
return fig, (x_min, x_max), (y_min, y_max)
def plot_reachtube_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
queue = [root]
while queue != []:
node = queue.pop(0)
traces = node.trace
trace = traces[agent_id]
data = []
for i in range(0,len(trace),2):
data.append([trace[i], trace[i+1]])
fig, x_lim, y_lim = plot(data, x_dim, y_dim_list, color, fig, x_lim, y_lim)
queue += node.child
return fig
def plot_map(map, color = 'b', fig = None, x_lim = None,y_lim = None):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
for lane_idx in map.lane_dict:
lane = map.lane_dict[lane_idx]
for lane_seg in lane.segment_list:
if lane_seg.type == 'Straight':
start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
ax.plot([start1[0], end1[0]],[start1[1], end1[1]], color)
start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
ax.plot([start2[0], end2[0]],[start2[1], end2[1]], color)
elif lane_seg.type == "Circular":
phase_array = np.linspace(start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
r1 = lane_seg.radius - lane_seg.width/2
x = np.cos(phase_array)*r1 + lane_seg.center[0]
y = np.sin(phase_array)*r1 + lane_seg.center[1]
ax.plot(x,y,color)
r2 = lane_seg.radius + lane_seg.width/2
x = np.cos(phase_array)*r2 + lane_seg.center[0]
y = np.sin(phase_array)*r2 + lane_seg.center[1]
ax.plot(x,y,color)
else:
raise ValueError(f'Unknown lane segment type {lane_seg.type}')
return fig
def plot_simulation_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
x_min, x_max = x_lim
y_min, y_max = y_lim
queue = [root]
while queue != []:
node = queue.pop(0)
traces = node.trace
trace = np.array(traces[agent_id])
for y_dim in y_dim_list:
ax.plot(trace[:,x_dim], trace[:,y_dim], color)
x_min = min(x_min, trace[:,x_dim].min())
x_max = max(x_max, trace[:,x_dim].max())
y_min = min(y_min, trace[:,y_dim].min())
y_max = max(y_max, trace[:,y_dim].max())
queue += node.child
ax.set_xlim([x_min-1, x_max+1])
ax.set_ylim([y_min-1, y_max+1])
return fig
def generate_simulation_anime(root, map, fig = None):
if fig is None:
fig = plt.figure()
fig = plot_map(map, 'g', fig)
timed_point_dict = {}
stack = [root]
ax = fig.gca()
x_min, x_max = float('inf'), -float('inf')
y_min, y_max = ax.get_ylim()
while stack != []:
node = stack.pop()
traces = node.trace
for agent_id in traces:
trace = traces[agent_id]
color = 'b'
if agent_id == 'car2':
color = 'r'
for i in range(len(trace)):
x_min = min(x_min, trace[i][1])
x_max = max(x_max, trace[i][1])
y_min = min(y_min, trace[i][2])
y_max = max(y_max, trace[i][2])
if round(trace[i][0],5) not in timed_point_dict:
timed_point_dict[round(trace[i][0],5)] = [(trace[i][1:],color)]
else:
timed_point_dict[round(trace[i][0],5)].append((trace[i][1:],color))
stack += node.child
frames = []
for time_point in timed_point_dict:
point_list = timed_point_dict[time_point]
plt.xlim((x_min-2, x_max+2))
plt.ylim((y_min-2, y_max+2))
plot_map(map,color = 'g', fig = fig)
for data in point_list:
point = data[0]
color = data[1]
ax = plt.gca()
ax.plot([point[0]], [point[1]], markerfacecolor = color, markeredgecolor = color, marker = '.', markersize = 20)
x_tail = point[0]
y_tail = point[1]
dx = np.cos(point[2])*point[3]
dy = np.sin(point[2])*point[3]
ax.arrow(x_tail, y_tail, dx, dy, head_width = 1, head_length = 0.5)
plt.pause(0.05)
plt.clf()
# img_buf = io.BytesIO()
# plt.savefig(img_buf, format = 'png')
# im = Image.open(img_buf)
# frames.append(im)
# plt.clf()
# frame_one = frames[0]
# frame_one.save(fn, format = "GIF", append_images = frames, save_all = True, duration = 100, loop = 0)