From 794d02fed7f23494378599920612489548305ed4 Mon Sep 17 00:00:00 2001
From: keyis2 <keyis2@illinois.edu>
Date: Mon, 27 Jun 2022 05:15:20 -0500
Subject: [PATCH] general Cont.

---
 demo/demo1.py                                 |   4 +-
 demo/demo2.py                                 |   5 +-
 demo/example_controller5.py                   |   2 +-
 dryvr_plus_plus/plotter/plotter2D_new.py      | 474 ++++++++++++++++--
 dryvr_plus_plus/plotter/plotter_README.md     |  36 +-
 dryvr_plus_plus/scene_verifier/map/lane.py    |   7 +-
 .../scene_verifier/map/lane_map.py            |  16 +-
 7 files changed, 480 insertions(+), 64 deletions(-)

diff --git a/demo/demo1.py b/demo/demo1.py
index 2fee44d0..53c582dc 100644
--- a/demo/demo1.py
+++ b/demo/demo1.py
@@ -1,7 +1,7 @@
 from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
-from dryvr_plus_plus.plotter.plotter2D import *
+from dryvr_plus_plus.plotter.plotter2D_new import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
 
 import matplotlib.pyplot as plt
@@ -75,6 +75,6 @@ if __name__ == "__main__":
     fig = go.Figure()
     for traces in res_list:
         # plotly_map(tmp_map, 'g', fig)
-        fig = plotly_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
+        fig = test_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
         # fig = plotly_simulation_anime(traces, tmp_map, fig)
     fig.show()
diff --git a/demo/demo2.py b/demo/demo2.py
index 8e2e4cf7..0d189ed1 100644
--- a/demo/demo2.py
+++ b/demo/demo2.py
@@ -2,6 +2,7 @@ from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
 from dryvr_plus_plus.plotter.plotter2D import *
+from dryvr_plus_plus.plotter.plotter2D_new import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
 import plotly.graph_objects as go
 import matplotlib.pyplot as plt
@@ -82,9 +83,9 @@ if __name__ == "__main__":
     # fig = plotly_reachtube_tree_v2(traces, 'car1', 1, [2], 'blue', fig)
     # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
     # fig.show()
-    traces = scenario.verify(20)
+    traces = scenario.simulate(20)
     fig = go.Figure()
-    fig = generate_reachtube_anime(traces, tmp_map, fig)
+    fig = test_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
     # fig = plotly_simulation_anime(traces, tmp_map, fig)
     # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
     fig.show()
diff --git a/demo/example_controller5.py b/demo/example_controller5.py
index 7ccc13ad..c1272a31 100644
--- a/demo/example_controller5.py
+++ b/demo/example_controller5.py
@@ -55,7 +55,7 @@ def controller(ego: State, other: State, sign: State, lane_map: LaneMap):
         #     output.vehicle_mode = VehicleMode.Accelerate
         # if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 5:
         #     output.vehicle_mode = VehicleMode.Accelerate
-        if lane_map.get_speed_limit(ego.lane_mode, [ego.x, ego.y]) > 1:
+        if lane_map.get_speed_limit(ego.lane_mode) > 1:
             output.vehicle_mode = VehicleMode.Accelerate
 
     if ego.vehicle_mode == VehicleMode.SwitchLeft:
