From f301069d80cb38600a3818cf4a75d1f3b271cb7e Mon Sep 17 00:00:00 2001
From: keyis2 <keyis2@illinois.edu>
Date: Thu, 16 Jun 2022 14:50:16 -0500
Subject: [PATCH] accelerate mode related

---
 demo/demo1.py                                 |   4 +-
 demo/demo2.py                                 |  20 +--
 demo/demo5.py                                 |  86 +++++++++++
 demo/example_controller5.py                   |  77 ++++++++++
 demo/plot1.py                                 |  38 -----
 demo/plot_test.py                             | 113 ++++++++++++--
 demo/plot_test1.py                            |  23 +++
 .../example/example_agent/car_agent.py        |   5 +
 .../example/example_map/simple_map2.py        | 123 ++++++++++-----
 dryvr_plus_plus/plotter/plotter2D.py          | 145 ++++++++++++++----
 .../code_parser/pythonparser.py               |   3 +-
 dryvr_plus_plus/scene_verifier/map/lane.py    |   6 +
 .../scene_verifier/map/lane_map.py            |  10 ++
 .../scene_verifier/map/lane_segment.py        |  93 +++++++++--
 14 files changed, 595 insertions(+), 151 deletions(-)
 create mode 100644 demo/demo5.py
 create mode 100644 demo/example_controller5.py
 delete mode 100644 demo/plot1.py
 create mode 100644 demo/plot_test1.py

diff --git a/demo/demo1.py b/demo/demo1.py
index 9ab0f8a0..2fee44d0 100644
--- a/demo/demo1.py
+++ b/demo/demo1.py
@@ -75,6 +75,6 @@ if __name__ == "__main__":
     fig = go.Figure()
     for traces in res_list:
         # plotly_map(tmp_map, 'g', fig)
-        # fig = plotly_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
-        fig = plotly_simulation_anime(traces, tmp_map, fig)
+        fig = plotly_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
+        # fig = plotly_simulation_anime(traces, tmp_map, fig)
     fig.show()
diff --git a/demo/demo2.py b/demo/demo2.py
index b8eaf9b8..6deab510 100644
--- a/demo/demo2.py
+++ b/demo/demo2.py
@@ -56,7 +56,7 @@ if __name__ == "__main__":
         ]
     )
     # res_list = scenario.simulate_multi(40,1)
-    traces = scenario.verify(40)
+
     # traces = scenario.simulate(40)
 
     # fig = plt.figure(2)
@@ -73,16 +73,18 @@ if __name__ == "__main__":
     # fig = go.Figure()
     # fig = plotly_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
     # fig.show()
+    # traces = scenario.simulate(40)
     # fig = go.Figure()
-    # fig = plotly_simulation_anime(traces, tmp_map, fig)
+    # # fig = plotly_simulation_anime(traces, tmp_map, fig)
     # fig.show()
 
-    fig = go.Figure()
-    fig = plotly_reachtube_tree_v2(traces, 'car1', 1, [2], 'blue', fig)
-    fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
-    fig.show()
     # fig = go.Figure()
-    # fig = generate_reachtube_anime(
-    #     traces, 'car1', 1, [2], 'blue', fig, map=tmp_map)
-    # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
+    # fig = plotly_reachtube_tree_v2(traces, 'car1', 1, [2], 'blue', fig)
+    # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
     # fig.show()
