Skip to content
Snippets Groups Projects
Commit a7bdfd6f authored by keyis2's avatar keyis2
Browse files

plotly-based visualization

parent f7e6c2b2
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree, generate_simulation_anime, plot_map
from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3
from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent
from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
from enum import Enum, auto
import copy
class LaneObjectMode(Enum):
Vehicle = auto()
Ped = auto() # Pedestrians
......@@ -8,6 +16,7 @@ class LaneObjectMode(Enum):
Signal = auto() # Traffic lights
Obstacle = auto() # Static (to road/lane) obstacles
class VehicleMode(Enum):
Normal = auto()
SwitchLeft = auto()
......@@ -20,6 +29,7 @@ class LaneMode(Enum):
Lane1 = auto()
Lane2 = auto()
class State:
x: float
y: float
......@@ -40,24 +50,25 @@ class State:
# self.lane_mode = lane_mode
# self.obj_mode = obj_mode
def controller(ego: State, other: State, sign: State, lane_map):
output = copy.deepcopy(ego)
if ego.vehicle_mode == VehicleMode.Normal:
if sign.type == LaneObjectMode.Obstacle and sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode:
output.vehicle_mode = VehicleMode.SwitchLeft
return output
if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \
and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \
and ego.lane_mode == other.lane_mode:
if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
and ego.lane_mode == other.lane_mode:
if lane_map.has_left(ego.lane_mode):
output.vehicle_mode = VehicleMode.SwitchLeft
if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \
and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \
and ego.lane_mode == other.lane_mode:
if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
and ego.lane_mode == other.lane_mode:
if lane_map.has_right(ego.lane_mode):
output.vehicle_mode = VehicleMode.SwitchRight
if ego.vehicle_mode == VehicleMode.SwitchLeft:
if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5:
if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5:
output.vehicle_mode = VehicleMode.Normal
output.lane_mode = lane_map.left_lane(ego.lane_mode)
if ego.vehicle_mode == VehicleMode.SwitchRight:
......@@ -68,15 +79,6 @@ def controller(ego: State, other: State, sign: State, lane_map):
return output
from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent
from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3
from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree
from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
import matplotlib.pyplot as plt
if __name__ == "__main__":
import sys
input_code_name = sys.argv[0]
......@@ -91,9 +93,9 @@ if __name__ == "__main__":
scenario.set_sensor(FakeSensor2())
scenario.set_init(
[
[[0, -0.2, 0, 1.0],[0.2, 0.2, 0, 1.0]],
[[10, 0, 0, 0.5],[10, 0, 0, 0.5]],
[[20, 0, 0, 0],[20, 0, 0, 0]],
[[0, -0.2, 0, 1.0], [0.2, 0.2, 0, 1.0]],
[[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
[[20, 0, 0, 0], [20, 0, 0, 0]],
],
[
(VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
......@@ -106,8 +108,11 @@ if __name__ == "__main__":
# traces = scenario.verify(40)
fig = plt.figure()
fig = plot_map(SimpleMap3(), 'g', fig)
fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
# fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
# fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
# generate_simulation_anime(traces, SimpleMap3(), fig)
plt.show()
from dash import Dash, dcc, html, Input, Output
import plotly.express as px
app = Dash(__name__)
app.layout = html.Div([
html.H4('Animated GDP and population over decades'),
html.P("Select an animation:"),
dcc.RadioItems(
id='selection',
options=["GDP - Scatter", "Population - Bar"],
value='GDP - Scatter',
),
dcc.Loading(dcc.Graph(id="graph"), type="cube")
])
@app.callback(
Output("graph", "figure"),
Input("selection", "value"))
def display_animated_graph(selection):
df = px.data.gapminder() # replace with your own data source
animations = {
'GDP - Scatter': px.scatter(
df, x="gdpPercap", y="lifeExp", animation_frame="year",
animation_group="country", size="pop", color="continent",
hover_name="country", log_x=True, size_max=55,
range_x=[100, 100000], range_y=[25, 90]),
'Population - Bar': px.bar(
df, x="continent", y="pop", color="continent",
animation_frame="year", animation_group="country",
range_y=[0, 4000000000]),
}
return animations[selection]
app.run_server()
import plotly.graph_objects as go
import pandas as pd
url = "https://raw.githubusercontent.com/plotly/datasets/master/gapminderDataFiveYear.csv"
dataset = pd.read_csv(url)
years = ["1952", "1962", "1967", "1972", "1977", "1982", "1987", "1992", "1997", "2002",
"2007"]
# make list of continents
continents = []
for continent in dataset["continent"]:
if continent not in continents:
continents.append(continent)
# make figure
fig_dict = {
"data": [],
"layout": {},
"frames": []
}
# fill in most of layout
fig_dict["layout"]["xaxis"] = {"range": [30, 85], "title": "Life Expectancy"}
fig_dict["layout"]["yaxis"] = {"title": "GDP per Capita", "type": "log"}
fig_dict["layout"]["hovermode"] = "closest"
fig_dict["layout"]["updatemenus"] = [
{
"buttons": [
{
"args": [None, {"frame": {"duration": 500, "redraw": False},
"fromcurrent": True, "transition": {"duration": 300,
"easing": "quadratic-in-out"}}],
"label": "Play",
"method": "animate"
},
{
"args": [[None], {"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0}}],
"label": "Pause",
"method": "animate"
}
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top"
}
]
sliders_dict = {
"active": 0,
"yanchor": "top",
"xanchor": "left",
"currentvalue": {
"font": {"size": 20},
"prefix": "Year:",
"visible": False,
"xanchor": "right"
},
"transition": {"duration": 300, "easing": "cubic-in-out"},
"pad": {"b": 10, "t": 50},
"len": 0.9,
"x": 0.1,
"y": 0,
"steps": []
}
# make data
year = 1952
for continent in continents:
dataset_by_year = dataset[dataset["year"] == year]
dataset_by_year_and_cont = dataset_by_year[
dataset_by_year["continent"] == continent]
data_dict = {
"x": list(dataset_by_year_and_cont["lifeExp"]),
"y": list(dataset_by_year_and_cont["gdpPercap"]),
"mode": "lines",
"text": list(dataset_by_year_and_cont["country"]),
"marker": {
"sizemode": "area",
"sizeref": 200000,
"size": list(dataset_by_year_and_cont["pop"])
},
"name": continent
}
fig_dict["data"].append(data_dict)
# make frames
for year in years:
frame = {"data": [], "name": str(year)}
for continent in continents:
dataset_by_year = dataset[dataset["year"] == int(year)]
dataset_by_year_and_cont = dataset_by_year[
dataset_by_year["continent"] == continent]
data_dict = {
"x": list(dataset_by_year_and_cont["lifeExp"]),
"y": list(dataset_by_year_and_cont["gdpPercap"]),
"mode": "lines",
"text": list(dataset_by_year_and_cont["country"]),
"marker": {
"sizemode": "area",
"sizeref": 200000,
"size": list(dataset_by_year_and_cont["pop"])
},
"name": continent
}
frame["data"].append(data_dict)
fig_dict["frames"].append(frame)
slider_step = {"args": [
[year],
{"frame": {"duration": 300, "redraw": False},
"mode": "immediate",
"transition": {"duration": 300}}
],
"label": year,
"method": "animate"}
sliders_dict["steps"].append(slider_step)
fig_dict["layout"]["sliders"] = [sliders_dict]
fig = go.Figure(fig_dict)
fig.show()
import plotly.graph_objects as go
import numpy as np
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
x_rev = x[::-1]
# Line 1
y1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y1_upper = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
y1_lower = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
y1_lower = y1_lower[::-1]
# Line 2
y2 = [5, 2.5, 5, 7.5, 5, 2.5, 7.5, 4.5, 5.5, 5]
y2_upper = [5.5, 3, 5.5, 8, 6, 3, 8, 5, 6, 5.5]
y2_lower = [4.5, 2, 4.4, 7, 4, 2, 7, 4, 5, 4.75]
y2_lower = y2_lower[::-1]
# Line 3
y3 = [10, 8, 6, 4, 2, 0, 2, 4, 2, 0]
y3_upper = [11, 9, 7, 5, 3, 1, 3, 5, 3, 1]
y3_lower = [9, 7, 5, 3, 1, -.5, 1, 3, 1, -1]
y3_lower = y3_lower[::-1]
fig = go.Figure()
fig.add_trace(go.Scatter(
x=x+x_rev,
y=y1_upper+y1_lower,
# fill='toself',
# fillcolor='rgba(0,100,80,0.2)',
# line_color='rgba(255,255,255,0)',
# showlegend=False,
name='Fair',
))
# fig.add_trace(go.Scatter(
# x=x+x_rev,
# y=y2_upper+y2_lower,
# fill='toself',
# fillcolor='rgba(0,176,246,0.2)',
# line_color='rgba(255,255,255,0)',
# name='Premium',
# showlegend=False,
# ))
# fig.add_trace(go.Scatter(
# x=x+x_rev,
# y=y3_upper+y3_lower,
# fill='toself',
# fillcolor='rgba(231,107,243,0.2)',
# line_color='rgba(255,255,255,0)',
# showlegend=False,
# name='Ideal',
# ))
fig.add_trace(go.Scatter(
x=x, y=y1,
line_color='rgb(0,100,80)',
name='Fair',
))
# fig.add_trace(go.Scatter(
# x=x, y=y2,
# line_color='rgb(0,176,246)',
# name='Premium',
# ))
# fig.add_trace(go.Scatter(
# x=x, y=y3,
# line_color='rgb(231,107,243)',
# name='Ideal',
# ))
fig.update_traces(mode='lines')
fig.show()
print(x+x_rev)
print(y1_upper+y1_lower)
"""
This file consist main plotter code for DryVR reachtube output
"""
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from typing import List
from PIL import Image, ImageDraw
import io
colors = ['red', 'green', 'blue', 'yellow', 'black']
def plot(
data,
x_dim: int = 0,
y_dim_list: List[int] = [1],
color = 'b',
fig = None,
x_lim = None,
y_lim = None
):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
x_min, x_max = x_lim
y_min, y_max = y_lim
for rect in data:
lb = rect[0]
ub = rect[1]
for y_dim in y_dim_list:
rect_patch = patches.Rectangle((lb[x_dim], lb[y_dim]), ub[x_dim]-lb[x_dim], ub[y_dim]-lb[y_dim], color = color)
ax.add_patch(rect_patch)
x_min = min(lb[x_dim], x_min)
y_min = min(lb[y_dim], y_min)
x_max = max(ub[x_dim], x_max)
y_max = max(ub[y_dim], y_max)
ax.set_xlim([x_min-1, x_max+1])
ax.set_ylim([y_min-1, y_max+1])
return fig, (x_min, x_max), (y_min, y_max)
def plot_reachtube_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
queue = [root]
while queue != []:
node = queue.pop(0)
traces = node.trace
trace = traces[agent_id]
data = []
for i in range(0,len(trace),2):
data.append([trace[i], trace[i+1]])
fig, x_lim, y_lim = plot(data, x_dim, y_dim_list, color, fig, x_lim, y_lim)
queue += node.child
return fig
def plot_map(map, color = 'b', fig = None, x_lim = None,y_lim = None):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
for lane_idx in map.lane_dict:
lane = map.lane_dict[lane_idx]
for lane_seg in lane.segment_list:
if lane_seg.type == 'Straight':
start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
ax.plot([start1[0], end1[0]],[start1[1], end1[1]], color)
start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
ax.plot([start2[0], end2[0]],[start2[1], end2[1]], color)
elif lane_seg.type == "Circular":
phase_array = np.linspace(start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
r1 = lane_seg.radius - lane_seg.width/2
x = np.cos(phase_array)*r1 + lane_seg.center[0]
y = np.sin(phase_array)*r1 + lane_seg.center[1]
ax.plot(x,y,color)
r2 = lane_seg.radius + lane_seg.width/2
x = np.cos(phase_array)*r2 + lane_seg.center[0]
y = np.sin(phase_array)*r2 + lane_seg.center[1]
ax.plot(x,y,color)
else:
raise ValueError(f'Unknown lane segment type {lane_seg.type}')
return fig
def plot_simulation_tree(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None):
if fig is None:
fig = plt.figure()
ax = fig.gca()
if x_lim is None:
x_lim = ax.get_xlim()
if y_lim is None:
y_lim = ax.get_ylim()
x_min, x_max = x_lim
y_min, y_max = y_lim
queue = [root]
while queue != []:
node = queue.pop(0)
traces = node.trace
trace = np.array(traces[agent_id])
for y_dim in y_dim_list:
ax.plot(trace[:,x_dim], trace[:,y_dim], color)
x_min = min(x_min, trace[:,x_dim].min())
x_max = max(x_max, trace[:,x_dim].max())
y_min = min(y_min, trace[:,y_dim].min())
y_max = max(y_max, trace[:,y_dim].max())
queue += node.child
ax.set_xlim([x_min-1, x_max+1])
ax.set_ylim([y_min-1, y_max+1])
return fig
def generate_simulation_anime(root, map, fig = None):
if fig is None:
fig = plt.figure()
fig = plot_map(map, 'g', fig)
timed_point_dict = {}
stack = [root]
ax = fig.gca()
x_min, x_max = float('inf'), -float('inf')
y_min, y_max = ax.get_ylim()
while stack != []:
node = stack.pop()
traces = node.trace
for agent_id in traces:
trace = traces[agent_id]
color = 'b'
if agent_id == 'car2':
color = 'r'
for i in range(len(trace)):
x_min = min(x_min, trace[i][1])
x_max = max(x_max, trace[i][1])
y_min = min(y_min, trace[i][2])
y_max = max(y_max, trace[i][2])
if round(trace[i][0],5) not in timed_point_dict:
timed_point_dict[round(trace[i][0],5)] = [(trace[i][1:],color)]
else:
timed_point_dict[round(trace[i][0],5)].append((trace[i][1:],color))
stack += node.child
frames = []
for time_point in timed_point_dict:
point_list = timed_point_dict[time_point]
plt.xlim((x_min-2, x_max+2))
plt.ylim((y_min-2, y_max+2))
plot_map(map,color = 'g', fig = fig)
for data in point_list:
point = data[0]
color = data[1]
ax = plt.gca()
ax.plot([point[0]], [point[1]], markerfacecolor = color, markeredgecolor = color, marker = '.', markersize = 20)
x_tail = point[0]
y_tail = point[1]
dx = np.cos(point[2])*point[3]
dy = np.sin(point[2])*point[3]
ax.arrow(x_tail, y_tail, dx, dy, head_width = 1, head_length = 0.5)
plt.pause(0.05)
plt.clf()
# img_buf = io.BytesIO()
# plt.savefig(img_buf, format = 'png')
# im = Image.open(img_buf)
# frames.append(im)
# plt.clf()
# frame_one = frames[0]
# frame_one.save(fn, format = "GIF", append_images = frames, save_all = True, duration = 100, loop = 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment