From 1706c3b5344115ef711cafcee4c2eeda34d52667 Mon Sep 17 00:00:00 2001
From: crides <zhuhaoqing@live.cn>
Date: Mon, 18 Jul 2022 16:25:15 -0500
Subject: [PATCH] unsafe: simple simulation tree visualization

---
 dryvr_plus_plus/plotter/plotter2D.py | 31 ++++++++++++++++++----------
 1 file changed, 20 insertions(+), 11 deletions(-)

diff --git a/dryvr_plus_plus/plotter/plotter2D.py b/dryvr_plus_plus/plotter/plotter2D.py
index 0f8cd26c..6bd9e996 100644
--- a/dryvr_plus_plus/plotter/plotter2D.py
+++ b/dryvr_plus_plus/plotter/plotter2D.py
@@ -7,6 +7,7 @@ import numpy as np
 from math import pi
 import plotly.graph_objects as go
 from typing import List, Tuple
+from plotly.graph_objs.scatter import Marker
 import copy
 # from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
 
@@ -210,6 +211,7 @@ def reachtube_tree(root, map=None, fig=go.Figure(), x_dim: int = 1, y_dim=2, map
     while queue != []:
         node = queue.pop(0)
         traces = node.trace
+        print({k: len(v) for k, v in traces.items()})
         i = 0
         for agent_id in traces:
             trace = np.array(traces[agent_id])
@@ -260,7 +262,6 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
             root, agent_id, fig, x_dim, y_dim, scheme_list[i], print_dim_list)
         i = (i+1) % 12
     if scale_type == 'trace':
-        queue = [root]
         x_min, x_max = float('inf'), -float('inf')
         y_min, y_max = float('inf'), -float('inf')
     i = 0
@@ -280,22 +281,30 @@ def simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
                 x_max = max(x_max, max(trace[:, x_dim]))
                 y_min = min(y_min, min(trace[:, y_dim]))
                 y_max = max(y_max, max(trace[:, y_dim]))
-            i = agent_list.index(agent_id)
             mode_point_color = colors[agent_list.index(agent_id) % 12][0]
             if previous_mode[agent_id] != node.mode[agent_id]:
                 text_pos, text = get_text_pos(node.mode[agent_id][0])
+                texts = [f"{agent_id}: {text}" for _ in trace]
+                mark_colors = [mode_point_color for _ in trace]
+                mark_sizes = [0 for _ in trace]
+                if node.assert_hits != None and agent_id in node.assert_hits:
+                    mark_colors[-1] = "black"
+                    mark_sizes[-1] = 10
+                    texts[-1] = "BOOM!!!\nAssertions hit:\n" + "\n".join("  " + a for a in node.assert_hits[agent_id])
+                marker = Marker(color=mark_colors, size=mark_sizes)
                 fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
-                                         mode='markers+text',
+                                         mode='markers+lines',
                                          line_color=mode_point_color,
                                          opacity=0.5,
-                                         text=str(agent_id)+': ' + text,
+                                         text=texts,
+                                         marker=marker,
                                          textposition=text_pos,
                                          textfont=dict(
-                    size=text_size,
-                    color=mode_text_color
-                ),
-                    showlegend=False,
-                ))
+                                             size=text_size,
+                                             color=mode_text_color
+                                         ),
+                                         showlegend=False,
+                                         ))
                 previous_mode[agent_id] = node.mode[agent_id]
         queue += node.child
     if scale_type == 'trace':
@@ -706,7 +715,7 @@ def reachtube_tree_single(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim
     return fig
 
 
-def simulation_tree_single(root, agent_id, fig: go.Figure() = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None):
+def simulation_tree_single(root, agent_id, fig: go.Figure = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, print_dim_list=None):
     """It statically shows the simulation traces of one given agent."""
     global color_cnt
     queue = [root]
@@ -882,7 +891,7 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty
     return fig
 
 
-def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List(int) = None):
+def check_dim(num_dim: int, x_dim: int = 1, y_dim: int = 2, print_dim_list: List[int] = None):
     if x_dim <= 0 or x_dim >= num_dim:
         raise ValueError(f'wrong x dimension value {x_dim}')
     if y_dim <= 0 or y_dim >= num_dim:
-- 
GitLab