+    traces = scenario.verify(5)
+    fig = go.Figure()
+    fig = generate_reachtube_anime(traces, tmp_map, fig)
+    # fig = plotly_simulation_anime(traces, tmp_map, fig)
+    # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
+    fig.show()
diff --git a/demo/demo5.py b/demo/demo5.py
new file mode 100644
index 00000000..e698e425
--- /dev/null
+++ b/demo/demo5.py
@@ -0,0 +1,86 @@
+from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
+from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
+from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6, SimpleMap3_v2
+from dryvr_plus_plus.plotter.plotter2D import *
+from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
+import plotly.graph_objects as go
+import matplotlib.pyplot as plt
+import numpy as np
+from enum import Enum, auto
+
+
+class VehicleMode(Enum):
+    Normal = auto()
+    SwitchLeft = auto()
+    SwitchRight = auto()
+    Brake = auto()
+    Accelerate = auto()
+
+
+class LaneMode(Enum):
+    Lane0 = auto()
+    Lane1 = auto()
+    Lane2 = auto()
+
+
+class State:
+    x = 0.0
+    y = 0.0
+    theta = 0.0
+    v = 0.0
+    vehicle_mode: VehicleMode = VehicleMode.Normal
+    lane_mode: LaneMode = LaneMode.Lane0
+
+    def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode):
+        self.data = []
+
+
+if __name__ == "__main__":
+    input_code_name = 'example_controller5.py'
+    scenario = Scenario()
+
+    car = CarAgent('car1', file_name=input_code_name)
+    scenario.add_agent(car)
+    car = CarAgent('car2', file_name=input_code_name)
+    scenario.add_agent(car)
+    tmp_map = SimpleMap3_v2()
+    scenario.set_map(tmp_map)
+    scenario.set_sensor(FakeSensor2())
+    scenario.set_init(
+        [
+            [[0, -0.2, 0, 1.0], [0.1, 0.2, 0, 1.0]],
+            [[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
+        ],
+        [
+            (VehicleMode.Normal, LaneMode.Lane1),
+            (VehicleMode.Normal, LaneMode.Lane1),
+        ]
+    )
+    # traces = scenario.verify(30)
+    # # fig = go.Figure()
+    # # fig = plotly_reachtube_tree_v2(traces, 'car1', 1, [2], 'blue', fig)
+    # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
+    # # fig.show()
+    # fig = go.Figure()
+    # fig = generate_reachtube_anime(traces, tmp_map, fig)
+    # # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
+    # fig.show()
+    # fig = plt.figure(2)
+    # fig = plot_map(tmp_map, 'g', fig)
+    # # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
+    # # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
+    # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
+    # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
+    # plt.show()
+    # # fig1 = plt.figure(2)
+    # fig = generate_simulation_anime(traces, tmp_map, fig)
+    # plt.show()
+
+    traces = scenario.simulate(30)
+    # fig = go.Figure()
+    # fig = plotly_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
+    # fig.show()
+    fig = go.Figure()
+    fig = plotly_simulation_anime(traces, tmp_map, fig)
+    # fig = plotly_map(tmp_map, fig=fig)
+    fig.show()
diff --git a/demo/example_controller5.py b/demo/example_controller5.py
new file mode 100644
index 00000000..7ccc13ad
--- /dev/null
+++ b/demo/example_controller5.py
@@ -0,0 +1,77 @@
+from enum import Enum, auto
+import copy
+from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
+
+
+class VehicleMode(Enum):
+    Normal = auto()
+    SwitchLeft = auto()
+    SwitchRight = auto()
+    Brake = auto()
+    Accelerate = auto()
+
+
+class LaneMode(Enum):
+    Lane0 = auto()
+    Lane1 = auto()
+    Lane2 = auto()
+
+
+class LaneObjectMode(Enum):
+    Vehicle = auto()
+    Ped = auto()        # Pedestrians
+    Sign = auto()       # Signs, stop signs, merge, yield etc.
+    Signal = auto()     # Traffic lights
+    Obstacle = auto()   # Static (to road/lane) obstacles
+
+
+class State:
+    x = 0.0
+    y = 0.0
+    theta = 0.0
+    v = 0.0
+    vehicle_mode: VehicleMode = VehicleMode.Normal
+    lane_mode: LaneMode = LaneMode.Lane0
+    type: LaneObjectMode
+
+    def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode):
+        self.data = []
+
+
+def controller(ego: State, other: State, sign: State, lane_map: LaneMap):
+    output = copy.deepcopy(ego)
+    if ego.vehicle_mode == VehicleMode.Normal:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
+            if lane_map.has_left(ego.lane_mode):
+                output.vehicle_mode = VehicleMode.SwitchLeft
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
+            if lane_map.has_right(ego.lane_mode):
+                output.vehicle_mode = VehicleMode.SwitchRight
+        # if ego.lane_mode != other.lane_mode:
+        #     output.vehicle_mode = VehicleMode.Accelerate
+        # if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 5:
+        #     output.vehicle_mode = VehicleMode.Accelerate
+        if lane_map.get_speed_limit(ego.lane_mode, [ego.x, ego.y]) > 1:
+            output.vehicle_mode = VehicleMode.Accelerate
+
+    if ego.vehicle_mode == VehicleMode.SwitchLeft:
+        if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5:
+            output.vehicle_mode = VehicleMode.Normal
+            output.lane_mode = lane_map.left_lane(ego.lane_mode)
+    if ego.vehicle_mode == VehicleMode.SwitchRight:
+        if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) <= -2.5:
+            output.vehicle_mode = VehicleMode.Normal
+            output.lane_mode = lane_map.right_lane(ego.lane_mode)
+    if ego.vehicle_mode == VehicleMode.Accelerate:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
+            output.vehicle_mode = VehicleMode.Normal
+        # if lane_map.get_speed_limit(ego.lane_mode, [ego.x, ego.y]) <= ego.v:
+        #     output.vehicle_mode = VehicleMode.Normal
+
+    return output
diff --git a/demo/plot1.py b/demo/plot1.py
deleted file mode 100644
index 93733594..00000000
--- a/demo/plot1.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from dash import Dash, dcc, html, Input, Output
-import plotly.express as px
-
-app = Dash(__name__)
-
-
-app.layout = html.Div([
-    html.H4('Animated GDP and population over decades'),
-    html.P("Select an animation:"),
-    dcc.RadioItems(
-        id='selection',
-        options=["GDP - Scatter", "Population - Bar"],
-        value='GDP - Scatter',
-    ),
-    dcc.Loading(dcc.Graph(id="graph"), type="cube")
-])
-
-
-@app.callback(
-    Output("graph", "figure"),
-    Input("selection", "value"))
-def display_animated_graph(selection):
-    df = px.data.gapminder()  # replace with your own data source
-    animations = {
-        'GDP - Scatter': px.scatter(
-            df, x="gdpPercap", y="lifeExp", animation_frame="year",
-            animation_group="country", size="pop", color="continent",
-            hover_name="country", log_x=True, size_max=55,
-            range_x=[100, 100000], range_y=[25, 90]),
-        'Population - Bar': px.bar(
-            df, x="continent", y="pop", color="continent",
-            animation_frame="year", animation_group="country",
-            range_y=[0, 4000000000]),
-    }
-    return animations[selection]
-
-
-app.run_server()
diff --git a/demo/plot_test.py b/demo/plot_test.py
index 66102f97..83bedd06 100644
--- a/demo/plot_test.py
+++ b/demo/plot_test.py
@@ -1,9 +1,13 @@
+from turtle import color
 import plotly.graph_objects as go
 import numpy as np
 
 
 x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
 x_rev = x[::-1]
+x2 = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
+x2_rev = x2[::-1]
+
 
 # Line 1
 y1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
@@ -26,15 +30,100 @@ y3_lower = y3_lower[::-1]
 
 fig = go.Figure()
 
+# fig.add_trace(go.Scatter(
+#     x=x+x_rev + x2+x2_rev,
+#     y=y1_upper+y1_lower+y1_upper+y1_lower,
+#     fill='toself',
+#     marker=dict(
+#         symbol='square',
+#         size=16,
+#         cmax=39,
+#         cmin=0,
+#         color='rgb(0,100,80)',
+#         colorbar=dict(
+#             title="Colorbar"
+#         ),
+#         colorscale="Viridis"
+#     ),
+#     # fillcolor='rgba(0,100,80,0.2)',
+#     # line_color='rgba(255,255,255,0)',
+#     showlegend=False,
+#     name='Fair',
+#     mode="markers"
+# ))
+# fig.add_trace(go.Scatter(
+#     x=[1, 2, 2, 1,  2, 3, 3, 2],
+#     y=[1, 1, 2, 2,  3, 3, 4, 4],
+#     fill='toself',
+#     marker=dict(
+#         symbol='square',
+#         size=16,
+#         cmax=3,
+#         cmin=0,
+#         color=[1, 2],
+#         colorbar=dict(
+#             title="Colorbar"
+#         ),
+#         colorscale="Viridis"
+#     ),
+#     # fillcolor='rgba(0,100,80,0.2)',
+#     # line_color='rgba(255,255,255,0)',
+#     showlegend=False,
+#     name='Fair',
+#     mode="markers"
+# ))
+start = [0, 0, 255, 0.5]
+end = [255, 0, 0, 0.5]
+rgb = [0, 0, 255, 0.5]
+pot = 4
+for i in range(len(rgb)-1):
+    rgb[i] = rgb[i] + (pot-0)/(4-0)*(end[i]-start[i])
+print(rgb)
 fig.add_trace(go.Scatter(
-    x=x+x_rev,
-    y=y1_upper+y1_lower,
-    # fill='toself',
+    x=[1, 2, 2, 1],
+    y=[1, 1, 2, 2],
+    fill='toself',
+    marker=dict(
+        symbol='square',
+        size=16,
+        cmax=4,
+        cmin=0,
+        # color=[3, 3, 3, 3],
+        colorbar=dict(
+            title="Colorbar"
+        ),
+        colorscale=[[0, 'rgba(0,0,255,0.5)'], [1, 'rgba(255,0,0,0.5)']]
+    ),
+
+    fillcolor='rgba'+str(tuple(rgb)),
     # fillcolor='rgba(0,100,80,0.2)',
     # line_color='rgba(255,255,255,0)',
-    # showlegend=False,
+    showlegend=False,
     name='Fair',
+    mode="markers"
 ))
+# for i in range(10):
+#     fig.add_trace(go.Scatter(
+#         x=x+x_rev,
+#         y=y3_upper+y3_lower,
+#         fill='toself',
+#         marker=dict(
+#             symbol='square',
+#             size=16,
+#             cmax=39,
+#             cmin=0,
+#             color='rgb(0,100,80)',
+#             colorbar=dict(
+#                 title="Colorbar"
+#             ),
+#             colorscale="Viridis"
+#         ),
+#         # fillcolor='rgba(0,100,80,0.2)',
+#         # line_color='rgba(255,255,255,0)',
+#         showlegend=False,
+#         name='Fair',
+#         mode="markers"
+#     ))
 # fig.add_trace(go.Scatter(
 #     x=x+x_rev,
 #     y=y2_upper+y2_lower,
@@ -53,11 +142,11 @@ fig.add_trace(go.Scatter(
 #     showlegend=False,
 #     name='Ideal',
 # ))
-fig.add_trace(go.Scatter(
-    x=x, y=y1,
-    line_color='rgb(0,100,80)',
-    name='Fair',
-))
+# fig.add_trace(go.Scatter(
+#     x=x, y=y1,
+#     line_color='rgb(0,100,80)',
+#     name='Fair',
+# ))
 # fig.add_trace(go.Scatter(
 #     x=x, y=y2,
 #     line_color='rgb(0,176,246)',
@@ -69,7 +158,7 @@ fig.add_trace(go.Scatter(
 #     name='Ideal',
 # ))
 
-fig.update_traces(mode='lines')
+# fig.update_traces(mode='lines')
 fig.show()
-print(x+x_rev)
-print(y1_upper+y1_lower)
+# print(x+x_rev)
+# print(y1_upper+y1_lower)
diff --git a/demo/plot_test1.py b/demo/plot_test1.py
new file mode 100644
index 00000000..3ba2de26
--- /dev/null
+++ b/demo/plot_test1.py
@@ -0,0 +1,23 @@
+import plotly.graph_objects as go
+import numpy as np
+
+fig = go.Figure()
+values = [1, 2, 3]
+fig.add_trace(go.Scatter(
+    x=values,
+    y=values,
+    marker=dict(
+        symbol='square',
+        size=16,
+        cmax=39,
+        cmin=0,
+        color=values,
+        colorbar=dict(
+            title="Colorbar"
+        ),
+        colorscale="Viridis"
+    ),
+    # marker_symbol='square', marker_line_color="midnightblue", marker_color="lightskyblue",
+    # marker_line_width=2, marker_size=15,
+    mode="markers"))
+fig.show()
diff --git a/dryvr_plus_plus/example/example_agent/car_agent.py b/dryvr_plus_plus/example/example_agent/car_agent.py
index cfad53fb..7accb5b3 100644
--- a/dryvr_plus_plus/example/example_agent/car_agent.py
+++ b/dryvr_plus_plus/example/example_agent/car_agent.py
@@ -88,6 +88,11 @@ class CarAgent(BaseAgent):
             a = -1
             if v <= 0.02:
                 a = 0
+        elif vehicle_mode == "Accelerate":
+            d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
+            a = 1
+            if v >= lane_map.get_speed_limit(vehicle_lane, vehicle_pos)-0.02:
+                a = 0
         elif vehicle_mode == 'Stop':
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
             a = 0
diff --git a/dryvr_plus_plus/example/example_map/simple_map2.py b/dryvr_plus_plus/example/example_map/simple_map2.py
index 05721390..4bf59de7 100644
--- a/dryvr_plus_plus/example/example_map/simple_map2.py
+++ b/dryvr_plus_plus/example/example_map/simple_map2.py
@@ -4,39 +4,41 @@ from dryvr_plus_plus.scene_verifier.map.lane import Lane
 
 import numpy as np
 
+
 class SimpleMap2(LaneMap):
     def __init__(self):
         super().__init__()
         segment0 = StraightLane(
             'seg0',
-            [0,0],
-            [100,0],
+            [0, 0],
+            [100, 0],
             3
         )
         lane0 = Lane('Lane1', [segment0])
         self.add_lanes([lane0])
 
+
 class SimpleMap3(LaneMap):
     def __init__(self):
         super().__init__()
         segment0 = StraightLane(
             'Seg0',
-            [0,3],
-            [50,3],
+            [0, 3],
+            [50, 3],
             3
         )
         lane0 = Lane('Lane0', [segment0])
         segment1 = StraightLane(
             'seg0',
-            [0,0],
-            [50,0],
+            [0, 0],
+            [50, 0],
             3
         )
         lane1 = Lane('Lane1', [segment1])
         segment2 = StraightLane(
             'seg0',
-            [0,-3],
-            [50,-3],
+            [0, -3],
+            [50, -3],
             3
         )
         lane2 = Lane('Lane2', [segment2])
@@ -48,63 +50,100 @@ class SimpleMap3(LaneMap):
         self.right_lane_dict[lane0.id].append(lane1.id)
         self.right_lane_dict[lane1.id].append(lane2.id)
 
+
+class SimpleMap3_v2(LaneMap):
+    def __init__(self):
+        super().__init__()
+        segment0 = StraightLane(
+            'Seg0',
+            [0, 3],
+            [50, 3],
+            3,
+            speed_limit=[(0, 1), (10, 2)]
+        )
+        lane0 = Lane('Lane0', [segment0])
+        segment1 = StraightLane(
+            'seg0',
+            [0, 0],
+            [50, 0],
+            3,
+            speed_limit=[(0, 1), (20, 3)]
+        )
+        lane1 = Lane('Lane1', [segment1])
+        segment2 = StraightLane(
+            'seg0',
+            [0, -3],
+            [50, -3],
+            3,
+            speed_limit=[(0, 1), (25, 2.5)]
+        )
+        lane2 = Lane('Lane2', [segment2])
+        # segment2 = LaneSegment('Lane1', 3)
+        # self.add_lanes([segment1,segment2])
+        self.add_lanes([lane0, lane1, lane2])
+        self.left_lane_dict[lane1.id].append(lane0.id)
+        self.left_lane_dict[lane2.id].append(lane1.id)
+        self.right_lane_dict[lane0.id].append(lane1.id)
+        self.right_lane_dict[lane1.id].append(lane2.id)
+
+
 class SimpleMap5(LaneMap):
     def __init__(self):
         super().__init__()
         segment0 = StraightLane(
             'Seg0',
-            [0,3],
-            [15,3],
+            [0, 3],
+            [15, 3],
             3
         )
         segment1 = StraightLane(
             'Seg1',
-            [15,3], 
-            [25,13],
+            [15, 3],
+            [25, 13],
             3
         )
         segment2 = StraightLane(
             'Seg2',
-            [25,13], 
-            [50,13],
+            [25, 13],
+            [50, 13],
             3
         )
         lane0 = Lane('Lane0', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,0],
-            [17,0],
+            [0, 0],
+            [17, 0],
             3
         )
         segment1 = StraightLane(
             'seg1',
-            [17,0],
-            [27,10],
+            [17, 0],
+            [27, 10],
             3
         )
         segment2 = StraightLane(
             'seg2',
-            [27,10],
-            [50,10],
+            [27, 10],
+            [50, 10],
             3
         )
         lane1 = Lane('Lane1', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,-3],
-            [19,-3],
+            [0, -3],
+            [19, -3],
             3
         )
         segment1 = StraightLane(
             'seg1',
-            [19,-3],
-            [29,7],
+            [19, -3],
+            [29, 7],
             3
         )
         segment2 = StraightLane(
             'seg2',
-            [29,7],
-            [50,7],
+            [29, 7],
+            [50, 7],
             3
         )
         lane2 = Lane('Lane2', [segment0, segment1, segment2])
@@ -114,18 +153,19 @@ class SimpleMap5(LaneMap):
         self.right_lane_dict[lane0.id].append(lane1.id)
         self.right_lane_dict[lane1.id].append(lane2.id)
 
+
 class SimpleMap6(LaneMap):
     def __init__(self):
         super().__init__()
         segment0 = StraightLane(
             'Seg0',
-            [0,3],
-            [15,3],
+            [0, 3],
+            [15, 3],
             3
         )
         segment1 = CircularLane(
             'Seg1',
-            [15,8],
+            [15, 8],
             5,
             np.pi*3/2,
             np.pi*2,
@@ -134,20 +174,20 @@ class SimpleMap6(LaneMap):
         )
         segment2 = StraightLane(
             'Seg2',
-            [20,8], 
-            [20,30],
+            [20, 8],
+            [20, 30],
             3
         )
         lane0 = Lane('Lane0', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,0],
-            [18,0],
+            [0, 0],
+            [18, 0],
             3
         )
         segment1 = CircularLane(
             'seg1',
-            [18,5],
+            [18, 5],
             5,
             3*np.pi/2,
             2*np.pi,
@@ -156,20 +196,20 @@ class SimpleMap6(LaneMap):
         )
         segment2 = StraightLane(
             'seg2',
-            [23,5],
-            [23,30],
+            [23, 5],
+            [23, 30],
             3
         )
         lane1 = Lane('Lane1', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,-3],
-            [21,-3],
+            [0, -3],
+            [21, -3],
             3
         )
         segment1 = CircularLane(
             'seg1',
-            [21,2],
+            [21, 2],
             5,
             np.pi*3/2,
             np.pi*2,
@@ -178,8 +218,8 @@ class SimpleMap6(LaneMap):
         )
         segment2 = StraightLane(
             'seg2',
-            [26,2],
-            [26,30],
+            [26, 2],
+            [26, 30],
             3
         )
         lane2 = Lane('Lane2', [segment0, segment1, segment2])
@@ -189,6 +229,7 @@ class SimpleMap6(LaneMap):
         self.right_lane_dict[lane0.id].append(lane1.id)
         self.right_lane_dict[lane1.id].append(lane2.id)
 
