Skip to content
Snippets Groups Projects
Commit 1706c3b5 authored by crides's avatar crides
Browse files

unsafe: simple simulation tree visualization

parent db4644ae
No related branches found
No related tags found
2 merge requests!7Keyi tmp,!1Merge refactor to main
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
from math import pi from math import pi
import plotly.graph_objects as go import plotly.graph_objects as go
from typing import List, Tuple from typing import List, Tuple
from plotly.graph_objs.scatter import Marker
import copy import copy
# from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode # from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
...@@ -210,6 +211,7 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map ...@@ -210,6 +211,7 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map
while queue != []: while queue != []:
node = queue.pop(0) node = queue.pop(0)
traces = node.trace traces = node.trace
print({k: len(v) for k, v in traces.items()})
i = 0 i = 0
for agent_id in traces: for agent_id in traces:
trace = np.array(traces[agent_id]) trace = np.array(traces[agent_id])
...@@ -260,7 +262,6 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type= ...@@ -260,7 +262,6 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
root, agent_id, fig, x_dim, y_dim, scheme_list[i], print_dim_list) root, agent_id, fig, x_dim, y_dim, scheme_list[i], print_dim_list)
i = (i+1) % 12 i = (i+1) % 12
if scale_type == 'trace': if scale_type == 'trace':
queue = [root]
x_min, x_max = float('inf'), -float('inf') x_min, x_max = float('inf'), -float('inf')
y_min, y_max = float('inf'), -float('inf') y_min, y_max = float('inf'), -float('inf')
i = 0 i = 0
...@@ -280,22 +281,30 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type= ...@@ -280,22 +281,30 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
x_max = max(x_max, max(trace[:, x_dim])) x_max = max(x_max, max(trace[:, x_dim]))
y_min = min(y_min, min(trace[:, y_dim])) y_min = min(y_min, min(trace[:, y_dim]))
y_max = max(y_max, max(trace[:, y_dim])) y_max = max(y_max, max(trace[:, y_dim]))
i = agent_list.index(agent_id)
mode_point_color = colors[agent_list.index(agent_id) % 12][0] mode_point_color = colors[agent_list.index(agent_id) % 12][0]
if previous_mode[agent_id] != node.mode[agent_id]: if previous_mode[agent_id] != node.mode[agent_id]:
text_pos, text = get_text_pos(node.mode[agent_id][0]) text_pos, text = get_text_pos(node.mode[agent_id][0])
texts = [f"{agent_id}: {text}" for _ in trace]
mark_colors = [mode_point_color for _ in trace]
mark_sizes = [0 for _ in trace]
if node.assert_hits != None and agent_id in node.assert_hits:
mark_colors[-1] = "black"
mark_sizes[-1] = 10
texts[-1] = "BOOM!!!\nAssertions hit:\n" + "\n".join(" " + a for a in node.assert_hits[agent_id])
marker = Marker(color=mark_colors, size=mark_sizes)
fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
mode='markers+text', mode='markers+lines',
line_color=mode_point_color, line_color=mode_point_color,
opacity=0.5, opacity=0.5,
text=str(agent_id)+': ' + text, text=texts,
marker=marker,
textposition=text_pos, textposition=text_pos,
textfont=dict( textfont=dict(
size=text_size, size=text_size,
color=mode_text_color color=mode_text_color
), ),
showlegend=False, showlegend=False,
)) ))
previous_mode[agent_id] = node.mode[agent_id] previous_mode[agent_id] = node.mode[agent_id]
queue += node.child queue += node.child
if scale_type == 'trace': if scale_type == 'trace':
...@@ -706,7 +715,7 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim ...@@ -706,7 +715,7 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim
return fig return fig
def simulation_tree_single(root, agent_id, fig: go.Figure() = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): def simulation_tree_single(root, agent_id, fig: go.Figure = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None):
"""It statically shows the simulation traces of one given agent.""" """It statically shows the simulation traces of one given agent."""
global color_cnt global color_cnt
queue = [root] queue = [root]
...@@ -882,7 +891,7 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty ...@@ -882,7 +891,7 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty
return fig return fig
def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List(int) = None): def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List[int] = None):
if x_dim <= 0 or x_dim >= num_dim: if x_dim <= 0 or x_dim >= num_dim:
raise ValueError(f'wrong x dimension value {x_dim}') raise ValueError(f'wrong x dimension value {x_dim}')
if y_dim <= 0 or y_dim >= num_dim: if y_dim <= 0 or y_dim >= num_dim:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment