From 1706c3b5344115ef711cafcee4c2eeda34d52667 Mon Sep 17 00:00:00 2001 From: crides <zhuhaoqing@live.cn> Date: Mon, 18 Jul 2022 16:25:15 -0500 Subject: [PATCH] unsafe: simple simulation tree visualization --- dryvr_plus_plus/plotter/plotter2D.py | 31 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/dryvr_plus_plus/plotter/plotter2D.py b/dryvr_plus_plus/plotter/plotter2D.py index 0f8cd26c..6bd9e996 100644 --- a/dryvr_plus_plus/plotter/plotter2D.py +++ b/dryvr_plus_plus/plotter/plotter2D.py @@ -7,6 +7,7 @@ import numpy as np from math import pi import plotly.graph_objects as go from typing import List, Tuple +from plotly.graph_objs.scatter import Marker import copy # from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode @@ -210,6 +211,7 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map while queue != []: node = queue.pop(0) traces = node.trace + print({k: len(v) for k, v in traces.items()}) i = 0 for agent_id in traces: trace = np.array(traces[agent_id]) @@ -260,7 +262,6 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type= root, agent_id, fig, x_dim, y_dim, scheme_list[i], print_dim_list) i = (i+1) % 12 if scale_type == 'trace': - queue = [root] x_min, x_max = float('inf'), -float('inf') y_min, y_max = float('inf'), -float('inf') i = 0 @@ -280,22 +281,30 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type= x_max = max(x_max, max(trace[:, x_dim])) y_min = min(y_min, min(trace[:, y_dim])) y_max = max(y_max, max(trace[:, y_dim])) - i = agent_list.index(agent_id) mode_point_color = colors[agent_list.index(agent_id) % 12][0] if previous_mode[agent_id] != node.mode[agent_id]: text_pos, text = get_text_pos(node.mode[agent_id][0]) + texts = [f"{agent_id}: {text}" for _ in trace] + mark_colors = [mode_point_color for _ in trace] + mark_sizes = [0 for _ in trace] + if node.assert_hits != None and agent_id in node.assert_hits: + mark_colors[-1] = "black" + mark_sizes[-1] = 10 + texts[-1] = "BOOM!!!\nAssertions hit:\n" + "\n".join(" " + a for a in node.assert_hits[agent_id]) + marker = Marker(color=mark_colors, size=mark_sizes) fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]], - mode='markers+text', + mode='markers+lines', line_color=mode_point_color, opacity=0.5, - text=str(agent_id)+': ' + text, + text=texts, + marker=marker, textposition=text_pos, textfont=dict( - size=text_size, - color=mode_text_color - ), - showlegend=False, - )) + size=text_size, + color=mode_text_color + ), + showlegend=False, + )) previous_mode[agent_id] = node.mode[agent_id] queue += node.child if scale_type == 'trace': @@ -706,7 +715,7 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim return fig -def simulation_tree_single(root, agent_id, fig: go.Figure() = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): +def simulation_tree_single(root, agent_id, fig: go.Figure = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None): """It statically shows the simulation traces of one given agent.""" global color_cnt queue = [root] @@ -882,7 +891,7 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty return fig -def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List(int) = None): +def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List[int] = None): if x_dim <= 0 or x_dim >= num_dim: raise ValueError(f'wrong x dimension value {x_dim}') if y_dim <= 0 or y_dim >= num_dim: -- GitLab