+
 if __name__ == "__main__":
     test_map = SimpleMap3()
     print(test_map.left_lane_dict)
diff --git a/dryvr_plus_plus/plotter/plotter2D.py b/dryvr_plus_plus/plotter/plotter2D.py
index 9e8fd27b..bb8fc7fd 100644
--- a/dryvr_plus_plus/plotter/plotter2D.py
+++ b/dryvr_plus_plus/plotter/plotter2D.py
@@ -4,16 +4,17 @@ This file consist main plotter code for DryVR reachtube output
 
 from __future__ import annotations
 from audioop import reverse
+# from curses import start_color
 from re import A
 import matplotlib.patches as patches
 import matplotlib.pyplot as plt
 import numpy as np
 from math import pi
-import plotly.express as px
 import plotly.graph_objects as go
 from typing import List
 from PIL import Image, ImageDraw
 import io
+import copy
 import operator
 from collections import OrderedDict
 
@@ -97,7 +98,7 @@ def plot(
     return fig, (x_min, x_max), (y_min, y_max)
 
 
-def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='blue', fig=None, x_lim=None, y_lim=None, map=None):
+def generate_reachtube_anime(root, map=None, fig=None):
     # make figure
     fig_dict = {
         "data": [],
@@ -109,11 +110,13 @@ def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[in
     stack = [root]
     x_min, x_max = float('inf'), -float('inf')
     y_min, y_max = float('inf'), -float('inf')
+    print("reachtude")
     while stack != []:
         node = stack.pop()
         traces = node.trace
         for agent_id in traces:
             trace = np.array(traces[agent_id])
+            print(trace)
             for i in range(len(trace)):
                 x_min = min(x_min, trace[i][1])
                 x_max = max(x_max, trace[i][1])
@@ -181,7 +184,7 @@ def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[in
             "visible": True,
             "xanchor": "right"
         },
-        "method": "update",
+        # "method": "update",
         "transition": {"duration": duration, "easing": "cubic-in-out"},
         "pad": {"b": 10, "t": 50},
         "len": 0.9,
@@ -191,6 +194,7 @@ def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[in
     }
     # make data
     point_list = timed_point_dict[0]
+    # print("reachtude")
     # print(point_list)
     data_dict = {
         "x": [data[0] for data in point_list],
@@ -232,23 +236,25 @@ def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[in
             "name": "current position"
         }
         frame["data"].append(data_dict)
+        # print(trace_x)
         for i in range(len(trace_x)):
-            ax = np.cos(trace_theta[i])*trace_v[i]
-            ay = np.sin(trace_theta[i])*trace_v[i]
+            pass
+            # ax = np.cos(trace_theta[i])*trace_v[i]
+            # ay = np.sin(trace_theta[i])*trace_v[i]
             # print(trace_x[i]+ax, trace_y[i]+ay)
-            annotations_dict = {"x": trace_x[i]+ax+0.1, "y": trace_y[i]+ay,
-                                # "xshift": ax, "yshift": ay,
-                                "ax": trace_x[i], "ay": trace_y[i],
-                                "arrowwidth": 2,
-                                # "arrowside": 'end',
-                                "showarrow": True,
-                                # "arrowsize": 1,
-                                "xref": 'x', "yref": 'y',
-                                "axref": 'x', "ayref": 'y',
-                                # "text": "erver",
-                                "arrowhead": 1,
-                                "arrowcolor": "black"}
-            frame["layout"]["annotations"].append(annotations_dict)
+            # annotations_dict = {"x": trace_x[i]+ax+0.1, "y": trace_y[i]+ay,
+            #                     # "xshift": ax, "yshift": ay,
+            #                     "ax": trace_x[i], "ay": trace_y[i],
+            #                     "arrowwidth": 2,
+            #                     # "arrowside": 'end',
+            #                     "showarrow": True,
+            #                     # "arrowsize": 1,
+            #                     "xref": 'x', "yref": 'y',
+            #                     "axref": 'x', "ayref": 'y',
+            #                     # "text": "erver",
+            #                     "arrowhead": 1,
+            #                     "arrowcolor": "black"}
+            # frame["layout"]["annotations"].append(annotations_dict)
 
         fig_dict["frames"].append(frame)
         slider_step = {"args": [
@@ -265,7 +271,7 @@ def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[in
     fig_dict["layout"]["sliders"] = [sliders_dict]
 
     fig = go.Figure(fig_dict)
-    fig = plotly_map(map, 'g', fig)
+    # fig = plotly_map(map, 'g', fig)
     for agent_id in traces:
         fig = plotly_reachtube_tree_v2(root, agent_id, 1, [2], 'blue', fig)
 
@@ -274,7 +280,7 @@ def generate_reachtube_anime(root, agent_id, x_dim: int = 0, y_dim_list: List[in
 
 def plotly_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='blue', fig=None, x_lim=None, y_lim=None):
     if fig is None:
-        fig = go.figure()
+        fig = go.Figure()
 
     # ax = fig.gca()
     # if x_lim is None:
@@ -301,7 +307,7 @@ def plotly_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int]
 
 def plotly_reachtube_tree_v2(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='blue', fig=None, x_lim=None, y_lim=None):
     if fig is None:
-        fig = go.figure()
+        fig = go.Figure()
 
     # ax = fig.gca()
     # if x_lim is None:
@@ -316,7 +322,7 @@ def plotly_reachtube_tree_v2(root, agent_id, x_dim: int = 0, y_dim_list: List[in
         node = queue.pop(0)
         traces = node.trace
         trace = np.array(traces[agent_id])
-        print(trace[0], trace[1], trace[-2], trace[-1])
+        # print(trace[0], trace[1], trace[-2], trace[-1])
         max_id = len(trace)-1
         # trace_x = np.zeros(max_id+1)
         # trace_x_2 = np.zeros(max_id+1)
@@ -413,7 +419,7 @@ def plotly_reachtube_tree_v2(root, agent_id, x_dim: int = 0, y_dim_list: List[in
         node = queue.pop(0)
         traces = node.trace
         trace = np.array(traces[agent_id])
-        print(trace[0], trace[1], trace[-2], trace[-1])
+        # print(trace[0], trace[1], trace[-2], trace[-1])
         max_id = len(trace)-1
         # trace_x = np.zeros(max_id+1)
         # trace_x_2 = np.zeros(max_id+1)
@@ -541,10 +547,12 @@ def plot_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] =
     return fig
 
 
-def plotly_map(map, color='b', fig=None, x_lim=None, y_lim=None):
+def plotly_map(map, color='b', fig: go.Figure() = None, x_lim=None, y_lim=None):
     if fig is None:
-        fig = go.figure()
-
+        fig = go.Figure()
+    all_x = []
+    all_y = []
+    all_v = []
     for lane_idx in map.lane_dict:
         lane = map.lane_dict[lane_idx]
         for lane_seg in lane.segment_list:
@@ -574,6 +582,11 @@ def plotly_map(map, color='b', fig=None, x_lim=None, y_lim=None):
                                          showlegend=False,
                                          # text=theta,
                                          name='lines'))
+                # fig = go.Figure().add_heatmap(x=)
+                seg_x, seg_y, seg_v = lane_seg.get_all_speed()
+                all_x += seg_x
+                all_y += seg_y
+                all_v += seg_v
             elif lane_seg.type == "Circular":
                 phase_array = np.linspace(
                     start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
@@ -598,6 +611,59 @@ def plotly_map(map, color='b', fig=None, x_lim=None, y_lim=None):
                                          name='lines'))
             else:
                 raise ValueError(f'Unknown lane segment type {lane_seg.type}')
+    start_color = [0, 0, 255, 0.2]
+    end_color = [255, 0, 0, 0.2]
+    curr_color = copy.deepcopy(start_color)
+    max_speed = max(all_v)
+    min_speed = min(all_v)
+
+    for i in range(len(all_v)):
+        # print(all_x[i])
+        # print(all_y[i])
+        # print(all_v[i])
+        curr_color = copy.deepcopy(start_color)
+        for j in range(len(curr_color)-1):
+            curr_color[j] += (all_v[i]-min_speed)/(max_speed -
+                                                   min_speed)*(end_color[j]-start_color[j])
+        fig.add_trace(go.Scatter(x=all_x[i], y=all_y[i],
+                                 mode='lines',
+                                 line_color='rgba(0,0,0,0)',
+                                 fill='toself',
+                                 fillcolor='rgba'+str(tuple(curr_color)),
+                                 #  marker=dict(
+                                 #     symbol='square',
+                                 #     size=16,
+                                 #     cmax=max_speed,
+                                 #     cmin=min_speed,
+                                 #     # color=all_v[i],
+                                 #     colorbar=dict(
+                                 #         title="Colorbar"
+                                 #     ),
+                                 #     colorscale=[
+                                 #         [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+                                 # ),
+                                 showlegend=False,
+                                 ))
+    fig.add_trace(go.Scatter(x=[0], y=[0],
+                             mode='markers',
+                             # fill='toself',
+                             # fillcolor='rgba'+str(tuple(curr_color)),
+                             marker=dict(
+                                 symbol='square',
+                                 size=16,
+                                 cmax=max_speed,
+                                 cmin=min_speed,
+                                 color='rgba(0,0,0,0)',
+                                 colorbar=dict(
+                                        title="Speed Limit"
+                                 ),
+                                 colorscale=[
+                                     [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+    ),
+        showlegend=False,
+    ))
+    # fig.update_coloraxes(colorbar=dict(title="Colorbar"), colorscale=[
+    #                      [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]])
     return fig
 
 
@@ -791,7 +857,7 @@ def plotly_simulation_anime(root, map=None, fig=None):
     # print(root.mode)
     x_min, x_max = float('inf'), -float('inf')
     y_min, y_max = float('inf'), -float('inf')
-    segment_start = set()
+    # segment_start = set()
     # previous_mode = {}
     # for agent_id in root.mode:
     #     previous_mode[agent_id] = []
@@ -801,7 +867,8 @@ def plotly_simulation_anime(root, map=None, fig=None):
         traces = node.trace
         for agent_id in traces:
             trace = np.array(traces[agent_id])
-            segment_start.add(round(trace[0][0], 2))
+            print(trace)
+            # segment_start.add(round(trace[0][0], 2))
             for i in range(len(trace)):
                 x_min = min(x_min, trace[i][1])
                 x_max = max(x_max, trace[i][1])
@@ -823,7 +890,7 @@ def plotly_simulation_anime(root, map=None, fig=None):
             time = round(trace[i][0], 2)
         stack += node.child
     # fill in most of layout
-    print(segment_start)
+    # print(segment_start)
     # print(timed_point_dict.keys())
     duration = int(600/time)
     fig_dict["layout"]["xaxis"] = {
@@ -880,6 +947,7 @@ def plotly_simulation_anime(root, map=None, fig=None):
     }
     # make data
     point_list = timed_point_dict[0]
+    print(point_list)
     x_list = []
     y_list = []
     text_list = []
@@ -888,12 +956,14 @@ def plotly_simulation_anime(root, map=None, fig=None):
         # print(trace)
         x_list.append(trace[0])
         y_list.append(trace[1])
-        text_list.append((round(trace[3], 2), round(trace[2]/pi*180, 2)))
+        text_list.append(
+            ('{:.2f}'.format(trace[2]/pi*180), '{:.3f}'.format(trace[3])))
     data_dict = {
         "x": x_list,
         "y": y_list,
         "mode": "markers + text",
         "text": text_list,
+        "textfont": dict(size=14, color="black"),
         "textposition": "bottom center",
         # "marker": {
         #     "sizemode": "area",
@@ -908,7 +978,7 @@ def plotly_simulation_anime(root, map=None, fig=None):
     for time_point in timed_point_dict:
         # print(time_point)
         frame = {"data": [], "layout": {
-            "annotations": []}, "name": str(time_point)}
+            "annotations": []}, "name": '{:.2f}'.format(time_point)}
         # print(timed_point_dict[time_point][0])
         point_list = timed_point_dict[time_point]
         # point_list = list(OrderedDict.fromkeys(timed_point_dict[time_point]))
@@ -928,7 +998,8 @@ def plotly_simulation_anime(root, map=None, fig=None):
             "x": trace_x,
             "y": trace_y,
             "mode": "markers + text",
-            "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+            # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+            "text": [('{:.2f}'.format(trace_theta[i]/pi*180), '{:.3f}'.format(trace_v[i])) for i in range(len(trace_theta))],
             "textfont": dict(size=14, color="black"),
             "textposition": "bottom center",
             # "marker": {
@@ -936,7 +1007,8 @@ def plotly_simulation_anime(root, map=None, fig=None):
             #     "sizeref": 200000,
             #     "size": 2
             # },
-            "name": "current position"
+            "name": "current position",
+            # "show_legend": False
         }
         frame["data"].append(data_dict)
         for i in range(len(trace_x)):
@@ -1032,6 +1104,11 @@ def plotly_simulation_anime(root, map=None, fig=None):
                         text_pos = 'middle left'
                     else:
                         text_pos = 'middle right'
+                elif veh_mode == 'Accelerate':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'middle right'
+                    else:
+                        text_pos = 'middle left'
                 elif veh_mode == 'SwitchLeft':
                     if theta >= -pi/2 and theta <= pi/2:
                         text_pos = 'top center'
@@ -1057,7 +1134,7 @@ def plotly_simulation_anime(root, map=None, fig=None):
                 # i += 1
                 previous_mode[agent_id] = node.mode[agent_id]
         queue += node.child
-    # fig.update_traces(mode='lines')
+    fig.update_traces(showlegend=False)
     # fig.update_annotations(textfont=dict(size=14, color="black"))
     # print(fig.frames[0].layout["annotations"])
     return fig
diff --git a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py b/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py
index c468de0b..7f75ed5a 100644
--- a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py
+++ b/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py
@@ -68,6 +68,7 @@ class Guard(Statement):
     '''
     def parseGuard(node, code):
         # assume guard is a strict comparision (modeType == mode)
+        # keyi mark: may support more comparision
         if isinstance(node.test, ast.Compare):  # ==
             if isinstance(node.test.comparators[0], ast.Attribute):
                 if ("Mode" in str(node.test.comparators[0].value.id)):
@@ -467,7 +468,7 @@ class ControllerAst():
         childrens_guards = []
         childrens_resets = []
         recoutput = []
-        # tree.show()
+        tree.show()
         if parent == None:
             s = Statement("root", None, None)
             tree.create_node("root")
diff --git a/dryvr_plus_plus/scene_verifier/map/lane.py b/dryvr_plus_plus/scene_verifier/map/lane.py
index da3adef5..aca0fa3d 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane.py
@@ -45,3 +45,9 @@ class Lane():
         seg_idx, segment = self.get_lane_segment(position)
         longitudinal, lateral = segment.local_coordinates(position)
         return lateral
+
+    def get_speed_limit(self, position: np.ndarray) -> float:
+        seg_idx, segment = self.get_lane_segment(position)
+        longitudinal, lateral = segment.local_coordinates(position)
+        return segment.speed_limit_at(longitudinal)
+
diff --git a/dryvr_plus_plus/scene_verifier/map/lane_map.py b/dryvr_plus_plus/scene_verifier/map/lane_map.py
index f17afe6f..b7723fe9 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane_map.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane_map.py
@@ -1,3 +1,4 @@
+from ctypes.wintypes import PINT
 from typing import Dict, List
 import copy
 from enum import Enum
@@ -93,3 +94,12 @@ class LaneMap:
         lane = self.lane_dict[lane_idx]
         seg_idx, segment = lane.get_lane_segment(position)
         return segment
+
+    def get_speed_limit(self, lane_idx: str, position: np.ndarray) -> AbstractLane:
+        if not isinstance(position, np.ndarray):
+            position = np.array(position)
+        lane = self.lane_dict[lane_idx]
+        limit = lane.get_speed_limit(position)
+        # print(limit)
+        # print(position)
+        return limit
diff --git a/dryvr_plus_plus/scene_verifier/map/lane_segment.py b/dryvr_plus_plus/scene_verifier/map/lane_segment.py
index 642d7a00..fd67782d 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane_segment.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane_segment.py
@@ -1,9 +1,13 @@
+from turtle import speed
 from typing import List
 import numpy as np
 from abc import ABCMeta, abstractmethod
 from typing import Tuple, List, Optional, Union
+import copy
+from sympy import false
+
+from dryvr_plus_plus.scene_verifier.utils.utils import wrap_to_pi, Vector, get_class_path, class_from_path, to_serializable
 
-from dryvr_plus_plus.scene_verifier.utils.utils import wrap_to_pi, Vector, get_class_path, class_from_path,to_serializable
 
 class LineType:
 
@@ -14,6 +18,7 @@ class LineType:
     CONTINUOUS = 2
     CONTINUOUS_LINE = 3
 
+
 class AbstractLane(object):
 
     """A lane on the road, described by its central curve."""
@@ -25,7 +30,7 @@ class AbstractLane(object):
     longitudinal_start: float = 0
     line_types: List["LineType"]
 
-    def __init__(self, id:str):
+    def __init__(self, id: str):
         self.id = id
         self.type = None
 
@@ -137,18 +142,19 @@ class AbstractLane(object):
         angle = np.abs(wrap_to_pi(heading - self.heading_at(s)))
         return abs(r) + max(s - self.length, 0) + max(0 - s, 0) + heading_weight*angle
 
+
 class StraightLane(AbstractLane):
 
     """A lane going in straight line."""
 
     def __init__(self,
-                 id: str, 
+                 id: str,
                  start: Vector,
                  end: Vector,
                  width: float = AbstractLane.DEFAULT_WIDTH,
                  line_types: Tuple[LineType, LineType] = None,
                  forbidden: bool = False,
-                 speed_limit: float = 20,
+                 speed_limit: List[Tuple[float, float]] = None,
                  priority: int = 0) -> None:
         """
         New straight lane.
@@ -164,14 +170,19 @@ class StraightLane(AbstractLane):
         self.start = np.array(start)
         self.end = np.array(end)
         self.width = width
-        self.heading = np.arctan2(self.end[1] - self.start[1], self.end[0] - self.start[0])
+        self.heading = np.arctan2(
+            self.end[1] - self.start[1], self.end[0] - self.start[0])
         self.length = np.linalg.norm(self.end - self.start)
         self.line_types = line_types or [LineType.STRIPED, LineType.STRIPED]
         self.direction = (self.end - self.start) / self.length
-        self.direction_lateral = np.array([-self.direction[1], self.direction[0]])
+        self.direction_lateral = np.array(
+            [-self.direction[1], self.direction[0]])
         self.forbidden = forbidden
         self.priority = priority
-        self.speed_limit = speed_limit
+        if speed_limit != None:
+            self.speed_limit = sorted(speed_limit, key=lambda elem: elem[0])
+        else:
+            self.speed_limit = None
         self.type = 'Straight'
         self.longitudinal_start = 0
 
@@ -190,6 +201,54 @@ class StraightLane(AbstractLane):
         lateral = np.dot(delta, self.direction_lateral)
         return float(longitudinal), float(lateral)
 
+    def speed_limit_at(self, longitudinal: float) -> float:
+        # print(self.speed_limit)
+        if longitudinal >= self.speed_limit[-1][0]:
+            # print(longitudinal, self.speed_limit[-1][1])
+            return self.speed_limit[-1][1]
+        prev_limit = self.speed_limit[0][1]
+        for (start, limit) in self.speed_limit:
+            if longitudinal <= start:
+                # print(longitudinal, prev_limit)
+                return prev_limit
+            prev_limit = limit
+
+        return -1
+    # in format for polty filling mode
+
+    def get_all_speed(self):
+        end_longitudinal, end_lateral = self.local_coordinates(self.end)
+        ret_x = []
+        ret_y = []
+        ret_v = []
+        x_y = np.ndarray(shape=2)
+        seg_pos = []
+        speed_limit = copy.deepcopy(self.speed_limit)
+        speed_limit.append(tuple([end_longitudinal, self.speed_limit[-1][1]]))
+        for i in range(len(self.speed_limit)):
+            seg_start = speed_limit[i][0]
+            limit = speed_limit[i][1]
+            if end_longitudinal < seg_start:
+                break
+            seg_pos = []
+            seg_end = min(end_longitudinal, speed_limit[i+1][0])
+            x_y = self.position(seg_start, self.width/2)
+            seg_pos.append(x_y.tolist())
+            x_y = self.position(seg_end, self.width/2)
+            seg_pos.append(x_y.tolist())
+            x_y = self.position(seg_end, -self.width/2)
+            seg_pos.append(x_y.tolist())
+            x_y = self.position(seg_start, -self.width/2)
+            seg_pos.append(x_y.tolist())
+            ret_x.append([pos[0] for pos in seg_pos])
+            ret_y.append([pos[1] for pos in seg_pos])
+            ret_v.append(limit)
+        # print('get_all_speed')
+        # print(ret_x)
+        # print(ret_y)
+        # print(ret_v)
+        return ret_x, ret_y, ret_v
+
     @classmethod
     def from_config(cls, config: dict):
         config["start"] = np.array(config["start"])
@@ -216,7 +275,7 @@ class CircularLane(AbstractLane):
     """A lane going in circle arc."""
 
     def __init__(self,
-                 id, 
+                 id,
                  center: Vector,
                  radius: float,
                  start_phase: float,
@@ -225,7 +284,7 @@ class CircularLane(AbstractLane):
                  width: float = AbstractLane.DEFAULT_WIDTH,
                  line_types: List[LineType] = None,
                  forbidden: bool = False,
-                 speed_limit: float = 20,
+                 speed_limit: List[Tuple[float, float]] = None,
                  priority: int = 0) -> None:
         super().__init__(id)
         self.center = np.array(center)
@@ -239,7 +298,7 @@ class CircularLane(AbstractLane):
         self.forbidden = forbidden
         self.length = radius*(end_phase - start_phase) * self.direction
         self.priority = priority
-        self.speed_limit = speed_limit
+        self.speed_limit = sorted(speed_limit, key=lambda elem: elem[0])
         self.type = 'Circular'
         self.longitudinal_start = 0
 
@@ -264,6 +323,12 @@ class CircularLane(AbstractLane):
         lateral = self.direction*(self.radius - r)
         return longitudinal, lateral
 
+    def speed_limit_at(self, longitudinal: float) -> float:
+        for (start, limit) in self.speed_limit:
+            if longitudinal <= start:
+                return limit
+        return -1
+
     @classmethod
     def from_config(cls, config: dict):
         config["center"] = np.array(config["center"])
@@ -288,15 +353,15 @@ class CircularLane(AbstractLane):
 
 
 class LaneSegment:
-    def __init__(self, id, lane_parameter = None):
+    def __init__(self, id, lane_parameter=None):
         self.id = id
         # self.left_lane:List[str] = left_lane
-        # self.right_lane:List[str] = right_lane 
+        # self.right_lane:List[str] = right_lane
         # self.next_segment:int = next_segment
 
-        self.lane_parameter = None 
+        self.lane_parameter = None
         if lane_parameter is not None:
             self.lane_parameter = lane_parameter
 
     def get_geometry(self):
-        return self.lane_parameter
\ No newline at end of file
+        return self.lane_parameter
-- 
GitLab