diff --git a/dryvr_plus_plus/plotter/plotter2D_new.py b/dryvr_plus_plus/plotter/plotter2D_new.py
index 1d3de003..6c699244 100644
--- a/dryvr_plus_plus/plotter/plotter2D_new.py
+++ b/dryvr_plus_plus/plotter/plotter2D_new.py
@@ -27,7 +27,7 @@ bg_color = ['rgba(31,119,180,1)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)',
 color_cnt = 0
 
 
-def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines'):
+def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim: int = 2, map_type='lines'):
     # make figure
     fig_dict = {
         "data": [],
@@ -68,7 +68,7 @@ def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, m
                 #         timed_point_dict[round(trace[i][0], 2)].append(
                 #             trace[i][1:].tolist())
                 time_point = round(trace[i][0], 2)
-                rect = [trace[i][1:].tolist(), trace[i+1][1:].tolist()]
+                rect = [trace[i][0:].tolist(), trace[i+1][0:].tolist()]
                 if time_point not in timed_point_dict:
                     timed_point_dict[time_point] = {agent_id: [rect]}
                 else:
@@ -169,8 +169,6 @@ def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, m
         agent_dict = timed_point_dict[time_point]
         trace_x = []
         trace_y = []
-        trace_theta = []
-        trace_v = []
         for agent_id, rect_list in agent_dict.items():
             for rect in rect_list:
                 trace_x.append((rect[0][x_dim]+rect[1][x_dim])/2)
@@ -299,16 +297,36 @@ def draw_reachtube_tree_v2(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_di
 
 
 def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_type='lines'):
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    speed_dict = map.get_all_speed_limit()
+    speed_list = list(filter(None, speed_dict.values()))
+    speed_min = min(speed_list)
+    speed_max = max(speed_list)
+    start_color = [0, 0, 255, 0.2]
+    end_color = [255, 0, 0, 0.2]
+    curr_color = [0, 0, 0, 0]
     for lane_idx in map.lane_dict:
         lane = map.lane_dict[lane_idx]
+        speed_limit = speed_dict[lane_idx]
+        if speed_limit is not None:
+            for j in range(len(curr_color)-1):
+                curr_color[j] += (speed_limit-speed_min)/(speed_max -
+                                                          speed_min)*(end_color[j]-start_color[j])
         for lane_seg in lane.segment_list:
             if lane_seg.type == 'Straight':
                 start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
                 end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
                 start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
                 end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
-                if fill_type == 'lines':
-                    fig.add_trace(go.Scatter(x=[start1[0], end1[0], end2[0], start2[0], start1[0]], y=[start1[1], end1[1], end2[1], start2[1], start1[1]],
+                trace_x = [start1[0], end1[0], end2[0], start2[0], start1[0]]
+                trace_y = [start1[1], end1[1], end2[1], start2[1], start1[1]]
+                x_min = min(x_min, min(trace_x))
+                y_min = min(y_min, min(trace_y))
+                x_max = max(x_max, max(trace_x))
+                y_max = max(y_max, max(trace_y))
+                if fill_type == 'lines' or speed_limit == None:
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
                                              mode='lines',
                                              line_color=color,
                                              #  fill='toself',
@@ -316,8 +334,18 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty
                                              showlegend=False,
                                              # text=theta,
                                              name='lines'))
+                elif fill_type == 'detailed' and speed_limit != None:
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             fill='toself',
+                                             fillcolor='rgba' +
+                                             str(tuple(curr_color)),
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
                 elif fill_type == 'fill':
-                    fig.add_trace(go.Scatter(x=[start1[0], end1[0], end2[0], start2[0], start1[0]], y=[start1[1], end1[1], end2[1], start2[1], start1[1]],
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
                                              mode='lines',
                                              line_color=color,
                                              fill='toself',
@@ -342,21 +370,31 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty
                       lane_seg.center[0]).tolist().reverse()
                 y2 = (np.sin(phase_array)*r2 +
                       lane_seg.center[1]).tolist().reverse()
-                # fig.add_trace(go.Scatter(x=x, y=y,
-                #                          mode='lines',
-                #                          line_color=color,
-                #                          showlegend=False,
-                #                          # text=theta,
-                #                          name='lines'))
+                trace_x = x1+x2+[x1[0]]
+                trace_y = y1+y2+[y1[0]]
+                x_min = min(x_min, min(trace_x))
+                y_min = min(y_min, min(trace_y))
+                x_max = max(x_max, max(trace_x))
+                y_max = max(y_max, max(trace_y))
                 if fill_type == 'lines':
-                    fig.add_trace(go.Scatter(x=x1+x2+[x1[0]], y=y1+y2+[y1[0]],
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
+                elif fill_type == 'detailed' and speed_limit != None:
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
                                              mode='lines',
                                              line_color=color,
+                                             fill='toself',
+                                             fillcolor='rgba' +
+                                             str(tuple(curr_color)),
                                              showlegend=False,
                                              # text=theta,
                                              name='lines'))
                 elif fill_type == 'fill':
-                    fig.add_trace(go.Scatter(x=x1+x2+[x1[0]], y=y1+y2+[y1[0]],
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
                                              mode='lines',
                                              line_color=color,
                                              fill='toself',
@@ -365,6 +403,26 @@ def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_ty
                                              name='lines'))
             else:
                 raise ValueError(f'Unknown lane segment type {lane_seg.type}')
+    fig.add_trace(go.Scatter(x=[0], y=[0],
+                             mode='markers',
+                             # fill='toself',
+                             # fillcolor='rgba'+str(tuple(curr_color)),
+                             marker=dict(
+                                 symbol='square',
+                                 size=16,
+                                 cmax=speed_max,
+                                 cmin=speed_min,
+                                 color='rgba(0,0,0,0)',
+                                 colorbar=dict(
+                                        title="Speed Limit"
+                                 ),
+                                 colorscale=[
+                                     [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+    ),
+        showlegend=False,
+    ))
+    fig.update_xaxes(range=[x_min, x_max])
+    fig.update_yaxes(range=[y_min, y_max])
     return fig
 
 
@@ -488,16 +546,46 @@ def plotly_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure()):
     return fig
 
 
-def draw_simulation_tree(root: AnalysisTreeNode, agent_id, fig=None, x_dim: int = 1, y_dim: int = 2, color_id=None, map_type='lines'):
+def draw_simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace'):
+    fig = draw_map(map=map, fig=fig, fill_type=map_type)
+    # mark
+    agent_list = root.agent.keys()
+    print(agent_list)
+    i = 0
+    for agent_id in agent_list:
+        fig = draw_simulation_tree_single(
+            root, agent_id, fig, x_dim, y_dim, i, map_type)
+        i += 1
+    if scale_type == 'trace':
+        queue = [root]
+        x_min, x_max = float('inf'), -float('inf')
+        y_min, y_max = float('inf'), -float('inf')
+        scale_factor = 0.5
+        while queue != []:
+            node = queue.pop(0)
+            traces = node.trace
+            for agent_id in agent_list:
+                trace = np.array(traces[agent_id])
+                x_min = min(x_min, min(trace[:, x_dim]))
+                x_max = max(x_max, max(trace[:, x_dim]))
+                y_min = min(y_min, min(trace[:, y_dim]))
+                y_max = max(y_max, max(trace[:, y_dim]))
+        fig.update_xaxes(
+            range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
+        fig.update_yaxes(
+            range=[y_min-scale_factor*(y_max-y_min), y_max+scale_factor*(y_max-y_min)])
+    return fig
+
+
+def draw_simulation_tree_single(root: AnalysisTreeNode, agent_id, fig: go.Figure() = go.Figure(), x_dim: int = 1, y_dim: int = 2, color_id=None, map_type='lines'):
     global color_cnt, bg_color
-    if fig is None:
-        fig = go.Figure()
     fig = draw_map(map=map, fig=fig, fill_type=map_type)
     if color_id is None:
         color_id = color_cnt
     fg_color = ['rgb(31,119,180)', 'rgb(255,127,14)', 'rgb(44,160,44)', 'rgb(214,39,40)', 'rgb(148,103,189)',
                 'rgb(140,86,75)', 'rgb(227,119,194)', 'rgb(127,127,127)', 'rgb(188,189,34)', 'rgb(23,190,207)']
     queue = [root]
+
     while queue != []:
         node = queue.pop(0)
         traces = node.trace
@@ -505,20 +593,20 @@ def draw_simulation_tree(root: AnalysisTreeNode, agent_id, fig=None, x_dim: int
         # [[time,x,y,theta,v]...]
         trace = np.array(traces[agent_id])
         # print(trace)
-        trace_y = trace[:, y_dim].tolist()
-        trace_x = trace[:, x_dim].tolist()
-        trace_x_rev = trace_x[::-1]
-        # print(trace_x)
-        trace_upper = [i+1 for i in trace_y]
-        trace_lower = [i-1 for i in trace_y]
-        trace_lower = trace_lower[::-1]
-        # print(trace_upper)
-        # print(trace[:, y_dim])
-        fig.add_trace(go.Scatter(x=trace_x+trace_x_rev, y=trace_upper+trace_lower,
-                                 fill='toself',
-                                 fillcolor=bg_color[color_id],
-                                 line_color='rgba(255,255,255,0)',
-                                 showlegend=False))
+        # trace_y = trace[:, y_dim].tolist()
+        # trace_x = trace[:, x_dim].tolist()
+        # trace_x_rev = trace_x[::-1]
+        # # print(trace_x)
+        # trace_upper = [i+1 for i in trace_y]
+        # trace_lower = [i-1 for i in trace_y]
+        # trace_lower = trace_lower[::-1]
+        # # print(trace_upper)
+        # # print(trace[:, y_dim])
+        # fig.add_trace(go.Scatter(x=trace_x+trace_x_rev, y=trace_upper+trace_lower,
+        #                          fill='toself',
+        #                          fillcolor=bg_color[color_id],
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=False))
         fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
                                  mode='lines',
                                  line_color=fg_color[color_id],
@@ -527,7 +615,6 @@ def draw_simulation_tree(root: AnalysisTreeNode, agent_id, fig=None, x_dim: int
                                  name='lines'))
         color_id = (color_id+1) % 10
         queue += node.child
-    fig.update_traces(mode='lines')
     color_cnt = color_id
     return fig
 
@@ -830,7 +917,7 @@ def draw_simulation_anime(root, map=None, fig=None):
     return fig
 
 
-def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines'):
+def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace'):
     # make figure
     fig_dict = {
         "data": [],
@@ -865,16 +952,16 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
                 time_point = round(trace[i][0], 2)
                 if time_point not in timed_point_dict:
                     timed_point_dict[time_point] = [
-                        {agent_id: trace[i][1:].tolist()}]
+                        {agent_id: trace[i][0:].tolist()}]
                 else:
                     init = False
                     for record in timed_point_dict[time_point]:
-                        if list(record.values())[0] == trace[i][1:].tolist():
+                        if list(record.values())[0] == trace[i][0:].tolist():
                             init = True
                             break
                     if init == False:
                         timed_point_dict[time_point].append(
-                            {agent_id: trace[i][1:].tolist()})
+                            {agent_id: trace[i][0:].tolist()})
             time = round(trace[i][0], 2)
         stack += node.child
     # fill in most of layout
@@ -882,10 +969,10 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
     # print(timed_point_dict.keys())
     duration = int(600/time)
     fig_dict["layout"]["xaxis"] = {
-        "range": [(x_min-10), (x_max+10)],
+        "range": [x_min, x_max],
         "title": "x position"}
     fig_dict["layout"]["yaxis"] = {
-        "range": [(y_min-2), (y_max+2)],
+        "range": [y_min, y_max],
         "title": "y position"}
     fig_dict["layout"]["hovermode"] = "closest"
     fig_dict["layout"]["updatemenus"] = [
@@ -942,10 +1029,10 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
     for data in point_list:
         trace = list(data.values())[0]
         # print(trace)
-        x_list.append(trace[x_dim - 1])
-        y_list.append(trace[y_dim - 1])
+        x_list.append(trace[x_dim])
+        y_list.append(trace[y_dim])
         text_list.append(
-            ('{:.2f}'.format(trace[x_dim-1]), '{:.2f}'.format(trace[y_dim-1])))
+            ('{:.2f}'.format(trace[x_dim]), '{:.2f}'.format(trace[y_dim])))
     data_dict = {
         "x": x_list,
         "y": y_list,
@@ -977,10 +1064,10 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
         for data in point_list:
             trace = list(data.values())[0]
             # print(trace)
-            trace_x.append(trace[x_dim-1])
-            trace_y.append(trace[y_dim-1])
+            trace_x.append(trace[x_dim])
+            trace_y.append(trace[y_dim])
             text_list.append(
-                ('{:.2f}'.format(trace[x_dim-1]), '{:.2f}'.format(trace[y_dim-1])))
+                ('{:.2f}'.format(trace[x_dim]), '{:.2f}'.format(trace[y_dim])))
         data_dict = {
             "x": trace_x,
             "y": trace_y,
@@ -1031,7 +1118,8 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
     fig_dict["layout"]["sliders"] = [sliders_dict]
 
     fig = go.Figure(fig_dict)
-    fig = draw_map(map, 'g', fig, map_type)
+    print(map)
+    fig = draw_map(map, 'rgba(0,0,0,1)', fig, map_type)
     i = 0
     queue = [root]
     previous_mode = {}
@@ -1091,10 +1179,304 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
                 previous_mode[agent_id] = node.mode[agent_id]
         queue += node.child
     fig.update_traces(showlegend=False)
+    scale_factor = 0.5
+    if scale_type == 'trace':
+        fig.update_xaxes(
+            range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
+        fig.update_yaxes(
+            range=[y_min-scale_factor*(y_max-y_min), y_max+scale_factor*(y_max-y_min)])
+    # fig.update_annotations(textfont=dict(size=14, color="black"))
+    # print(fig.frames[0].layout["annotations"])
+    return fig
+
+
+def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace'):
+    # make figure
+    fig_dict = {
+        "data": [],
+        "layout": {},
+        "frames": []
+    }
+    # fig = plot_map(map, 'g', fig)
+    timed_point_dict = {}
+    stack = [root]
+    print("plot")
+    # print(root.mode)
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    # segment_start = set()
+    # previous_mode = {}
+    # for agent_id in root.mode:
+    #     previous_mode[agent_id] = []
+
+    while stack != []:
+        node = stack.pop()
+        traces = node.trace
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            # print(trace)
+            # segment_start.add(round(trace[0][0], 2))
+            for i in range(len(trace)):
+                x_min = min(x_min, trace[i][x_dim])
+                x_max = max(x_max, trace[i][x_dim])
+                y_min = min(y_min, trace[i][y_dim])
+                y_max = max(y_max, trace[i][y_dim])
+                # print(round(trace[i][0], 2))
+                time_point = round(trace[i][0], 2)
+                tmp_trace = trace[i][0:].tolist()
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = {agent_id: [tmp_trace]}
+                else:
+                    if agent_id not in timed_point_dict[time_point].keys():
+                        timed_point_dict[time_point][agent_id] = [tmp_trace]
+                    elif tmp_trace not in timed_point_dict[time_point][agent_id]:
+                        timed_point_dict[time_point][agent_id].append(
+                            tmp_trace)
+            time = round(trace[i][0], 2)
+        stack += node.child
+    # fill in most of layout
+    # print(segment_start)
+    # print(timed_point_dict.keys())
+    duration = int(300/time)
+    fig_dict["layout"]["xaxis"] = {
+        # "range": [x_min, x_max],
+        "title": "x position"}
+    fig_dict["layout"]["yaxis"] = {
+        # "range": [y_min, y_max],
+        "title": "y position"}
+    fig_dict["layout"]["hovermode"] = "closest"
+    fig_dict["layout"]["updatemenus"] = [
+        {
+            "buttons": [
+                {
+                    "args": [None, {"frame": {"duration": duration, "redraw": False},
+                                    "fromcurrent": True, "transition": {"duration": duration,
+                                                                        "easing": "quadratic-in-out"}}],
+                    "label": "Play",
+                    "method": "animate"
+                },
+                {
+                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                      "mode": "immediate",
+                                      "transition": {"duration": 0}}],
+                    "label": "Pause",
+                    "method": "animate"
+                }
+            ],
+            "direction": "left",
+            "pad": {"r": 10, "t": 87},
+            "showactive": False,
+            "type": "buttons",
+            "x": 0.1,
+            "xanchor": "right",
+            "y": 0,
+            "yanchor": "top"
+        }
+    ]
+    sliders_dict = {
+        "active": 0,
+        "yanchor": "top",
+        "xanchor": "left",
+        "currentvalue": {
+            "font": {"size": 20},
+            "prefix": "time:",
+            "visible": True,
+            "xanchor": "right"
+        },
+        "transition": {"duration": duration, "easing": "cubic-in-out"},
+        "pad": {"b": 10, "t": 50},
+        "len": 0.9,
+        "x": 0.1,
+        "y": 0,
+        "steps": []
+    }
+    # # make data
+    trace_dict = timed_point_dict[0]
+
+    x_list = []
+    y_list = []
+    text_list = []
+    for time_point in timed_point_dict:
+        trace_dict = timed_point_dict[time_point]
+        for agent_id, point_list in trace_dict.items():
+            for point in point_list:
+                x_list.append(point[x_dim])
+                y_list.append(point[y_dim])
+                text_list.append(
+                    ('{:.2f}'.format(point[x_dim]), '{:.2f}'.format(point[y_dim])))
+        data_dict = {
+            "x": x_list,
+            "y": y_list,
+            "mode": "markers",
+            "text": text_list,
+            "textfont": dict(size=14, color="black"),
+            "visible": False,
+            "textposition": "bottom center",
+            # "marker": {
+            #     "sizemode": "area",
+            #     "sizeref": 200000,
+            #     "size": 2
+            # },
+            "name": "Current Position"
+        }
+        fig_dict["data"].append(data_dict)
+    time_list = list(timed_point_dict.keys())
+    agent_list = list(trace_dict.keys())
+    # start = time_list[0]
+    # step = time_list[1]-start
+    trail_limit = min(15, len(time_list))
+    print(agent_list)
+    # make frames
+    for time_point_id in range(trail_limit, len(time_list)):
+        time_point = time_list[time_point_id]
+        frame = {"data": [], "layout": {
+            "annotations": []}, "name": '{:.2f}'.format(time_point)}
+        # todokeyi
+        trail_len = min(time_point_id+1, trail_limit)
+        opacity_step = 1/trail_len
+        size_step = 3/trail_len
+        for agent_id in agent_list:
+            for id in range(0, trail_len):
+                tmp_point_list = timed_point_dict[time_list[time_point_id-id]][agent_id]
+                trace_x = []
+                trace_y = []
+                text_list = []
+                for point in tmp_point_list:
+                    trace_x.append(point[x_dim])
+                    trace_y.append(point[y_dim])
+                    text_list.append(
+                        ('{:.2f}'.format(point[x_dim]), '{:.2f}'.format(point[y_dim])))
+                    # text_list.append(
+                    #     (agent_id, opacity_step*(trail_len-id)))
+                #  print(trace_y)
+                if id == 0:
+                    data_dict = {
+                        "x": trace_x,
+                        "y": trace_y,
+                        "mode": "markers + text",
+                        # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+                        "text": text_list,
+                        "textfont": dict(size=6, color="black"),
+                        "textposition": "bottom center",
+                        "visible": True,
+                        "marker": {
+                            "color": 'Black',
+                            "opacity": opacity_step*(trail_len-id),
+                            # "sizemode": "area",
+                            # "sizeref": 200000,
+                            "size": 6 + size_step*(trail_len-id)
+                        },
+                        "name": "current position",
+                        # "show_legend": False
+                    }
+                else:
+                    data_dict = {
+                        "x": trace_x,
+                        "y": trace_y,
+                        "mode": "markers",
+                        # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+                        # "text": text_list,
+                        # "textfont": dict(size=6, color="black"),
+                        # "textposition": "bottom center",
+                        "visible": True,
+                        "marker": {
+                            "color": 'Black',
+                            "opacity": opacity_step*(trail_len-id),
+                            # "sizemode": "area",
+                            # "sizeref": 200000,
+                            "size": 6 + size_step*(trail_len-id)
+                        },
+                        "name": "current position",
+                        # "show_legend": False
+                    }
+                frame["data"].append(data_dict)
+
+        fig_dict["frames"].append(frame)
+        slider_step = {"args": [
+            [time_point],
+            {"frame": {"duration": duration, "redraw": False},
+             "mode": "immediate",
+             "transition": {"duration": duration}}
+        ],
+            "label": time_point,
+            "method": "animate"}
+        sliders_dict["steps"].append(slider_step)
+        # print(len(frame["layout"]["annotations"]))
+
+    fig_dict["layout"]["sliders"] = [sliders_dict]
+
+    fig = go.Figure(fig_dict)
+    fig = draw_map(map, 'rgba(0,0,0,1)', fig, map_type)
+    i = 0
+    queue = [root]
+    previous_mode = {}
+    agent_list = []
+    for agent_id in root.mode:
+        previous_mode[agent_id] = []
+        agent_list.append(agent_id)
+    text_pos = 'middle center'
+    # while queue != []:
+    #     node = queue.pop(0)
+    #     traces = node.trace
+    #     # print(node.mode)
+    #     # [[time,x,y,theta,v]...]
+    #     i = 0
+    #     for agent_id in traces:
+    #         trace = np.array(traces[agent_id])
+    #         # print(trace)
+    #         trace_y = trace[:, y_dim].tolist()
+    #         trace_x = trace[:, x_dim].tolist()
+    #         # theta = [i/pi*180 for i in trace[:, 3]]
+    #         i = agent_list.index(agent_id)
+    #         color = colors[i % 5]
+    #         fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+    #                                  mode='lines',
+    #                                  line_color=color,
+    #                                  text=[('{:.2f}'.format(trace_x[i]), '{:.2f}'.format(
+    #                                      trace_y[i])) for i in range(len(trace_x))],
+    #                                  showlegend=False)
+    #                       #  name='lines')
+    #                       )
+    #         if previous_mode[agent_id] != node.mode[agent_id]:
+    #             veh_mode = node.mode[agent_id][0]
+    #             if veh_mode == 'Normal':
+    #                 text_pos = 'middle center'
+    #             elif veh_mode == 'Brake':
+    #                 text_pos = 'middle left'
+    #             elif veh_mode == 'Accelerate':
+    #                 text_pos = 'middle right'
+    #             elif veh_mode == 'SwitchLeft':
+    #                 text_pos = 'top center'
+    #             elif veh_mode == 'SwitchRight':
+    #                 text_pos = 'bottom center'
+
+    #             fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
+    #                                      mode='markers+text',
+    #                                      line_color='rgba(255,255,255,0.3)',
+    #                                      text=str(agent_id)+': ' +
+    #                                      str(node.mode[agent_id][0]),
+    #                                      textposition=text_pos,
+    #                                      textfont=dict(
+    #                 #  family="sans serif",
+    #                 size=10,
+    #                                          color="grey"),
+    #                                      showlegend=False,
+    #                                      ))
+    #             # i += 1
+    #             previous_mode[agent_id] = node.mode[agent_id]
+    #     queue += node.child
+    fig.update_traces(showlegend=False)
+    scale_factor = 0.5
+    if scale_type == 'trace':
+        fig.update_xaxes(
+            range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
+        fig.update_yaxes(
+            range=[y_min-scale_factor*(y_max-y_min), y_max+scale_factor*(y_max-y_min)])
     # fig.update_annotations(textfont=dict(size=14, color="black"))
     # print(fig.frames[0].layout["annotations"])
     return fig
 
+
 # The 'color' property is a color and may be specified as:
 #       - A hex string (e.g. '#ff0000')
 #       - An rgb/rgba string (e.g. 'rgb(255,0,0)')
diff --git a/dryvr_plus_plus/plotter/plotter_README.md b/dryvr_plus_plus/plotter/plotter_README.md
index 4610b406..d5cc7a84 100644
--- a/dryvr_plus_plus/plotter/plotter_README.md
+++ b/dryvr_plus_plus/plotter/plotter_README.md
@@ -1,7 +1,13 @@
-# Plotly-based Plotter Notes
+# Plotly-based Plotter Development Notes
 
 Now the latest version is placed in plotter2D_new.py. All functions in the plotter2D.py still work. Every function are in developemnt and might change.
 
+## Current work & Todo
+- **Animation with trails** supported in test_simu_anime() and will be tested further.
+- **Modified accelerating mode** modified, will be tested
+- **new quadrotor agent**
+- **different color for segments of trace** ongoing. Color choice ?
+
 ## Functions
 Belows are the functions currently used. Some of the functions in the file are deprecated.
 
@@ -29,11 +35,11 @@ The original version is implemented with rectangle and very inefficient.
 - **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
 - **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
 - **color_id:** a int indicating the color. Now 10 kinds of colors are supported. If it is None, the colors of segments will be auto-assigned. The default value is None.
-- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill' or 'detailed'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors. For the 'detailed' mode, it will vistualize the speed limit information (if exists) by fill the lanes with different colors. The Default value is 'lines'.
 
 #### draw_map(map, color, fig, fill_type)
 
-The genernal static plotter for map. It is called in many functions drawing traces.
+The genernal static plotter for map. It is called in many functions drawing traces, so it is often unnecessary to call it separately.
 
 **parameters:**
 - **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py.
@@ -50,10 +56,22 @@ The old ungenernal static plotter for map which support visualization of speed l
 - **color** the color of the margin of the lanes, should be a string like 'black' or in rgb/rgba format, like 'rgb(0,0,0)' or 'rgba(0,0,0,1)'. The default value is 'rgba(0,0,0,1)' which is non-transparent black.
 - **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
 
-#### draw_simulation_tree(root, agent_id, x_dim, y_dim, color, fig):
+#### draw_simulation_tree(root, map, fig, x_dim, y_dim, map_type, scale_type):
 
-The genernal static plotter for reachtube tree. It draws the all traces of reachtubes and the map.
-The original version is implemented with rectangle and very inefficient.
+The genernal static plotter for simulation trees. It draws the traces of agents and map.
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.simulate().
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
+- **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
+- **scale_type** the way to draw the map. It should be 'trace' or 'map'. For the 'trace' mode, the plot will be scaled to show all traces. For the 'map' mode, the plot will be scaled to show the whole map. The Default value is 'trace'.
+
+#### draw_simulation_tree_single(root, agent_id, x_dim, y_dim, color_id, fig):
+
+The genernal static plotter for simulation tree. It draws the  traces of one specific agent.
 
 **parameters:**
 - **root:** the root node of the trace, should be the return value of Scenario.simulate().
@@ -62,7 +80,6 @@ The original version is implemented with rectangle and very inefficient.
 - **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
 - **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
 - **color_id:** a int indicating the color. Now 10 kinds of colors are supported. If it is None, the colors of segments will be auto-assigned. The default value is None.
-- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
 
 #### draw_simulation_anime(root, map, fig)
 
@@ -73,7 +90,7 @@ The old ungenernal plotter for simulation animation. It draws the all traces and
 - **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py. It doesn't handle the map with wrong format. Only SimpleMap3_v2() class is supported now.
 - **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
 
-#### general_simu_anime(root, map, fig, x_dim, y_dim, map_type):
+#### general_simu_anime(root, map, fig, x_dim, y_dim, map_type, scale_type):
 
 The genernal plotter for simulation animation. It draws the all traces and the map. Animation is implemented as points. Since arrow is hard to generalize.
 
@@ -83,5 +100,6 @@ The genernal plotter for simulation animation. It draws the all traces and the m
 - **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
 - **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
 - **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
-- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors. The Default value is 'lines'.
+- **scale_type** the way to draw the map. It should be 'trace' or 'map'. For the 'trace' mode, the plot will be scaled to show all traces. For the 'map' mode, the plot will be scaled to show the whole map. The Default value is 'trace'.
 
diff --git a/dryvr_plus_plus/scene_verifier/map/lane.py b/dryvr_plus_plus/scene_verifier/map/lane.py
index aca0fa3d..e9d120ef 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane.py
@@ -8,9 +8,10 @@ from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane
 class Lane():
     COMPENSATE = 3
 
-    def __init__(self, id, seg_list: List[AbstractLane]):
+    def __init__(self, id, seg_list: List[AbstractLane], speed_limit=None):
         self.id = id
         self.segment_list: List[AbstractLane] = seg_list
+        self.speed_limit = speed_limit
         self._set_longitudinal_start()
 
     def _set_longitudinal_start(self):
@@ -46,8 +47,10 @@ class Lane():
         longitudinal, lateral = segment.local_coordinates(position)
         return lateral
 
-    def get_speed_limit(self, position: np.ndarray) -> float:
+    def get_speed_limit_old(self, position: np.ndarray) -> float:
         seg_idx, segment = self.get_lane_segment(position)
         longitudinal, lateral = segment.local_coordinates(position)
         return segment.speed_limit_at(longitudinal)
 
+    def get_speed_limit(self):
+        return self.speed_limit
diff --git a/dryvr_plus_plus/scene_verifier/map/lane_map.py b/dryvr_plus_plus/scene_verifier/map/lane_map.py
index b7723fe9..eb75f752 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane_map.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane_map.py
@@ -95,11 +95,23 @@ class LaneMap:
         seg_idx, segment = lane.get_lane_segment(position)
         return segment
 
-    def get_speed_limit(self, lane_idx: str, position: np.ndarray) -> AbstractLane:
+    def get_speed_limit_old(self, lane_idx: str, position: np.ndarray) -> float:
         if not isinstance(position, np.ndarray):
             position = np.array(position)
         lane = self.lane_dict[lane_idx]
-        limit = lane.get_speed_limit(position)
+        limit = lane.get_speed_limit_old(position)
         # print(limit)
         # print(position)
         return limit
+
+    def get_speed_limit(self, lane_idx: str) -> float:
+        lane = self.lane_dict[lane_idx]
+        print(lane.get_speed_limit())
+        return lane.get_speed_limit()
+
+    def get_all_speed_limit(self) -> Dict[str, float]:
+        ret_dict = {}
+        for lane_idx, lane in self.lane_dict.items():
+            ret_dict[lane_idx] = lane.get_speed_limit()
+        print(ret_dict)
+        return ret_dict
-- 
GitLab