Skip to content
Snippets Groups Projects
Commit 1706c3b5 authored by crides's avatar crides
Browse files

unsafe: simple simulation tree visualization

parent db4644ae
No related branches found
No related tags found
2 merge requests!7Keyi tmp,!1Merge refactor to main
......@@ -7,6 +7,7 @@ import numpy as np
from math import pi
import plotly.graph_objects as go
from typing import List, Tuple
from plotly.graph_objs.scatter import Marker
import copy
# from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
......@@ -210,6 +211,7 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map
while queue != []:
node = queue.pop(0)
traces = node.trace
print({k: len(v) for k, v in traces.items()})
i = 0
for agent_id in traces:
trace = np.array(traces[agent_id])
......@@ -260,7 +262,6 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
root, agent_id, fig, x_dim, y_dim, scheme_list[i], print_dim_list)
i = (i+1) % 12
if scale_type == 'trace':
queue = [root]
x_min, x_max = float('inf'), -float('inf')
y_min, y_max = float('inf'), -float('inf')
i = 0
......@@ -280,22 +281,30 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
x_max = max(x_max, max(trace[:, x_dim]))
y_min = min(y_min, min(trace[:, y_dim]))
y_max = max(y_max, max(trace[:, y_dim]))
i = agent_list.index(agent_id)
mode_point_color = colors[agent_list.index(agent_id) % 12][0]
if previous_mode[agent_id] != node.mode[agent_id]:
text_pos, text = get_text_pos(node.mode[agent_id][0])
texts = [f"{agent_id}: {text}" for _ in trace]
mark_colors = [mode_point_color for _ in trace]
mark_sizes = [0 for _ in trace]
if node.assert_hits != None and agent_id in node.assert_hits:
mark_colors[-1] = "black"
mark_sizes[-1] = 10
texts[-1] = "BOOM!!!\nAssertions hit:\n" + "\n".join(" " + a for a in node.assert_hits[agent_id])
marker = Marker(color=mark_colors, size=mark_sizes)
fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
mode='markers+text',
mode='markers+lines',
line_color=mode_point_color,
opacity=0.5,
text=str(agent_id)+': ' + text,
text=texts,
marker=marker,
textposition=text_pos,
textfont=dict(
size=text_size,
color=mode_text_color
),
showlegend=False,
))
size=text_size,
color=mode_text_color
),
showlegend=False,
))
previous_mode[agent_id] = node.mode[agent_id]
queue += node.child
if scale_type == 'trace':
......@@ -706,7 +715,7 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim
return fig
def simulation_tree_single(root, agent_id, fig: go.Figure() = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None):
def simulation_tree_single(root, agent_id, fig: go.Figure = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None):
"""It statically shows the simulation traces of one given agent."""
global color_cnt
queue = [root]
......@@ -882,7 +891,7 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty
return fig
def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List(int) = None):
def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List[int] = None):
if x_dim <= 0 or x_dim >= num_dim:
raise ValueError(f'wrong x dimension value {x_dim}')
if y_dim <= 0 or y_dim >= num_dim:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment