diff --git a/demo/demo1.py b/demo/demo1.py
index b4c57e0c6047e4ab6d0feaec9a072a2697989757..44b98844ff3bae7c5c5b80e087f242ee420cd85a 100644
--- a/demo/demo1.py
+++ b/demo/demo1.py
@@ -1,7 +1,7 @@
 from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
-from dryvr_plus_plus.plotter.plotter2D import *
+from dryvr_plus_plus.plotter.plotter2D_new import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
 import plotly.graph_objects as go
 import numpy as np
@@ -41,7 +41,7 @@ if __name__ == "__main__":
     scenario.add_agent(car)
     car = CarAgent('car2', file_name=input_code_name)
     scenario.add_agent(car)
-    tmp_map = SimpleMap3()
+    tmp_map = SimpleMap2()
     scenario.set_map(tmp_map)
     scenario.set_sensor(FakeSensor2())
     scenario.set_init(
@@ -54,26 +54,8 @@ if __name__ == "__main__":
             (VehicleMode.Normal, LaneMode.Lane1),
         ]
     )
-    # res_list = scenario.simulate_multi(10, 1)
-    # # traces = scenario.verify(10)
-
-    # fig = plt.figure(2)
-    # # fig = plot_map(tmp_map, 'g', fig)
-    # # fig = plot_reachtube_tree(traces, 'car1', 0, [1], 'b', fig, (1000,-1000), (1000,-1000))
-    # # fig = plot_reachtube_tree(traces, 'car2', 0, [1], 'r', fig)
-    # for traces in res_list:
-    #     fig = plot_simulation_tree(
-    #         traces, 'car1', 1, [2], 'b', fig, (1000, -1000), (1000, -1000))
-    #     fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
-    #     # generate_simulation_anime(traces, tmp_map, fig)
-
-    # plt.show()
 
     traces = scenario.simulate(10, 0.01)
     fig = go.Figure()
-    # fig = plotly_map(tmp_map, 'g', fig)
-    # fig = plotly_simulation_tree(
-    #     traces, 'car1', 0, [1], 'b', fig, (1000, -1000), (1000, -1000))
-    # fig = plotly_simulation_tree(traces, 'car2', 0, [1], 'r', fig)
-    fig = plotly_simulation_anime(traces, tmp_map, fig)
+    fig = test_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
     fig.show()
diff --git a/demo/demo10.py b/demo/demo10.py
new file mode 100644
index 0000000000000000000000000000000000000000..13b158c2fd5d60988bc951df16d7ed911585e435
--- /dev/null
+++ b/demo/demo10.py
@@ -0,0 +1,87 @@
+from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
+from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
+from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6, SimpleMap3_v2
+from dryvr_plus_plus.plotter.plotter2D import *
+from dryvr_plus_plus.plotter.plotter2D_new import *
+from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
+import plotly.graph_objects as go
+import matplotlib.pyplot as plt
+import numpy as np
+from enum import Enum, auto
+
+
+class VehicleMode(Enum):
+    Normal = auto()
+    SwitchLeft = auto()
+    SwitchRight = auto()
+    Brake = auto()
+    Accelerate = auto()
+
+
+class LaneMode(Enum):
+    Lane0 = auto()
+    Lane1 = auto()
+    Lane2 = auto()
+
+
+class State:
+    x = 0.0
+    y = 0.0
+    theta = 0.0
+    v = 0.0
+    vehicle_mode: VehicleMode = VehicleMode.Normal
+    lane_mode: LaneMode = LaneMode.Lane0
+
+    def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode):
+        self.data = []
+
+
+if __name__ == "__main__":
+    input_code_name = 'example_controller10.py'
+    scenario = Scenario()
+
+    car = CarAgent('car1', file_name=input_code_name)
+    scenario.add_agent(car)
+    car = CarAgent('car2', file_name=input_code_name)
+    scenario.add_agent(car)
+    tmp_map = SimpleMap3_v2()
+    scenario.set_map(tmp_map)
+    scenario.set_sensor(FakeSensor2())
+    scenario.set_init(
+        [
+            [[0, -0.2, 0, 1.0], [0.1, 0.2, 0, 1.0]],
+            [[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
+        ],
+        [
+            (VehicleMode.Normal, LaneMode.Lane1),
+            (VehicleMode.Normal, LaneMode.Lane1),
+        ]
+    )
+    # traces = scenario.verify(30)
+    # # fig = go.Figure()
+    # # fig = plotly_reachtube_tree_v2(traces, 'car1', 1, [2], 'blue', fig)
+    # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
+    # # fig.show()
+    # fig = go.Figure()
+    # fig = generate_reachtube_anime(traces, tmp_map, fig)
+    # # # fig = plotly_reachtube_tree_v2(traces, 'car2', 1, [2], 'red', fig)
+    # fig.show()
+    # fig = plt.figure(2)
+    # fig = plot_map(tmp_map, 'g', fig)
+    # # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
+    # # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
+    # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
+    # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
+    # plt.show()
+    # # fig1 = plt.figure(2)
+    # fig = generate_simulation_anime(traces, tmp_map, fig)
+    # plt.show()
+
+    traces = scenario.simulate(25)
+    # fig = go.Figure()
+    # fig = plotly_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
+    # fig.show()
+    fig = go.Figure()
+    fig = draw_simulation_tree(traces, tmp_map, fig, 1, 2, 'detailed')
+    # fig = plotly_map(tmp_map, fig=fig)
+    fig.show()
diff --git a/demo/demo2.py b/demo/demo2.py
index 7c8d006f6a2c9bb62dd972e5a2121ee64219dd91..2386b599c4e6e2f6fb3338c18664db1533c60665 100644
--- a/demo/demo2.py
+++ b/demo/demo2.py
@@ -2,8 +2,11 @@ from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
 from dryvr_plus_plus.plotter.plotter2D import *
+from dryvr_plus_plus.plotter.plotter2D_new import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
 import plotly.graph_objects as go
+# import matplotlib.pyplot as plt
+
 import numpy as np
 from enum import Enum, auto
 
@@ -54,17 +57,9 @@ if __name__ == "__main__":
             (VehicleMode.Normal, LaneMode.Lane1),
         ]
     )
-    # res_list = scenario.simulate(40)
-    traces = scenario.verify(40)
 
-    fig = plt.figure(2)
-    fig = plot_map(tmp_map, 'g', fig)
-    fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
-    fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
-    plt.show()
+    traces = scenario.simulate(30)
+    fig = go.Figure()
+    fig = test_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
+    fig.show()
 
-    # # this is for plot-based visualization
-    # traces = scenario.simulate(40)
-    # fig = go.Figure()
-    # fig = plotly_simulation_anime(traces, tmp_map, fig)
-    # fig.show()
diff --git a/demo/example_controller1.py b/demo/example_controller1.py
index bc3c11401fe4fbd7c648d09614ef97a091014f9a..3d467ccec825dd16d599d35eafd5cd1bbc83f2b9 100644
--- a/demo/example_controller1.py
+++ b/demo/example_controller1.py
@@ -1,17 +1,24 @@
 from enum import Enum, auto
 import copy
 
+from sympy import false
+
+from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
+
+
 class VehicleMode(Enum):
     Normal = auto()
     SwitchLeft = auto()
     SwitchRight = auto()
     Brake = auto()
 
+
 class LaneMode(Enum):
     Lane0 = auto()
     Lane1 = auto()
     Lane2 = auto()
 
+
 class State:
     x = 0.0
     y = 0.0
@@ -23,17 +30,17 @@ class State:
     def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode):
         self.data = []
 
-def controller(ego:State, other:State, lane_map):
+
+def controller(ego: State, other: State, lane_map: LaneMap):
     output = copy.deepcopy(ego)
     if ego.vehicle_mode == VehicleMode.Normal:
-        if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 0 \
-        and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 3 \
-        and ego.lane_mode == other.lane_mode:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 0 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 3 \
+                and ego.lane_mode == other.lane_mode:
             output.vehicle_mode = VehicleMode.Brake
     elif ego.vehicle_mode == VehicleMode.Brake:
-        if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 10 \
-            or ego.lane_mode != other.lane_mode:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+                or ego.lane_mode != other.lane_mode:
             output.vehicle_mode = VehicleMode.Normal
 
     return output
-    
diff --git a/demo/example_controller10.py b/demo/example_controller10.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1272a318ea96b5f99f044d957281047335abe11
--- /dev/null
+++ b/demo/example_controller10.py
@@ -0,0 +1,77 @@
+from enum import Enum, auto
+import copy
+from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
+
+
+class VehicleMode(Enum):
+    Normal = auto()
+    SwitchLeft = auto()
+    SwitchRight = auto()
+    Brake = auto()
+    Accelerate = auto()
+
+
+class LaneMode(Enum):
+    Lane0 = auto()
+    Lane1 = auto()
+    Lane2 = auto()
+
+
+class LaneObjectMode(Enum):
+    Vehicle = auto()
+    Ped = auto()        # Pedestrians
+    Sign = auto()       # Signs, stop signs, merge, yield etc.
+    Signal = auto()     # Traffic lights
+    Obstacle = auto()   # Static (to road/lane) obstacles
+
+
+class State:
+    x = 0.0
+    y = 0.0
+    theta = 0.0
+    v = 0.0
+    vehicle_mode: VehicleMode = VehicleMode.Normal
+    lane_mode: LaneMode = LaneMode.Lane0
+    type: LaneObjectMode
+
+    def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode):
+        self.data = []
+
+
+def controller(ego: State, other: State, sign: State, lane_map: LaneMap):
+    output = copy.deepcopy(ego)
+    if ego.vehicle_mode == VehicleMode.Normal:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
+            if lane_map.has_left(ego.lane_mode):
+                output.vehicle_mode = VehicleMode.SwitchLeft
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
+            if lane_map.has_right(ego.lane_mode):
+                output.vehicle_mode = VehicleMode.SwitchRight
+        # if ego.lane_mode != other.lane_mode:
+        #     output.vehicle_mode = VehicleMode.Accelerate
+        # if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 5:
+        #     output.vehicle_mode = VehicleMode.Accelerate
+        if lane_map.get_speed_limit(ego.lane_mode) > 1:
+            output.vehicle_mode = VehicleMode.Accelerate
+
+    if ego.vehicle_mode == VehicleMode.SwitchLeft:
+        if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5:
+            output.vehicle_mode = VehicleMode.Normal
+            output.lane_mode = lane_map.left_lane(ego.lane_mode)
+    if ego.vehicle_mode == VehicleMode.SwitchRight:
+        if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) <= -2.5:
+            output.vehicle_mode = VehicleMode.Normal
+            output.lane_mode = lane_map.right_lane(ego.lane_mode)
+    if ego.vehicle_mode == VehicleMode.Accelerate:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
+            output.vehicle_mode = VehicleMode.Normal
+        # if lane_map.get_speed_limit(ego.lane_mode, [ego.x, ego.y]) <= ego.v:
+        #     output.vehicle_mode = VehicleMode.Normal
+
+    return output
diff --git a/demo/example_two_car_sign_lane_switch.py b/demo/example_two_car_sign_lane_switch.py
index fcfaf9235f447795527bbcc46b3923130295e149..64524b310cb1b90615dae80d0b5878298889e15f 100644
--- a/demo/example_two_car_sign_lane_switch.py
+++ b/demo/example_two_car_sign_lane_switch.py
@@ -1,6 +1,14 @@
+import matplotlib.pyplot as plt
+from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
+from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree, generate_simulation_anime, plot_map
+from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3
+from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
+from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent
+from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
 from enum import Enum, auto
 import copy
 
+
 class LaneObjectMode(Enum):
     Vehicle = auto()
     Ped = auto()        # Pedestrians
@@ -8,6 +16,7 @@ class LaneObjectMode(Enum):
     Signal = auto()     # Traffic lights
     Obstacle = auto()   # Static (to road/lane) obstacles
 
+
 class VehicleMode(Enum):
     Normal = auto()
     SwitchLeft = auto()
@@ -20,6 +29,7 @@ class LaneMode(Enum):
     Lane1 = auto()
     Lane2 = auto()
 
+
 class State:
     x: float
     y: float
@@ -40,24 +50,25 @@ class State:
         # self.lane_mode = lane_mode
         # self.obj_mode = obj_mode
 
+
 def controller(ego: State, other: State, sign: State, lane_map):
     output = copy.deepcopy(ego)
     if ego.vehicle_mode == VehicleMode.Normal:
         if sign.type == LaneObjectMode.Obstacle and sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode:
             output.vehicle_mode = VehicleMode.SwitchLeft
             return output
-        if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \
-        and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \
-        and ego.lane_mode == other.lane_mode:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
             if lane_map.has_left(ego.lane_mode):
                 output.vehicle_mode = VehicleMode.SwitchLeft
-        if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \
-        and lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) < 5 \
-        and ego.lane_mode == other.lane_mode:
+        if lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) > 3 \
+            and lane_map.get_longitudinal_position(other.lane_mode, [other.x, other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x, ego.y]) < 5 \
+                and ego.lane_mode == other.lane_mode:
             if lane_map.has_right(ego.lane_mode):
                 output.vehicle_mode = VehicleMode.SwitchRight
     if ego.vehicle_mode == VehicleMode.SwitchLeft:
-        if  lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5:
+        if lane_map.get_lateral_distance(ego.lane_mode, [ego.x, ego.y]) >= 2.5:
             output.vehicle_mode = VehicleMode.Normal
             output.lane_mode = lane_map.left_lane(ego.lane_mode)
     if ego.vehicle_mode == VehicleMode.SwitchRight:
@@ -68,15 +79,6 @@ def controller(ego: State, other: State, sign: State, lane_map):
     return output
 
 
-from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
-from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent
-from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
-from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap3
-from dryvr_plus_plus.plotter.plotter2D import plot_reachtube_tree, plot_simulation_tree
-from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
-
-import matplotlib.pyplot as plt
-
 if __name__ == "__main__":
     import sys
     input_code_name = sys.argv[0]
@@ -91,9 +93,9 @@ if __name__ == "__main__":
     scenario.set_sensor(FakeSensor2())
     scenario.set_init(
         [
-            [[0, -0.2, 0, 1.0],[0.2, 0.2, 0, 1.0]],
-            [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], 
-            [[20, 0, 0, 0],[20, 0, 0, 0]],
+            [[0, -0.2, 0, 1.0], [0.2, 0.2, 0, 1.0]],
+            [[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
+            [[20, 0, 0, 0], [20, 0, 0, 0]],
         ],
         [
             (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
@@ -106,8 +108,11 @@ if __name__ == "__main__":
     # traces = scenario.verify(40)
 
     fig = plt.figure()
-    fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
-    fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
+    fig = plot_map(SimpleMap3(), 'g', fig)
+    # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
+    # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
+    fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
+    fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
 
+    # generate_simulation_anime(traces, SimpleMap3(), fig)
     plt.show()
-
diff --git a/demo/plot2.py b/demo/plot2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c8a74c6b000de18fc232bed841ea8b0c47a55c8
--- /dev/null
+++ b/demo/plot2.py
@@ -0,0 +1,133 @@
+import plotly.graph_objects as go
+
+import pandas as pd
+
+url = "https://raw.githubusercontent.com/plotly/datasets/master/gapminderDataFiveYear.csv"
+dataset = pd.read_csv(url)
+
+years = ["1952", "1962", "1967", "1972", "1977", "1982", "1987", "1992", "1997", "2002",
+         "2007"]
+
+# make list of continents
+continents = []
+for continent in dataset["continent"]:
+    if continent not in continents:
+        continents.append(continent)
+# make figure
+fig_dict = {
+    "data": [],
+    "layout": {},
+    "frames": []
+}
+
+# fill in most of layout
+fig_dict["layout"]["xaxis"] = {"range": [30, 85], "title": "Life Expectancy"}
+fig_dict["layout"]["yaxis"] = {"title": "GDP per Capita", "type": "log"}
+fig_dict["layout"]["hovermode"] = "closest"
+fig_dict["layout"]["updatemenus"] = [
+    {
+        "buttons": [
+            {
+                "args": [None, {"frame": {"duration": 500, "redraw": False},
+                                "fromcurrent": True, "transition": {"duration": 300,
+                                                                    "easing": "quadratic-in-out"}}],
+                "label": "Play",
+                "method": "animate"
+            },
+            {
+                "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                  "mode": "immediate",
+                                  "transition": {"duration": 0}}],
+                "label": "Pause",
+                "method": "animate"
+            }
+        ],
+        "direction": "left",
+        "pad": {"r": 10, "t": 87},
+        "showactive": False,
+        "type": "buttons",
+        "x": 0.1,
+        "xanchor": "right",
+        "y": 0,
+        "yanchor": "top"
+    }
+]
+
+sliders_dict = {
+    "active": 0,
+    "yanchor": "top",
+    "xanchor": "left",
+    "currentvalue": {
+        "font": {"size": 20},
+        "prefix": "Year:",
+        "visible": False,
+        "xanchor": "right"
+    },
+    "transition": {"duration": 300, "easing": "cubic-in-out"},
+    "pad": {"b": 10, "t": 50},
+    "len": 0.9,
+    "x": 0.1,
+    "y": 0,
+    "steps": []
+}
+
+# make data
+year = 1952
+for continent in continents:
+    dataset_by_year = dataset[dataset["year"] == year]
+    dataset_by_year_and_cont = dataset_by_year[
+        dataset_by_year["continent"] == continent]
+
+    data_dict = {
+        "x": list(dataset_by_year_and_cont["lifeExp"]),
+        "y": list(dataset_by_year_and_cont["gdpPercap"]),
+        "mode": "lines",
+        "text": list(dataset_by_year_and_cont["country"]),
+        "marker": {
+            "sizemode": "area",
+            "sizeref": 200000,
+            "size": list(dataset_by_year_and_cont["pop"])
+        },
+        "name": continent
+    }
+    fig_dict["data"].append(data_dict)
+
+# make frames
+for year in years:
+    frame = {"data": [], "name": str(year)}
+    for continent in continents:
+        dataset_by_year = dataset[dataset["year"] == int(year)]
+        dataset_by_year_and_cont = dataset_by_year[
+            dataset_by_year["continent"] == continent]
+
+        data_dict = {
+            "x": list(dataset_by_year_and_cont["lifeExp"]),
+            "y": list(dataset_by_year_and_cont["gdpPercap"]),
+            "mode": "lines",
+            "text": list(dataset_by_year_and_cont["country"]),
+            "marker": {
+                "sizemode": "area",
+                "sizeref": 200000,
+                "size": list(dataset_by_year_and_cont["pop"])
+            },
+            "name": continent
+        }
+        frame["data"].append(data_dict)
+
+    fig_dict["frames"].append(frame)
+    slider_step = {"args": [
+        [year],
+        {"frame": {"duration": 300, "redraw": False},
+         "mode": "immediate",
+         "transition": {"duration": 300}}
+    ],
+        "label": year,
+        "method": "animate"}
+    sliders_dict["steps"].append(slider_step)
+
+
+fig_dict["layout"]["sliders"] = [sliders_dict]
+
+fig = go.Figure(fig_dict)
+
+fig.show()
diff --git a/demo/plot_test.py b/demo/plot_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..83bedd06d50750323deaa0716b8113b3bdd10023
--- /dev/null
+++ b/demo/plot_test.py
@@ -0,0 +1,164 @@
+from turtle import color
+import plotly.graph_objects as go
+import numpy as np
+
+
+x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+x_rev = x[::-1]
+x2 = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
+x2_rev = x2[::-1]
+
+
+# Line 1
+y1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+y1_upper = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+y1_lower = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+y1_lower = y1_lower[::-1]
+
+# Line 2
+y2 = [5, 2.5, 5, 7.5, 5, 2.5, 7.5, 4.5, 5.5, 5]
+y2_upper = [5.5, 3, 5.5, 8, 6, 3, 8, 5, 6, 5.5]
+y2_lower = [4.5, 2, 4.4, 7, 4, 2, 7, 4, 5, 4.75]
+y2_lower = y2_lower[::-1]
+
+# Line 3
+y3 = [10, 8, 6, 4, 2, 0, 2, 4, 2, 0]
+y3_upper = [11, 9, 7, 5, 3, 1, 3, 5, 3, 1]
+y3_lower = [9, 7, 5, 3, 1, -.5, 1, 3, 1, -1]
+y3_lower = y3_lower[::-1]
+
+
+fig = go.Figure()
+
+# fig.add_trace(go.Scatter(
+#     x=x+x_rev + x2+x2_rev,
+#     y=y1_upper+y1_lower+y1_upper+y1_lower,
+#     fill='toself',
+#     marker=dict(
+#         symbol='square',
+#         size=16,
+#         cmax=39,
+#         cmin=0,
+#         color='rgb(0,100,80)',
+#         colorbar=dict(
+#             title="Colorbar"
+#         ),
+#         colorscale="Viridis"
+#     ),
+#     # fillcolor='rgba(0,100,80,0.2)',
+#     # line_color='rgba(255,255,255,0)',
+#     showlegend=False,
+#     name='Fair',
+#     mode="markers"
+# ))
+# fig.add_trace(go.Scatter(
+#     x=[1, 2, 2, 1,  2, 3, 3, 2],
+#     y=[1, 1, 2, 2,  3, 3, 4, 4],
+#     fill='toself',
+#     marker=dict(
+#         symbol='square',
+#         size=16,
+#         cmax=3,
+#         cmin=0,
+#         color=[1, 2],
+#         colorbar=dict(
+#             title="Colorbar"
+#         ),
+#         colorscale="Viridis"
+#     ),
+#     # fillcolor='rgba(0,100,80,0.2)',
+#     # line_color='rgba(255,255,255,0)',
+#     showlegend=False,
+#     name='Fair',
+#     mode="markers"
+# ))
+start = [0, 0, 255, 0.5]
+end = [255, 0, 0, 0.5]
+rgb = [0, 0, 255, 0.5]
+pot = 4
+for i in range(len(rgb)-1):
+    rgb[i] = rgb[i] + (pot-0)/(4-0)*(end[i]-start[i])
+print(rgb)
+fig.add_trace(go.Scatter(
+    x=[1, 2, 2, 1],
+    y=[1, 1, 2, 2],
+    fill='toself',
+    marker=dict(
+        symbol='square',
+        size=16,
+        cmax=4,
+        cmin=0,
+        # color=[3, 3, 3, 3],
+        colorbar=dict(
+            title="Colorbar"
+        ),
+        colorscale=[[0, 'rgba(0,0,255,0.5)'], [1, 'rgba(255,0,0,0.5)']]
+    ),
+
+    fillcolor='rgba'+str(tuple(rgb)),
+    # fillcolor='rgba(0,100,80,0.2)',
+    # line_color='rgba(255,255,255,0)',
+    showlegend=False,
+    name='Fair',
+    mode="markers"
+))
+# for i in range(10):
+#     fig.add_trace(go.Scatter(
+#         x=x+x_rev,
+#         y=y3_upper+y3_lower,
+#         fill='toself',
+#         marker=dict(
+#             symbol='square',
+#             size=16,
+#             cmax=39,
+#             cmin=0,
+#             color='rgb(0,100,80)',
+#             colorbar=dict(
+#                 title="Colorbar"
+#             ),
+#             colorscale="Viridis"
+#         ),
+#         # fillcolor='rgba(0,100,80,0.2)',
+#         # line_color='rgba(255,255,255,0)',
+#         showlegend=False,
+#         name='Fair',
+#         mode="markers"
+#     ))
+# fig.add_trace(go.Scatter(
+#     x=x+x_rev,
+#     y=y2_upper+y2_lower,
+#     fill='toself',
+#     fillcolor='rgba(0,176,246,0.2)',
+#     line_color='rgba(255,255,255,0)',
+#     name='Premium',
+#     showlegend=False,
+# ))
+# fig.add_trace(go.Scatter(
+#     x=x+x_rev,
+#     y=y3_upper+y3_lower,
+#     fill='toself',
+#     fillcolor='rgba(231,107,243,0.2)',
+#     line_color='rgba(255,255,255,0)',
+#     showlegend=False,
+#     name='Ideal',
+# ))
+# fig.add_trace(go.Scatter(
+#     x=x, y=y1,
+#     line_color='rgb(0,100,80)',
+#     name='Fair',
+# ))
+# fig.add_trace(go.Scatter(
+#     x=x, y=y2,
+#     line_color='rgb(0,176,246)',
+#     name='Premium',
+# ))
+# fig.add_trace(go.Scatter(
+#     x=x, y=y3,
+#     line_color='rgb(231,107,243)',
+#     name='Ideal',
+# ))
+
+# fig.update_traces(mode='lines')
+fig.show()
+# print(x+x_rev)
+# print(y1_upper+y1_lower)
diff --git a/demo/plot_test1.py b/demo/plot_test1.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba2de2619df0f8ab8f132f866a584f737bd6822
--- /dev/null
+++ b/demo/plot_test1.py
@@ -0,0 +1,23 @@
+import plotly.graph_objects as go
+import numpy as np
+
+fig = go.Figure()
+values = [1, 2, 3]
+fig.add_trace(go.Scatter(
+    x=values,
+    y=values,
+    marker=dict(
+        symbol='square',
+        size=16,
+        cmax=39,
+        cmin=0,
+        color=values,
+        colorbar=dict(
+            title="Colorbar"
+        ),
+        colorscale="Viridis"
+    ),
+    # marker_symbol='square', marker_line_color="midnightblue", marker_color="lightskyblue",
+    # marker_line_width=2, marker_size=15,
+    mode="markers"))
+fig.show()
diff --git a/dist/dryvr_plus_plus-0.1-py3.8.egg b/dist/dryvr_plus_plus-0.1-py3.8.egg
new file mode 100644
index 0000000000000000000000000000000000000000..cdd7c8753229a16431604862518b4f6f9d13a03d
Binary files /dev/null and b/dist/dryvr_plus_plus-0.1-py3.8.egg differ
diff --git a/dryvr_plus_plus/example/example_agent/car_agent.py b/dryvr_plus_plus/example/example_agent/car_agent.py
index dde521e454e420d0e9a541a579c0b4d40592f101..2470e10801588d30b4756fdae91d2126704f3ab6 100644
--- a/dryvr_plus_plus/example/example_agent/car_agent.py
+++ b/dryvr_plus_plus/example/example_agent/car_agent.py
@@ -1,13 +1,14 @@
 # Example agent.
 from typing import Tuple, List
 
-import numpy as np 
+import numpy as np
 from scipy.integrate import ode
 
 from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
 from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
 from dryvr_plus_plus.scene_verifier.code_parser.pythonparser import EmptyAst
 
+
 class NPCAgent(BaseAgent):
     def __init__(self, id):
         self.id = id
@@ -16,99 +17,110 @@ class NPCAgent(BaseAgent):
     @staticmethod
     def dynamic(t, state, u):
         x, y, theta, v = state
-        delta, a = u  
+        delta, a = u
         x_dot = v*np.cos(theta+delta)
         y_dot = v*np.sin(theta+delta)
         theta_dot = v/1.75*np.sin(delta)
-        v_dot = a 
+        v_dot = a
         return [x_dot, y_dot, theta_dot, v_dot]
 
-    def action_handler(self, mode, state, lane_map:LaneMap)->Tuple[float, float]:
+
+    def action_handler(self, mode, state, lane_map: LaneMap) -> Tuple[float, float]:
         ''' Computes steering and acceleration based on current lane, target lane and
             current state using a Stanley controller-like rule'''
-        x,y,theta,v = state
+        x, y, theta, v = state
         vehicle_mode = mode[0]
         vehicle_lane = mode[1]
-        vehicle_pos = np.array([x,y])
+        vehicle_pos = np.array([x, y])
         d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
         psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta
         steering = psi + np.arctan2(0.45*d, v)
         steering = np.clip(steering, -0.61, 0.61)
         a = 0
-        return steering, a  
+        return steering, a
+
 
     def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray:
         time_bound = float(time_bound)
         number_points = int(np.ceil(time_bound/time_step))
-        t = [i*time_step for i in range(0,number_points)]
+        t = [i*time_step for i in range(0, number_points)]
 
         init = initialCondition
         trace = [[0]+init]
         for i in range(len(t)):
             steering, a = self.action_handler(mode, init, lane_map)
-            r = ode(self.dynamic)    
-            r.set_initial_value(init).set_f_params([steering, a])      
-            res:np.ndarray = r.integrate(r.t + time_step)
+            r = ode(self.dynamic)
+            r.set_initial_value(init).set_f_params([steering, a])
+            res: np.ndarray = r.integrate(r.t + time_step)
             init = res.flatten().tolist()
-            trace.append([t[i] + time_step] + init) 
+            trace.append([t[i] + time_step] + init)
 
         return np.array(trace)
 
+
 class CarAgent(BaseAgent):
-    def __init__(self, id, code = None, file_name = None):
+    def __init__(self, id, code=None, file_name=None):
         super().__init__(id, code, file_name)
 
     @staticmethod
     def dynamic(t, state, u):
         x, y, theta, v = state
-        delta, a = u  
+        delta, a = u
         x_dot = v*np.cos(theta+delta)
         y_dot = v*np.sin(theta+delta)
         theta_dot = v/1.75*np.sin(delta)
-        v_dot = a 
+        v_dot = a
         return [x_dot, y_dot, theta_dot, v_dot]
 
-    def action_handler(self, mode: List[str], state, lane_map:LaneMap)->Tuple[float, float]:
-        x,y,theta,v = state
+    def action_handler(self, mode: List[str], state, lane_map: LaneMap) -> Tuple[float, float]:
+        x, y, theta, v = state
         vehicle_mode = mode[0]
         vehicle_lane = mode[1]
-        vehicle_pos = np.array([x,y])
+        vehicle_pos = np.array([x, y])
         a = 0
         if vehicle_mode == "Normal":
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
+            # # keyi: just toy mod
+            # if v <= 2:
+            #     a = 0.2
         elif vehicle_mode == "SwitchLeft":
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) + 3
         elif vehicle_mode == "SwitchRight":
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos) - 3
         elif vehicle_mode == "Brake":
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
-            a = -1    
-        elif vehicle_mode == "Accel":
+            a = -1
+            if v <= 0.02:
+                a = 0
+        elif vehicle_mode == "Accelerate":
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
             a = 1
+            if v >= lane_map.get_speed_limit(vehicle_lane)-0.02:
+                a = 0
         elif vehicle_mode == 'Stop':
             d = -lane_map.get_lateral_distance(vehicle_lane, vehicle_pos)
             a = 0
         psi = lane_map.get_lane_heading(vehicle_lane, vehicle_pos)-theta
         steering = psi + np.arctan2(0.45*d, v)
         steering = np.clip(steering, -0.61, 0.61)
-        return steering, a  
+        return steering, a
+
 
     def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray:
         time_bound = float(time_bound)
         number_points = int(np.ceil(time_bound/time_step))
         t = [round(i*time_step,10) for i in range(0,number_points)]
-
         init = initialCondition
+        # [time, x, y, theta, v]
         trace = [[0]+init]
         for i in range(len(t)):
             steering, a = self.action_handler(mode, init, lane_map)
-            r = ode(self.dynamic)    
-            r.set_initial_value(init).set_f_params([steering, a])      
-            res:np.ndarray = r.integrate(r.t + time_step)
+            r = ode(self.dynamic)
+            r.set_initial_value(init).set_f_params([steering, a])
+            res: np.ndarray = r.integrate(r.t + time_step)
             init = res.flatten().tolist()
             if init[3] < 0:
                 init[3] = 0
-            trace.append([t[i] + time_step] + init) 
+            trace.append([t[i] + time_step] + init)
 
         return np.array(trace)
diff --git a/dryvr_plus_plus/example/example_map/simple_map2.py b/dryvr_plus_plus/example/example_map/simple_map2.py
index 1e249851787d6dbd2368eb2427f189ec99e9ac38..257b2308188045131c5f5d5736e10e7be05dc60f 100644
--- a/dryvr_plus_plus/example/example_map/simple_map2.py
+++ b/dryvr_plus_plus/example/example_map/simple_map2.py
@@ -4,18 +4,20 @@ from dryvr_plus_plus.scene_verifier.map.lane import Lane
 
 import numpy as np
 
+
 class SimpleMap2(LaneMap):
     def __init__(self):
         super().__init__()
         segment0 = StraightLane(
             'seg0',
-            [0,0],
-            [100,0],
+            [0, 0],
+            [100, 0],
             3
         )
         lane0 = Lane('Lane1', [segment0])
         self.add_lanes([lane0])
 
+
 class SimpleMap3(LaneMap):
     def __init__(self):
         super().__init__()
@@ -48,6 +50,41 @@ class SimpleMap3(LaneMap):
         self.right_lane_dict[lane0.id].append(lane1.id)
         self.right_lane_dict[lane1.id].append(lane2.id)
 
+class SimpleMap3_v2(LaneMap):
+    def __init__(self):
+        super().__init__()
+        segment0 = StraightLane(
+            'Seg0',
+            [0, 3],
+            [50, 3],
+            3,
+            speed_limit=[(0, 1), (10, 2)]
+        )
+        lane0 = Lane('Lane0', [segment0], speed_limit=2)
+        segment1 = StraightLane(
+            'seg0',
+            [0, 0],
+            [50, 0],
+            3,
+            speed_limit=[(0, 1), (20, 3)]
+        )
+        lane1 = Lane('Lane1', [segment1], speed_limit=1)
+        segment2 = StraightLane(
+            'seg0',
+            [0, -3],
+            [50, -3],
+            3,
+            speed_limit=[(0, 1), (25, 2.5)]
+        )
+        lane2 = Lane('Lane2', [segment2], speed_limit=3)
+        # segment2 = LaneSegment('Lane1', 3)
+        # self.add_lanes([segment1,segment2])
+        self.add_lanes([lane0, lane1, lane2])
+        self.left_lane_dict[lane1.id].append(lane0.id)
+        self.left_lane_dict[lane2.id].append(lane1.id)
+        self.right_lane_dict[lane0.id].append(lane1.id)
+        self.right_lane_dict[lane1.id].append(lane2.id)
+
 class SimpleMap4(LaneMap):
     def __init__(self):
         super().__init__()
@@ -104,58 +141,58 @@ class SimpleMap5(LaneMap):
         super().__init__()
         segment0 = StraightLane(
             'Seg0',
-            [0,3],
-            [15,3],
+            [0, 3],
+            [15, 3],
             3
         )
         segment1 = StraightLane(
             'Seg1',
-            [15,3], 
-            [25,13],
+            [15, 3],
+            [25, 13],
             3
         )
         segment2 = StraightLane(
             'Seg2',
-            [25,13], 
-            [50,13],
+            [25, 13],
+            [50, 13],
             3
         )
         lane0 = Lane('Lane0', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,0],
-            [17,0],
+            [0, 0],
+            [17, 0],
             3
         )
         segment1 = StraightLane(
             'seg1',
-            [17,0],
-            [27,10],
+            [17, 0],
+            [27, 10],
             3
         )
         segment2 = StraightLane(
             'seg2',
-            [27,10],
-            [50,10],
+            [27, 10],
+            [50, 10],
             3
         )
         lane1 = Lane('Lane1', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,-3],
-            [19,-3],
+            [0, -3],
+            [19, -3],
             3
         )
         segment1 = StraightLane(
             'seg1',
-            [19,-3],
-            [29,7],
+            [19, -3],
+            [29, 7],
             3
         )
         segment2 = StraightLane(
             'seg2',
-            [29,7],
-            [50,7],
+            [29, 7],
+            [50, 7],
             3
         )
         lane2 = Lane('Lane2', [segment0, segment1, segment2])
@@ -165,18 +202,19 @@ class SimpleMap5(LaneMap):
         self.right_lane_dict[lane0.id].append(lane1.id)
         self.right_lane_dict[lane1.id].append(lane2.id)
 
+
 class SimpleMap6(LaneMap):
     def __init__(self):
         super().__init__()
         segment0 = StraightLane(
             'Seg0',
-            [0,3],
-            [15,3],
+            [0, 3],
+            [15, 3],
             3
         )
         segment1 = CircularLane(
             'Seg1',
-            [15,8],
+            [15, 8],
             5,
             np.pi*3/2,
             np.pi*2,
@@ -185,20 +223,20 @@ class SimpleMap6(LaneMap):
         )
         segment2 = StraightLane(
             'Seg2',
-            [20,8], 
-            [20,30],
+            [20, 8],
+            [20, 30],
             3
         )
         lane0 = Lane('Lane0', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,0],
-            [18,0],
+            [0, 0],
+            [18, 0],
             3
         )
         segment1 = CircularLane(
             'seg1',
-            [18,5],
+            [18, 5],
             5,
             3*np.pi/2,
             2*np.pi,
@@ -207,20 +245,20 @@ class SimpleMap6(LaneMap):
         )
         segment2 = StraightLane(
             'seg2',
-            [23,5],
-            [23,30],
+            [23, 5],
+            [23, 30],
             3
         )
         lane1 = Lane('Lane1', [segment0, segment1, segment2])
         segment0 = StraightLane(
             'seg0',
-            [0,-3],
-            [21,-3],
+            [0, -3],
+            [21, -3],
             3
         )
         segment1 = CircularLane(
             'seg1',
-            [21,2],
+            [21, 2],
             5,
             np.pi*3/2,
             np.pi*2,
@@ -229,8 +267,8 @@ class SimpleMap6(LaneMap):
         )
         segment2 = StraightLane(
             'seg2',
-            [26,2],
-            [26,30],
+            [26, 2],
+            [26, 30],
             3
         )
         lane2 = Lane('Lane2', [segment0, segment1, segment2])
@@ -240,6 +278,7 @@ class SimpleMap6(LaneMap):
         self.right_lane_dict[lane0.id].append(lane1.id)
         self.right_lane_dict[lane1.id].append(lane2.id)
 
+
 if __name__ == "__main__":
     test_map = SimpleMap3()
     print(test_map.left_lane_dict)
diff --git a/dryvr_plus_plus/plotter/plotter2D.py b/dryvr_plus_plus/plotter/plotter2D.py
index d72a2b35f60325e9ffbf856aba90ecd7aa980c5f..8dc9f5e644c22a97adee89d0cc1902050257660d 100644
--- a/dryvr_plus_plus/plotter/plotter2D.py
+++ b/dryvr_plus_plus/plotter/plotter2D.py
@@ -2,6 +2,10 @@
 This file consist main plotter code for DryVR reachtube output
 """
 
+from __future__ import annotations
+from audioop import reverse
+# from curses import start_color
+from re import A
 import matplotlib.patches as patches
 import matplotlib.pyplot as plt
 import numpy as np
@@ -10,9 +14,50 @@ import plotly.graph_objects as go
 from typing import List
 from PIL import Image, ImageDraw
 import io
+import copy
+import operator
+from collections import OrderedDict
+
+from torch import layout
+from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
 
 colors = ['red', 'green', 'blue', 'yellow', 'black']
 
+def plotly_plot(data,
+                x_dim: int = 0,
+                y_dim_list: List[int] = [1],
+                color='blue',
+                fig=None,
+                x_lim=None,
+                y_lim=None
+                ):
+    if fig is None:
+        fig = plt.figure()
+
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    for rect in data:
+        lb = rect[0]
+        ub = rect[1]
+        for y_dim in y_dim_list:
+            fig.add_shape(type="rect",
+                          x0=lb[x_dim], y0=lb[y_dim], x1=ub[x_dim], y1=ub[y_dim],
+                          line=dict(color=color),
+                          fillcolor=color
+                          )
+            # rect_patch = patches.Rectangle(
+            #     (lb[x_dim], lb[y_dim]), ub[x_dim]-lb[x_dim], ub[y_dim]-lb[y_dim], color=color)
+            # ax.add_patch(rect_patch)
+            x_min = min(lb[x_dim], x_min)
+            y_min = min(lb[y_dim], y_min)
+            x_max = max(ub[x_dim], x_max)
+            y_max = max(ub[y_dim], y_max)
+    fig.update_shapes(dict(xref='x', yref='y'))
+    # ax.set_xlim([x_min-1, x_max+1])
+    # ax.set_ylim([y_min-1, y_max+1])
+    # fig.update_xaxes(range=[x_min-1, x_max+1], showgrid=False)
+    # fig.update_yaxes(range=[y_min-1, y_max+1])
+    return fig, (x_min, x_max), (y_min, y_max)
 
 def plot(
     data,
@@ -51,6 +96,448 @@ def plot(
     return fig, (x_min, x_max), (y_min, y_max)
 
 
+def generate_reachtube_anime(root, map=None, fig=None):
+    # make figure
+    fig_dict = {
+        "data": [],
+        "layout": {},
+        "frames": []
+    }
+    # fig = plot_map(map, 'g', fig)
+    timed_point_dict = {}
+    stack = [root]
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    print("reachtude")
+    end_time = 0
+    while stack != []:
+        node = stack.pop()
+        traces = node.trace
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            if trace[0][0] > 0:
+                trace = trace[4:]
+            # print(trace)
+            end_time = trace[-1][0]
+            for i in range(0, len(trace), 2):
+                x_min = min(x_min, trace[i][1])
+                x_max = max(x_max, trace[i][1])
+                y_min = min(y_min, trace[i][2])
+                y_max = max(y_max, trace[i][2])
+                # if round(trace[i][0], 2) not in timed_point_dict:
+                #     timed_point_dict[round(trace[i][0], 2)] = [
+                #         trace[i][1:].tolist()]
+                # else:
+                #     init = False
+                #     for record in timed_point_dict[round(trace[i][0], 2)]:
+                #         if record == trace[i][1:].tolist():
+                #             init = True
+                #             break
+                #     if init == False:
+                #         timed_point_dict[round(trace[i][0], 2)].append(
+                #             trace[i][1:].tolist())
+                time_point = round(trace[i][0], 2)
+                rect = [trace[i][1:].tolist(), trace[i+1][1:].tolist()]
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = {agent_id: [rect]}
+                else:
+                    if agent_id in timed_point_dict[time_point].keys():
+                        timed_point_dict[time_point][agent_id].append(rect)
+                    else:
+                        timed_point_dict[time_point][agent_id] = [rect]
+
+        stack += node.child
+    # fill in most of layout
+    # print(end_time)
+    duration = int(100/end_time)
+    fig_dict["layout"]["xaxis"] = {
+        "range": [(x_min-10), (x_max+10)],
+        "title": "x position"}
+    fig_dict["layout"]["yaxis"] = {
+        "range": [(y_min-2), (y_max+2)],
+        "title": "y position"}
+    fig_dict["layout"]["hovermode"] = "closest"
+    fig_dict["layout"]["updatemenus"] = [
+        {
+            "buttons": [
+                {
+                    "args": [None, {"frame": {"duration": duration, "redraw": False},
+                                    "fromcurrent": True, "transition": {"duration": duration,
+                                                                        "easing": "quadratic-in-out"}}],
+                    "label": "Play",
+                    "method": "animate"
+                },
+                {
+                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                      "mode": "immediate",
+                                      "transition": {"duration": 0}}],
+                    "label": "Pause",
+                    "method": "animate"
+                }
+            ],
+            "direction": "left",
+            "pad": {"r": 10, "t": 87},
+            "showactive": False,
+            "type": "buttons",
+            "x": 0.1,
+            "xanchor": "right",
+            "y": 0,
+            "yanchor": "top"
+        }
+    ]
+    sliders_dict = {
+        "active": 0,
+        "yanchor": "top",
+        "xanchor": "left",
+        "currentvalue": {
+            "font": {"size": 20},
+            "prefix": "time:",
+            "visible": True,
+            "xanchor": "right"
+        },
+        # "method": "update",
+        "transition": {"duration": duration, "easing": "cubic-in-out"},
+        "pad": {"b": 10, "t": 50},
+        "len": 0.9,
+        "x": 0.1,
+        "y": 0,
+        "steps": []
+    }
+    # make data
+    agent_dict = timed_point_dict[0]  # {agent1:[rect1,..], ...}
+    x_list = []
+    y_list = []
+    text_list = []
+    for agent_id, rect_list in agent_dict.items():
+        for rect in rect_list:
+            # trace = list(data.values())[0]
+            print(rect)
+            x_list.append((rect[0][0]+rect[1][0])/2)
+            y_list.append((rect[0][1]+rect[1][1])/2)
+            text_list.append(
+                ('{:.2f}'.format((rect[0][2]+rect[1][2])/pi*90), '{:.3f}'.format(rect[0][3]+rect[1][3])))
+    # data_dict = {
+    #     "x": x_list,
+    #     "y": y_list,
+    #     "mode": "markers + text",
+    #     "text": text_list,
+    #     "textposition": "bottom center",
+    #     # "marker": {
+    #     #     "sizemode": "area",
+    #     #     "sizeref": 200000,
+    #     #     "size": 2
+    #     # },
+    #     "name": "Current Position"
+    # }
+    # fig_dict["data"].append(data_dict)
+
+    # make frames
+    for time_point in timed_point_dict:
+        frame = {"data": [], "layout": {
+            "annotations": [], "shapes": []}, "name": str(time_point)}
+        agent_dict = timed_point_dict[time_point]
+        trace_x = []
+        trace_y = []
+        trace_theta = []
+        trace_v = []
+        for agent_id, rect_list in agent_dict.items():
+            for rect in rect_list:
+                trace_x.append((rect[0][0]+rect[1][0])/2)
+                trace_y.append((rect[0][1]+rect[1][1])/2)
+                trace_theta.append((rect[0][2]+rect[1][2])/2)
+                trace_v.append((rect[0][3]+rect[1][3])/2)
+                shape_dict = {
+                    "type": 'rect',
+                    "x0": rect[0][0],
+                    "y0": rect[0][1],
+                    "x1": rect[1][0],
+                    "y1": rect[1][1],
+                    "fillcolor": 'rgba(255,255,255,0.5)',
+                    "line": dict(color='rgba(255,255,255,0)'),
+
+                }
+                frame["layout"]["shapes"].append(shape_dict)
+        # data_dict = {
+        #     "x": trace_x,
+        #     "y": trace_y,
+        #     "mode": "markers + text",
+        #     "text": [('{:.2f}'.format(trace_theta[i]/pi*180), '{:.3f}'.format(trace_v[i])) for i in range(len(trace_theta))],
+        #     "textposition": "bottom center",
+        #     # "marker": {
+        #     #     "sizemode": "area",
+        #     #     "sizeref": 200000,
+        #     #     "size": 2
+        #     # },
+        #     "name": "current position"
+        # }
+        # frame["data"].append(data_dict)
+        # print(trace_x)
+        fig_dict["frames"].append(frame)
+        slider_step = {"args": [
+            [time_point],
+            {"frame": {"duration": duration, "redraw": False},
+             "mode": "immediate",
+             "transition": {"duration": duration}}
+        ],
+            "label": time_point,
+            "method": "animate"}
+        sliders_dict["steps"].append(slider_step)
+        # print(len(frame["layout"]["annotations"]))
+
+    fig_dict["layout"]["sliders"] = [sliders_dict]
+
+    fig = go.Figure(fig_dict)
+    # fig = plotly_map(map, 'g', fig)
+    i = 1
+    for agent_id in traces:
+        fig = plotly_reachtube_tree_v2(root, agent_id, 1, [2], i, fig)
+        i += 2
+
+    return fig
+
+
+def plotly_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='blue', fig=None, x_lim=None, y_lim=None):
+    if fig is None:
+        fig = go.Figure()
+
+    # ax = fig.gca()
+    # if x_lim is None:
+    #     x_lim = ax.get_xlim()
+    # if y_lim is None:
+    #     y_lim = ax.get_ylim()
+
+    queue = [root]
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        trace = traces[agent_id]
+        # print(trace)
+        data = []
+        for i in range(0, len(trace)-1, 2):
+            data.append([trace[i], trace[i+1]])
+        fig, x_lim, y_lim = plotly_plot(
+            data, x_dim, y_dim_list, color, fig, x_lim, y_lim)
+        # print(data)
+        queue += node.child
+
+    return fig
+
+
+def plotly_reachtube_tree_v2(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color=0, fig=None, x_lim=None, y_lim=None):
+    if fig is None:
+        fig = go.Figure()
+
+    # ax = fig.gca()
+    # if x_lim is None:
+    #     x_lim = ax.get_xlim()
+    # if y_lim is None:
+    #     y_lim = ax.get_ylim()
+    bg_color = ['rgba(31,119,180,1)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)', 'rgba(214,39,40,0.2)', 'rgba(148,103,189,0.2)',
+                'rgba(140,86,75,0.2)', 'rgba(227,119,194,0.2)', 'rgba(127,127,127,0.2)', 'rgba(188,189,34,0.2)', 'rgba(23,190,207,0.2)']
+    queue = [root]
+    show_legend = False
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        trace = np.array(traces[agent_id])
+        # print(trace[0], trace[1], trace[-2], trace[-1])
+        max_id = len(trace)-1
+        # trace_x = np.zeros(max_id+1)
+        # trace_x_2 = np.zeros(max_id+1)
+        # trace_y = np.zeros(max_id+1)
+        # trace_y_2 = np.zeros(max_id+1)
+        # for y_dim in y_dim_list:
+        #     for i in range(0, max_id, 2):
+        #         id = int(i/2)
+        #         trace_x[id] = trace[i+1][x_dim]
+        #         trace_x[max_id-id] = trace[i][x_dim]
+        #         trace_x_2[id] = trace[i][x_dim]
+        #         trace_x_2[max_id-id] = trace[i+1][x_dim]
+        #         trace_y[id] = trace[i][y_dim]
+        #         trace_y[max_id-id] = trace[i+1][y_dim]
+        # fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+        #                          fill='toself',
+        #                          fillcolor='blue',
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=False))
+        # fig.add_trace(go.Scatter(x=trace_x_2, y=trace_y,
+        #                          fill='toself',
+        #                          fillcolor='red',
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=False))
+        # fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2],
+        #                          mode='lines',
+        #                          line_color="black",
+        #                          text=[range(0, max_id+1)],
+        #                          name='lines',
+        #                          showlegend=False))
+        trace_x_odd = np.array([trace[i][1] for i in range(0, max_id, 2)])
+        trace_x_even = np.array([trace[i][1] for i in range(1, max_id+1, 2)])
+        trace_y_odd = np.array([trace[i][2] for i in range(0, max_id, 2)])
+        trace_y_even = np.array([trace[i][2] for i in range(1, max_id+1, 2)])
+        fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend
+                                 ))
+        fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend))
+        fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend
+                                 ))
+        fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend))
+        # fig.add_trace(go.Scatter(x=trace_x_odd.tolist(), y=trace_y_odd.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True
+        #                          ))
+        # fig.add_trace(go.Scatter(x=trace_x_even.tolist(), y=trace_y_odd.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True))
+        # fig.add_trace(go.Scatter(x=trace_x_odd.tolist(), y=trace_y_even.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid",shape="spline"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True
+        #                          ))
+        # fig.add_trace(go.Scatter(x=trace_x_even.tolist(), y=trace_y_even.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True))
+        # fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2],
+        #                          mode='markers',
+        #                          #  fill='toself',
+        #                          #  line=dict(dash="dot"),
+        #                          line_color="black",
+        #                          text=[range(0, max_id+1)],
+        #                          name='lines',
+        #                          showlegend=False))
+        queue += node.child
+    queue = [root]
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        trace = np.array(traces[agent_id])
+        # print(trace[0], trace[1], trace[-2], trace[-1])
+        max_id = len(trace)-1
+        # trace_x = np.zeros(max_id+1)
+        # trace_x_2 = np.zeros(max_id+1)
+        # trace_y = np.zeros(max_id+1)
+        # trace_y_2 = np.zeros(max_id+1)
+        # for y_dim in y_dim_list:
+        #     for i in range(0, max_id, 2):
+        #         id = int(i/2)
+        #         trace_x[id] = trace[i+1][x_dim]
+        #         trace_x[max_id-id] = trace[i][x_dim]
+        #         trace_x_2[id] = trace[i][x_dim]
+        #         trace_x_2[max_id-id] = trace[i+1][x_dim]
+        #         trace_y[id] = trace[i][y_dim]
+        #         trace_y[max_id-id] = trace[i+1][y_dim]
+        # fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+        #                          fill='toself',
+        #                          fillcolor='blue',
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=False))
+        # fig.add_trace(go.Scatter(x=trace_x_2, y=trace_y,
+        #                          fill='toself',
+        #                          fillcolor='red',
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=False))
+        # fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2],
+        #                          mode='lines',
+        #                          line_color="black",
+        #                          text=[range(0, max_id+1)],
+        #                          name='lines',
+        #                          showlegend=False))
+        # trace_x_odd = np.array([trace[i][1] for i in range(0, max_id, 2)])
+        # trace_x_even = np.array([trace[i][1] for i in range(1, max_id+1, 2)])
+        # trace_y_odd = np.array([trace[i][2] for i in range(0, max_id, 2)])
+        # trace_y_even = np.array([trace[i][2] for i in range(1, max_id+1, 2)])
+        # fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+        #                          fill='toself',
+        #                          fillcolor=bg_color[0],
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=True
+        #                          ))
+        # fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+        #                          fill='toself',
+        #                          fillcolor=bg_color[0],
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=True))
+        # fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+        #                          fill='toself',
+        #                          fillcolor=bg_color[0],
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=True
+        #                          ))
+        # fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+        #                          fill='toself',
+        #                          fillcolor=bg_color[0],
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=True))
+        # fig.add_trace(go.Scatter(x=trace_x_odd.tolist(), y=trace_y_odd.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True
+        #                          ))
+        # fig.add_trace(go.Scatter(x=trace_x_even.tolist(), y=trace_y_odd.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True))
+        # fig.add_trace(go.Scatter(x=trace_x_odd.tolist(), y=trace_y_even.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid",shape="spline"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True
+        #                          ))
+        # fig.add_trace(go.Scatter(x=trace_x_even.tolist(), y=trace_y_even.tolist(), mode='lines',
+        #                          #  fill='toself',
+        #                          #  fillcolor=bg_color[0],
+        #                          #  line=dict(width=1, dash="solid"),
+        #                          line_color=bg_color[0],
+        #                          showlegend=True))
+        fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2],
+                                 mode='markers',
+                                 #  fill='toself',
+                                 #  line=dict(dash="dot"),
+                                 line_color="black",
+                                 marker={
+            "sizemode": "area",
+            "sizeref": 200000,
+            "size": 2
+        },
+            text=[range(0, max_id+1)],
+            name='lines',
+            showlegend=False))
+        queue += node.child
+    # fig.update_traces(line_dash="dash")
+    return fig
+
+
 def plot_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None):
     if fig is None:
         fig = plt.figure()
@@ -66,14 +553,14 @@ def plot_reachtube_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] =
         node = queue.pop(0)
         traces = node.trace
         trace = traces[agent_id]
+        # print(trace)
         data = []
-        for i in range(0, len(trace), 2):
+
+        for i in range(0, len(trace)-1, 2):
             data.append([trace[i], trace[i+1]])
         fig, x_lim, y_lim = plot(
             data, x_dim, y_dim_list, color, fig, x_lim, y_lim)
-
         queue += node.child
-
     return fig
 
 def plot_reachtube_tree_branch(root, agent_id, x_dim: int=0, y_dim_list: List[int]=[1], color='b', fig = None, x_lim = None, y_lim = None):
@@ -98,10 +585,130 @@ def plot_reachtube_tree_branch(root, agent_id, x_dim: int=0, y_dim_list: List[in
 
         if node.child:
             stack += [node.child[0]]
+    return fig
+
+
+def plotly_map(map, color='b', fig: go.Figure() = None, x_lim=None, y_lim=None):
+    if fig is None:
+        fig = go.Figure()
+    all_x = []
+    all_y = []
+    all_v = []
+    for lane_idx in map.lane_dict:
+        lane = map.lane_dict[lane_idx]
+        for lane_seg in lane.segment_list:
+            if lane_seg.type == 'Straight':
+                start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
+                end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
+                # fig.add_trace(go.Scatter(x=[start1[0], end1[0]], y=[start1[1], end1[1]],
+                #                          mode='lines',
+                #                          line_color='black',
+                #                          showlegend=False,
+                #                          # text=theta,
+                #                          name='lines'))
+                start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
+                end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
+                # fig.add_trace(go.Scatter(x=[start2[0], end2[0]], y=[start2[1], end2[1]],
+                #                          mode='lines',
+                #                          line_color='black',
+                #                          showlegend=False,
+                #                          # text=theta,
+                #                          name='lines'))
+                fig.add_trace(go.Scatter(x=[start1[0], end1[0], end2[0], start2[0]], y=[start1[1], end1[1], end2[1], start2[1]],
+                                         mode='lines',
+                                         line_color='black',
+                                         #  fill='toself',
+                                         #  fillcolor='rgba(255,255,255,0)',
+                                         #  line_color='rgba(0,0,0,0)',
+                                         showlegend=False,
+                                         # text=theta,
+                                         name='lines'))
+                # fig = go.Figure().add_heatmap(x=)
+                seg_x, seg_y, seg_v = lane_seg.get_all_speed()
+                all_x += seg_x
+                all_y += seg_y
+                all_v += seg_v
+            elif lane_seg.type == "Circular":
+                phase_array = np.linspace(
+                    start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
+                r1 = lane_seg.radius - lane_seg.width/2
+                x = np.cos(phase_array)*r1 + lane_seg.center[0]
+                y = np.sin(phase_array)*r1 + lane_seg.center[1]
+                fig.add_trace(go.Scatter(x=x, y=y,
+                                         mode='lines',
+                                         line_color='black',
+                                         showlegend=False,
+                                         # text=theta,
+                                         name='lines'))
 
+                r2 = lane_seg.radius + lane_seg.width/2
+                x = np.cos(phase_array)*r2 + lane_seg.center[0]
+                y = np.sin(phase_array)*r2 + lane_seg.center[1]
+                fig.add_trace(go.Scatter(x=x, y=y,
+                                         mode='lines',
+                                         line_color='black',
+                                         showlegend=False,
+                                         # text=theta,
+                                         name='lines'))
+            else:
+                raise ValueError(f'Unknown lane segment type {lane_seg.type}')
+    start_color = [0, 0, 255, 0.2]
+    end_color = [255, 0, 0, 0.2]
+    curr_color = copy.deepcopy(start_color)
+    max_speed = max(all_v)
+    min_speed = min(all_v)
+
+    for i in range(len(all_v)):
+        # print(all_x[i])
+        # print(all_y[i])
+        # print(all_v[i])
+        curr_color = copy.deepcopy(start_color)
+        for j in range(len(curr_color)-1):
+            curr_color[j] += (all_v[i]-min_speed)/(max_speed -
+                                                   min_speed)*(end_color[j]-start_color[j])
+        fig.add_trace(go.Scatter(x=all_x[i], y=all_y[i],
+                                 mode='lines',
+                                 line_color='rgba(0,0,0,0)',
+                                 fill='toself',
+                                 fillcolor='rgba'+str(tuple(curr_color)),
+                                 #  marker=dict(
+                                 #     symbol='square',
+                                 #     size=16,
+                                 #     cmax=max_speed,
+                                 #     cmin=min_speed,
+                                 #     # color=all_v[i],
+                                 #     colorbar=dict(
+                                 #         title="Colorbar"
+                                 #     ),
+                                 #     colorscale=[
+                                 #         [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+                                 # ),
+                                 showlegend=False,
+                                 ))
+    fig.add_trace(go.Scatter(x=[0], y=[0],
+                             mode='markers',
+                             # fill='toself',
+                             # fillcolor='rgba'+str(tuple(curr_color)),
+                             marker=dict(
+                                 symbol='square',
+                                 size=16,
+                                 cmax=max_speed,
+                                 cmin=min_speed,
+                                 color='rgba(0,0,0,0)',
+                                 colorbar=dict(
+                                        title="Speed Limit"
+                                 ),
+                                 colorscale=[
+                                     [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+    ),
+        showlegend=False,
+    ))
+    # fig.update_coloraxes(colorbar=dict(title="Colorbar"), colorscale=[
+    #                      [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]])
     return fig
 
-def plot_map(map, color = 'b', fig = None, x_lim = None,y_lim = None):
+
+def plot_map(map, color='b', fig=None, x_lim=None, y_lim=None):
     if fig is None:
         fig = plt.figure()
 
@@ -138,7 +745,50 @@ def plot_map(map, color = 'b', fig = None, x_lim = None,y_lim = None):
     return fig
 
 
-def plot_simulation_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None):
+def plotly_simulation_tree(root: AnalysisTreeNode, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None):
+    if fig is None:
+        fig = go.Figure()
+    i = 0
+    fg_color = ['rgb(31,119,180)', 'rgb(255,127,14)', 'rgb(44,160,44)', 'rgb(214,39,40)', 'rgb(148,103,189)',
+                'rgb(140,86,75)', 'rgb(227,119,194)', 'rgb(127,127,127)', 'rgb(188,189,34)', 'rgb(23,190,207)']
+    bg_color = ['rgba(31,119,180,0.2)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)', 'rgba(214,39,40,0.2)', 'rgba(148,103,189,0.2)',
+                'rgba(140,86,75,0.2)', 'rgba(227,119,194,0.2)', 'rgba(127,127,127,0.2)', 'rgba(188,189,34,0.2)', 'rgba(23,190,207,0.2)']
+    queue = [root]
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        print(node.mode)
+        # [[time,x,y,theta,v]...]
+        trace = np.array(traces[agent_id])
+        # print(trace)
+        for y_dim in y_dim_list:
+            trace_y = trace[:, y_dim].tolist()
+            trace_x = trace[:, x_dim].tolist()
+            theta = [i/pi*180 for i in trace[:, x_dim+2]]
+            trace_x_rev = trace_x[::-1]
+            # print(trace_x)
+            trace_upper = [i+1 for i in trace_y]
+            trace_lower = [i-1 for i in trace_y]
+            trace_lower = trace_lower[::-1]
+            # print(trace_upper)
+            # print(trace[:, y_dim])
+            fig.add_trace(go.Scatter(x=trace_x+trace_x_rev, y=trace_upper+trace_lower,
+                                     fill='toself',
+                                     fillcolor=bg_color[i % 10],
+                                     line_color='rgba(255,255,255,0)',
+                                     showlegend=False))
+            fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+                                     mode='lines',
+                                     line_color=fg_color[i % 10],
+                                     text=theta,
+                                     name='lines'))
+            i += 1
+        queue += node.child
+    fig.update_traces(mode='lines')
+    return fig
+
+
+def plot_simulation_tree(root: AnalysisTreeNode, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None):
     if fig is None:
         fig = plt.figure()
 
@@ -155,26 +805,26 @@ def plot_simulation_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] =
     while queue != []:
         node = queue.pop(0)
         traces = node.trace
+        print(node.mode)
+        # [[time,x,y,theta,v]...]
         trace = np.array(traces[agent_id])
+        # print(trace)
         for y_dim in y_dim_list:
             ax.plot(trace[:, x_dim], trace[:, y_dim], color)
             x_min = min(x_min, trace[:, x_dim].min())
             x_max = max(x_max, trace[:, x_dim].max())
-
             y_min = min(y_min, trace[:, y_dim].min())
             y_max = max(y_max, trace[:, y_dim].max())
-
         queue += node.child
     ax.set_xlim([x_min-1, x_max+1])
     ax.set_ylim([y_min-1, y_max+1])
-
     return fig
 
 
 def generate_simulation_anime(root, map, fig=None):
     if fig is None:
         fig = plt.figure()
-    fig = plot_map(map, 'g', fig)
+    # fig = plot_map(map, 'g', fig)
     timed_point_dict = {}
     stack = [root]
     ax = fig.gca()
@@ -206,9 +856,9 @@ def generate_simulation_anime(root, map, fig=None):
         point_list = timed_point_dict[time_point]
         plt.xlim((x_min-2, x_max+2))
         plt.ylim((y_min-2, y_max+2))
-        plot_map(map, color='g', fig=fig)
+        # plot_map(map, color='g', fig=fig)
         for data in point_list:
-            point = data[0]
+            point = data
             color = data[1]
             ax = plt.gca()
             ax.plot([point[0]], [point[1]], markerfacecolor=color,
@@ -220,6 +870,7 @@ def generate_simulation_anime(root, map, fig=None):
             ax.arrow(x_tail, y_tail, dx, dy, head_width=1, head_length=0.5)
         plt.pause(0.05)
         plt.clf()
+    return fig
     #     img_buf = io.BytesIO()
     #     plt.savefig(img_buf, format = 'png')
     #     im = Image.open(img_buf)
@@ -229,109 +880,6 @@ def generate_simulation_anime(root, map, fig=None):
     # frame_one.save(fn, format = "GIF", append_images = frames, save_all = True, duration = 100, loop = 0)
 
 
-def plotly_map(map, color='b', fig=None, x_lim=None, y_lim=None):
-    if fig is None:
-        fig = go.figure()
-
-    for lane_idx in map.lane_dict:
-        lane = map.lane_dict[lane_idx]
-        for lane_seg in lane.segment_list:
-            if lane_seg.type == 'Straight':
-                start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
-                end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
-                # fig.add_trace(go.Scatter(x=[start1[0], end1[0]], y=[start1[1], end1[1]],
-                #                          mode='lines',
-                #                          line_color='black',
-                #                          showlegend=False,
-                #                          # text=theta,
-                #                          name='lines'))
-                start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
-                end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
-                # fig.add_trace(go.Scatter(x=[start2[0], end2[0]], y=[start2[1], end2[1]],
-                #                          mode='lines',
-                #                          line_color='black',
-                #                          showlegend=False,
-                #                          # text=theta,
-                #                          name='lines'))
-                fig.add_trace(go.Scatter(x=[start1[0], end1[0], end2[0], start2[0]], y=[start1[1], end1[1], end2[1], start2[1]],
-                                         mode='lines',
-                                         line_color='black',
-                                         #  fill='toself',
-                                         #  fillcolor='rgba(255,255,255,0)',
-                                         #  line_color='rgba(0,0,0,0)',
-                                         showlegend=False,
-                                         # text=theta,
-                                         name='lines'))
-            elif lane_seg.type == "Circular":
-                phase_array = np.linspace(
-                    start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
-                r1 = lane_seg.radius - lane_seg.width/2
-                x = np.cos(phase_array)*r1 + lane_seg.center[0]
-                y = np.sin(phase_array)*r1 + lane_seg.center[1]
-                fig.add_trace(go.Scatter(x=x, y=y,
-                                         mode='lines',
-                                         line_color='black',
-                                         showlegend=False,
-                                         # text=theta,
-                                         name='lines'))
-
-                r2 = lane_seg.radius + lane_seg.width/2
-                x = np.cos(phase_array)*r2 + lane_seg.center[0]
-                y = np.sin(phase_array)*r2 + lane_seg.center[1]
-                fig.add_trace(go.Scatter(x=x, y=y,
-                                         mode='lines',
-                                         line_color='black',
-                                         showlegend=False,
-                                         # text=theta,
-                                         name='lines'))
-            else:
-                raise ValueError(f'Unknown lane segment type {lane_seg.type}')
-    return fig
-
-
-def plotly_simulation_tree(root, agent_id, x_dim: int = 0, y_dim_list: List[int] = [1], color='b', fig=None, x_lim=None, y_lim=None):
-    if fig is None:
-        fig = go.Figure()
-    i = 0
-    fg_color = ['rgb(31,119,180)', 'rgb(255,127,14)', 'rgb(44,160,44)', 'rgb(214,39,40)', 'rgb(148,103,189)',
-                'rgb(140,86,75)', 'rgb(227,119,194)', 'rgb(127,127,127)', 'rgb(188,189,34)', 'rgb(23,190,207)']
-    bg_color = ['rgba(31,119,180,0.2)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)', 'rgba(214,39,40,0.2)', 'rgba(148,103,189,0.2)',
-                'rgba(140,86,75,0.2)', 'rgba(227,119,194,0.2)', 'rgba(127,127,127,0.2)', 'rgba(188,189,34,0.2)', 'rgba(23,190,207,0.2)']
-    queue = [root]
-    while queue != []:
-        node = queue.pop(0)
-        traces = node.trace
-        # print(node.mode)
-        # [[time,x,y,theta,v]...]
-        trace = np.array(traces[agent_id])
-        # print(trace)
-        for y_dim in y_dim_list:
-            trace_y = trace[:, y_dim].tolist()
-            trace_x = trace[:, x_dim].tolist()
-            theta = [i/pi*180 for i in trace[:, x_dim+2]]
-            trace_x_rev = trace_x[::-1]
-            # print(trace_x)
-            trace_upper = [i+1 for i in trace_y]
-            trace_lower = [i-1 for i in trace_y]
-            trace_lower = trace_lower[::-1]
-            # print(trace_upper)
-            # print(trace[:, y_dim])
-            fig.add_trace(go.Scatter(x=trace_x+trace_x_rev, y=trace_upper+trace_lower,
-                                     fill='toself',
-                                     fillcolor=bg_color[i % 10],
-                                     line_color='rgba(255,255,255,0)',
-                                     showlegend=False))
-            fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
-                                     mode='lines',
-                                     line_color=fg_color[i % 10],
-                                     text=theta,
-                                     name='lines'))
-            i += 1
-        queue += node.child
-    fig.update_traces(mode='lines')
-    return fig
-
-
 def plotly_simulation_anime(root, map=None, fig=None):
     # make figure
     fig_dict = {
@@ -342,34 +890,47 @@ def plotly_simulation_anime(root, map=None, fig=None):
     # fig = plot_map(map, 'g', fig)
     timed_point_dict = {}
     stack = [root]
+
+    # print("plot")
+    # print(root.mode)
     x_min, x_max = float('inf'), -float('inf')
     y_min, y_max = float('inf'), -float('inf')
+    # segment_start = set()
+    # previous_mode = {}
+    # for agent_id in root.mode:
+    #     previous_mode[agent_id] = []
+
     while stack != []:
         node = stack.pop()
         traces = node.trace
         for agent_id in traces:
             trace = np.array(traces[agent_id])
+
             for i in range(len(trace)):
                 x_min = min(x_min, trace[i][1])
                 x_max = max(x_max, trace[i][1])
                 y_min = min(y_min, trace[i][2])
                 y_max = max(y_max, trace[i][2])
-                if round(trace[i][0], 2) not in timed_point_dict:
-                    timed_point_dict[round(trace[i][0], 2)] = [
-                        trace[i][1:].tolist()]
+                # print(round(trace[i][0], 2))
+                time_point = round(trace[i][0], 2)
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = [
+                        {agent_id: trace[i][1:].tolist()}]
                 else:
                     init = False
-                    for record in timed_point_dict[round(trace[i][0], 2)]:
-                        if record == trace[i][1:].tolist():
+                    for record in timed_point_dict[time_point]:
+                        if list(record.values())[0] == trace[i][1:].tolist():
                             init = True
                             break
                     if init == False:
-                        timed_point_dict[round(trace[i][0], 2)].append(
-                            trace[i][1:].tolist())
+                        timed_point_dict[time_point].append(
+                            {agent_id: trace[i][1:].tolist()})
             time = round(trace[i][0], 2)
         stack += node.child
     # fill in most of layout
-    # print(time)
+    # print(segment_start)
+    # print(timed_point_dict.keys())
+
     duration = int(600/time)
     fig_dict["layout"]["xaxis"] = {
         "range": [(x_min-10), (x_max+10)],
@@ -425,12 +986,24 @@ def plotly_simulation_anime(root, map=None, fig=None):
     }
     # make data
     point_list = timed_point_dict[0]
-    # print(point_list)
+
+    print(point_list)
+    x_list = []
+    y_list = []
+    text_list = []
+    for data in point_list:
+        trace = list(data.values())[0]
+        # print(trace)
+        x_list.append(trace[0])
+        y_list.append(trace[1])
+        text_list.append(
+            ('{:.2f}'.format(trace[2]/pi*180), '{:.3f}'.format(trace[3])))
     data_dict = {
-        "x": [data[0] for data in point_list],
-        "y": [data[1] for data in point_list],
+        "x": x_list,
+        "y": y_list,
         "mode": "markers + text",
-        "text": [(round(data[3], 2), round(data[2]/pi*180, 2)) for data in point_list],
+        "text": text_list,
+        "textfont": dict(size=14, color="black"),
         "textposition": "bottom center",
         # "marker": {
         #     "sizemode": "area",
@@ -443,34 +1016,48 @@ def plotly_simulation_anime(root, map=None, fig=None):
 
     # make frames
     for time_point in timed_point_dict:
+
+        # print(time_point)
         frame = {"data": [], "layout": {
-            "annotations": []}, "name": str(time_point)}
-        # print(timed_point_dict[time_point])
+            "annotations": []}, "name": '{:.2f}'.format(time_point)}
+        # print(timed_point_dict[time_point][0])
         point_list = timed_point_dict[time_point]
         # point_list = list(OrderedDict.fromkeys(timed_point_dict[time_point]))
-        trace_x = [data[0] for data in point_list]
-        trace_y = [data[1] for data in point_list]
-        trace_theta = [data[2] for data in point_list]
-        trace_v = [data[3] for data in point_list]
+        # todokeyi
+        trace_x = []
+        trace_y = []
+        trace_theta = []
+        trace_v = []
+        for data in point_list:
+            trace = list(data.values())[0]
+            # print(trace)
+            trace_x.append(trace[0])
+            trace_y.append(trace[1])
+            trace_theta.append(trace[2])
+            trace_v.append(trace[3])
         data_dict = {
             "x": trace_x,
             "y": trace_y,
             "mode": "markers + text",
-            "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 2)) for i in range(len(trace_theta))],
+
+            # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+            "text": [('{:.2f}'.format(trace_theta[i]/pi*180), '{:.3f}'.format(trace_v[i])) for i in range(len(trace_theta))],
+            "textfont": dict(size=14, color="black"),
             "textposition": "bottom center",
             # "marker": {
             #     "sizemode": "area",
             #     "sizeref": 200000,
             #     "size": 2
             # },
-            "name": "current position"
+            "name": "current position",
+            # "show_legend": False
         }
         frame["data"].append(data_dict)
         for i in range(len(trace_x)):
             ax = np.cos(trace_theta[i])*trace_v[i]
             ay = np.sin(trace_theta[i])*trace_v[i]
             # print(trace_x[i]+ax, trace_y[i]+ay)
-            annotations_dict = {"x": trace_x[i]+ax+0.1, "y": trace_y[i]+ay,
+            annotations_dict = {"x": trace_x[i]+ax, "y": trace_y[i]+ay,
                                 # "xshift": ax, "yshift": ay,
                                 "ax": trace_x[i], "ay": trace_y[i],
                                 "arrowwidth": 2,
@@ -483,7 +1070,6 @@ def plotly_simulation_anime(root, map=None, fig=None):
                                 "arrowhead": 1,
                                 "arrowcolor": "black"}
             frame["layout"]["annotations"].append(annotations_dict)
-
         fig_dict["frames"].append(frame)
         slider_step = {"args": [
             [time_point],
@@ -499,34 +1085,124 @@ def plotly_simulation_anime(root, map=None, fig=None):
     fig_dict["layout"]["sliders"] = [sliders_dict]
 
     fig = go.Figure(fig_dict)
-    if map is not None:
-        fig = plotly_map(map, 'g', fig)
+    fig = plotly_map(map, 'g', fig)
     i = 0
     queue = [root]
+    previous_mode = {}
+    agent_list = []
+    for agent_id in root.mode:
+        previous_mode[agent_id] = []
+        agent_list.append(agent_id)
+    text_pos = 'middle center'
     while queue != []:
         node = queue.pop(0)
         traces = node.trace
         # print(node.mode)
         # [[time,x,y,theta,v]...]
+        i = 0
         for agent_id in traces:
             trace = np.array(traces[agent_id])
             # print(trace)
             trace_y = trace[:, 2].tolist()
             trace_x = trace[:, 1].tolist()
             # theta = [i/pi*180 for i in trace[:, 3]]
-            color = 'green'
-            if agent_id == 'car1':
-                color = 'red'
+
+            i = agent_list.index(agent_id)
+            color = colors[i % 5]
             fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2],
                                      mode='lines',
                                      line_color=color,
-                                     text=[(round(trace[i, 3]/pi*180, 2), round(trace[i, 4], 2))
+                                     text=[(round(trace[i, 3]/pi*180, 2), round(trace[i, 4], 3))
                                            for i in range(len(trace_y))],
                                      showlegend=False)
                           #  name='lines')
                           )
-            i += 1
+            if previous_mode[agent_id] != node.mode[agent_id]:
+                theta = trace[0, 3]
+                veh_mode = node.mode[agent_id][0]
+                if veh_mode == 'Normal':
+                    text_pos = 'middle center'
+                elif veh_mode == 'Brake':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'middle left'
+                    else:
+                        text_pos = 'middle right'
+                elif veh_mode == 'Accelerate':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'middle right'
+                    else:
+                        text_pos = 'middle left'
+                elif veh_mode == 'SwitchLeft':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'top center'
+                    else:
+                        text_pos = 'bottom center'
+                elif veh_mode == 'SwitchRight':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'bottom center'
+                    else:
+                        text_pos = 'top center'
+                fig.add_trace(go.Scatter(x=[trace[0, 1]], y=[trace[0, 2]],
+                                         mode='markers+text',
+                                         line_color='rgba(255,255,255,0.3)',
+                                         text=str(agent_id)+': ' +
+                                         str(node.mode[agent_id][0]),
+                                         textposition=text_pos,
+                                         textfont=dict(
+                    #  family="sans serif",
+                    size=10,
+                                             color="grey"),
+                                         showlegend=False,
+                                         ))
+                # i += 1
+                previous_mode[agent_id] = node.mode[agent_id]
         queue += node.child
-    # fig.update_traces(mode='lines')
-
+    fig.update_traces(showlegend=False)
+    # fig.update_annotations(textfont=dict(size=14, color="black"))
+    # print(fig.frames[0].layout["annotations"])
     return fig
+    # fig.show()
+
+
+# The 'color' property is a color and may be specified as:
+#       - A hex string (e.g. '#ff0000')
+#       - An rgb/rgba string (e.g. 'rgb(255,0,0)')
+#       - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
+#       - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
+#       - A named CSS color:
+#             aliceblue, antiquewhite, aqua, aquamarine, azure,
+#             beige, bisque, black, blanchedalmond, blue,
+#             blueviolet, brown, burlywood, cadetblue,
+#             chartreuse, chocolate, coral, cornflowerblue,
+#             cornsilk, crimson, cyan, darkblue, darkcyan,
+#             darkgoldenrod, darkgray, darkgrey, darkgreen,
+#             darkkhaki, darkmagenta, darkolivegreen, darkorange,
+#             darkorchid, darkred, darksalmon, darkseagreen,
+#             darkslateblue, darkslategray, darkslategrey,
+#             darkturquoise, darkviolet, deeppink, deepskyblue,
+#             dimgray, dimgrey, dodgerblue, firebrick,
+#             floralwhite, forestgreen, fuchsia, gainsboro,
+#             ghostwhite, gold, goldenrod, gray, grey, green,
+#             greenyellow, honeydew, hotpink, indianred, indigo,
+#             ivory, khaki, lavender, lavenderblush, lawngreen,
+#             lemonchiffon, lightblue, lightcoral, lightcyan,
+#             lightgoldenrodyellow, lightgray, lightgrey,
+#             lightgreen, lightpink, lightsalmon, lightseagreen,
+#             lightskyblue, lightslategray, lightslategrey,
+#             lightsteelblue, lightyellow, lime, limegreen,
+#             linen, magenta, maroon, mediumaquamarine,
+#             mediumblue, mediumorchid, mediumpurple,
+#             mediumseagreen, mediumslateblue, mediumspringgreen,
+#             mediumturquoise, mediumvioletred, midnightblue,
+#             mintcream, mistyrose, moccasin, navajowhite, navy,
+#             oldlace, olive, olivedrab, orange, orangered,
+#             orchid, palegoldenrod, palegreen, paleturquoise,
+#             palevioletred, papayawhip, peachpuff, peru, pink,
+#             plum, powderblue, purple, red, rosybrown,
+#             royalblue, rebeccapurple, saddlebrown, salmon,
+#             sandybrown, seagreen, seashell, sienna, silver,
+#             skyblue, slateblue, slategray, slategrey, snow,
+#             springgreen, steelblue, tan, teal, thistle, tomato,
+#             turquoise, violet, wheat, white, whitesmoke,
+#             yellow, yellowgreen
+
diff --git a/dryvr_plus_plus/plotter/plotter2D_new.py b/dryvr_plus_plus/plotter/plotter2D_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..5016b55be0270480c8770d01410191cfd903b7ef
--- /dev/null
+++ b/dryvr_plus_plus/plotter/plotter2D_new.py
@@ -0,0 +1,1609 @@
+"""
+This file consist main plotter code for DryVR reachtube output
+"""
+
+from __future__ import annotations
+from audioop import reverse
+# from curses import start_color
+from re import A
+import matplotlib.patches as patches
+import matplotlib.pyplot as plt
+import numpy as np
+from math import pi
+import plotly.graph_objects as go
+from typing import List
+from PIL import Image, ImageDraw
+import io
+import copy
+import operator
+from collections import OrderedDict
+
+from torch import layout
+from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
+
+colors = [['#CC0000', '#FF0000', '#FF3333', '#FF6666', '#FF9999'],
+          ['#CC6600', '#FF8000', '#FF9933', '#FFB266', '#FFCC99'],
+          ['#CCCC00', '#FFFF00', '#FFFF33', '#FFFF66', '#FFFF99'],
+          ['#66CC00', '#80FF00', '#99FF33', '#B2FF66', '#CCFF99'],
+          ['#00CC00', '#00FF00', '#33FF33', '#66FF66', '#99FF99'],
+          ['#00CC66', '#00FF80', '#33FF99', '#66FFB2', '#99FFCC'],
+          ['#00CCCC', '#00FFFF', '#33FFFF', '#66FFFF', '#99FFFF'],
+          ['#0066CC', '#0080FF', '#3399FF', '#66B2FF', '#99CCFF'],
+          ['#0000CC', '#0000FF', '#3333FF', '#6666FF', '#9999FF'],
+          ['#6600CC', '#7F00FF', '#9933FF', '#B266FF', '#CC99FF'],
+          ['#CC00CC', '#FF00FF', '#FF33FF', '#FF66FF', '#FF99FF'],
+          ['#CC0066', '#FF007F', '#FF3399', '#FF66B2', '#FF99CC']
+          ]
+scheme_dict = {'red': 0, 'orange': 1, 'yellow': 2, 'yellowgreen': 3, 'lime': 4,
+               'springgreen': 5, 'cyan': 6, 'cyanblue': 7, 'blue': 8, 'purple': 9, 'magenta': 10, 'pink': 11}
+bg_color = ['rgba(31,119,180,1)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)', 'rgba(214,39,40,0.2)', 'rgba(148,103,189,0.2)',
+            'rgba(140,86,75,0.2)', 'rgba(227,119,194,0.2)', 'rgba(127,127,127,0.2)', 'rgba(188,189,34,0.2)', 'rgba(23,190,207,0.2)']
+color_cnt = 0
+
+
+def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim: int = 2, map_type='lines'):
+    # make figure
+    fig_dict = {
+        "data": [],
+        "layout": {},
+        "frames": []
+    }
+    fig = draw_map(map=map, fig=fig, fill_type=map_type)
+    timed_point_dict = {}
+    stack = [root]
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    print("reachtude")
+    end_time = 0
+    while stack != []:
+        node = stack.pop()
+        traces = node.trace
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            if trace[0][0] > 0:
+                trace = trace[4:]
+            # print(trace)
+            end_time = trace[-1][0]
+            for i in range(0, len(trace), 2):
+                x_min = min(x_min, trace[i][x_dim])
+                x_max = max(x_max, trace[i][x_dim])
+                y_min = min(y_min, trace[i][y_dim])
+                y_max = max(y_max, trace[i][y_dim])
+                # if round(trace[i][0], 2) not in timed_point_dict:
+                #     timed_point_dict[round(trace[i][0], 2)] = [
+                #         trace[i][1:].tolist()]
+                # else:
+                #     init = False
+                #     for record in timed_point_dict[round(trace[i][0], 2)]:
+                #         if record == trace[i][1:].tolist():
+                #             init = True
+                #             break
+                #     if init == False:
+                #         timed_point_dict[round(trace[i][0], 2)].append(
+                #             trace[i][1:].tolist())
+                time_point = round(trace[i][0], 2)
+                rect = [trace[i][0:].tolist(), trace[i+1][0:].tolist()]
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = {agent_id: [rect]}
+                else:
+                    if agent_id in timed_point_dict[time_point].keys():
+                        timed_point_dict[time_point][agent_id].append(rect)
+                    else:
+                        timed_point_dict[time_point][agent_id] = [rect]
+
+        stack += node.child
+    # fill in most of layout
+    # print(end_time)
+    duration = int(100/end_time)
+    fig_dict["layout"]["xaxis"] = {
+        # "range": [(x_min-10), (x_max+10)],
+        "title": "x position"}
+    fig_dict["layout"]["yaxis"] = {
+        "range": [(y_min-2), (y_max+2)],
+        "title": "y position"}
+    fig_dict["layout"]["hovermode"] = "closest"
+    fig_dict["layout"]["updatemenus"] = [
+        {
+            "buttons": [
+                {
+                    "args": [None, {"frame": {"duration": duration, "redraw": False},
+                                    "fromcurrent": True, "transition": {"duration": duration,
+                                                                        "easing": "quadratic-in-out"}}],
+                    "label": "Play",
+                    "method": "animate"
+                },
+                {
+                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                      "mode": "immediate",
+                                      "transition": {"duration": 0}}],
+                    "label": "Pause",
+                    "method": "animate"
+                }
+            ],
+            "direction": "left",
+            "pad": {"r": 10, "t": 87},
+            "showactive": False,
+            "type": "buttons",
+            "x": 0.1,
+            "xanchor": "right",
+            "y": 0,
+            "yanchor": "top"
+        }
+    ]
+    sliders_dict = {
+        "active": 0,
+        "yanchor": "top",
+        "xanchor": "left",
+        "currentvalue": {
+            "font": {"size": 20},
+            "prefix": "time:",
+            "visible": True,
+            "xanchor": "right"
+        },
+        # "method": "update",
+        "transition": {"duration": duration, "easing": "cubic-in-out"},
+        "pad": {"b": 10, "t": 50},
+        "len": 0.9,
+        "x": 0.1,
+        "y": 0,
+        "steps": []
+    }
+    # make data
+    agent_dict = timed_point_dict[0]  # {agent1:[rect1,..], ...}
+    x_list = []
+    y_list = []
+    text_list = []
+    for agent_id, rect_list in agent_dict.items():
+        for rect in rect_list:
+            # trace = list(data.values())[0]
+            # print(rect)
+            x_list.append((rect[0][x_dim]+rect[1][x_dim])/2)
+            y_list.append((rect[0][y_dim]+rect[1][y_dim])/2)
+            text_list.append(
+                ('{:.2f}'.format((rect[0][x_dim]+rect[1][x_dim])/2), '{:.3f}'.format(rect[0][y_dim]+rect[1][y_dim])/2))
+    # data_dict = {
+    #     "x": x_list,
+    #     "y": y_list,
+    #     "mode": "markers + text",
+    #     "text": text_list,
+    #     "textposition": "bottom center",
+    #     # "marker": {
+    #     #     "sizemode": "area",
+    #     #     "sizeref": 200000,
+    #     #     "size": 2
+    #     # },
+    #     "name": "Current Position"
+    # }
+    # fig_dict["data"].append(data_dict)
+
+    # make frames
+    for time_point in timed_point_dict:
+        frame = {"data": [], "layout": {
+            "annotations": [], "shapes": []}, "name": str(time_point)}
+        agent_dict = timed_point_dict[time_point]
+        trace_x = []
+        trace_y = []
+        for agent_id, rect_list in agent_dict.items():
+            for rect in rect_list:
+                trace_x.append((rect[0][x_dim]+rect[1][x_dim])/2)
+                trace_y.append((rect[0][y_dim]+rect[1][y_dim])/2)
+                # trace_theta.append((rect[0][2]+rect[1][2])/2)
+                # trace_v.append((rect[0][3]+rect[1][3])/2)
+                shape_dict = {
+                    "type": 'rect',
+                    "x0": rect[0][x_dim],
+                    "y0": rect[0][y_dim],
+                    "x1": rect[1][x_dim],
+                    "y1": rect[1][y_dim],
+                    "fillcolor": 'rgba(255,255,255,0.5)',
+                    "line": dict(color='rgba(255,255,255,0)'),
+
+                }
+                frame["layout"]["shapes"].append(shape_dict)
+        # data_dict = {
+        #     "x": trace_x,
+        #     "y": trace_y,
+        #     "mode": "markers + text",
+        #     "text": [('{:.2f}'.format(trace_theta[i]/pi*180), '{:.3f}'.format(trace_v[i])) for i in range(len(trace_theta))],
+        #     "textposition": "bottom center",
+        #     # "marker": {
+        #     #     "sizemode": "area",
+        #     #     "sizeref": 200000,
+        #     #     "size": 2
+        #     # },
+        #     "name": "current position"
+        # }
+        # frame["data"].append(data_dict)
+        # print(trace_x)
+        fig_dict["frames"].append(frame)
+        slider_step = {"args": [
+            [time_point],
+            {"frame": {"duration": duration, "redraw": False},
+             "mode": "immediate",
+             "transition": {"duration": duration}}
+        ],
+            "label": time_point,
+            "method": "animate"}
+        sliders_dict["steps"].append(slider_step)
+        # print(len(frame["layout"]["annotations"]))
+
+    fig_dict["layout"]["sliders"] = [sliders_dict]
+
+    fig = go.Figure(fig_dict)
+    # fig = plotly_map(map, 'g', fig)
+    i = 1
+    for agent_id in traces:
+        fig = draw_reachtube_tree_v2(root, agent_id, 1, 2, i, fig)
+        i += 2
+
+    return fig
+
+
+def draw_reachtube_tree_v2(root, agent_id, fig=go.Figure(), x_dim: int = 1, y_dim: int = 2, color_id=None, map_type='lines'):
+    # if fig is None:
+    #     fig = go.Figure()
+    global color_cnt, bg_color
+    fig = draw_map(map=map, fig=fig, fill_type=map_type)
+    if color_id is None:
+        color_id = color_cnt
+    queue = [root]
+    show_legend = False
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        trace = np.array(traces[agent_id])
+        # print(trace[0], trace[1], trace[-2], trace[-1])
+        max_id = len(trace)-1
+        trace_x_odd = np.array([trace[i][x_dim] for i in range(0, max_id, 2)])
+        trace_x_even = np.array([trace[i][x_dim]
+                                for i in range(1, max_id+1, 2)])
+        trace_y_odd = np.array([trace[i][y_dim] for i in range(0, max_id, 2)])
+        trace_y_even = np.array([trace[i][y_dim]
+                                for i in range(1, max_id+1, 2)])
+        fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color_id],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend
+                                 ))
+        fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color_id],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend))
+        fig.add_trace(go.Scatter(x=trace_x_odd.tolist()+trace_x_even[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color_id],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend
+                                 ))
+        fig.add_trace(go.Scatter(x=trace_x_even.tolist()+trace_x_odd[::-1].tolist(), y=trace_y_odd.tolist()+trace_y_even[::-1].tolist(), mode='lines',
+                                 fill='toself',
+                                 fillcolor=bg_color[color_id],
+                                 line_color='rgba(255,255,255,0)',
+                                 showlegend=show_legend))
+        queue += node.child
+        color_id = (color_id+1) % 10
+    queue = [root]
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        trace = np.array(traces[agent_id])
+        # print(trace[0], trace[1], trace[-2], trace[-1])
+        max_id = len(trace)-1
+        fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+                                 mode='markers',
+                                 #  fill='toself',
+                                 #  line=dict(dash="dot"),
+                                 line_color="black",
+                                 marker={
+            "sizemode": "area",
+            "sizeref": 200000,
+            "size": 2
+        },
+            text=[range(0, max_id+1)],
+            name='lines',
+            showlegend=False))
+        queue += node.child
+    color_cnt = color_id
+    # fig.update_traces(line_dash="dash")
+    return fig
+
+
+def draw_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure(), fill_type='lines'):
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    if fill_type == 'detailed':
+        speed_dict = map.get_all_speed_limit()
+        speed_list = list(filter(None, speed_dict.values()))
+        speed_min = min(speed_list)
+        speed_max = max(speed_list)
+        start_color = [255, 255, 255, 0.2]
+        end_color = [0, 0, 0, 0.2]
+        curr_color = [0, 0, 0, 0]
+    for lane_idx in map.lane_dict:
+        lane = map.lane_dict[lane_idx]
+        curr_color = [0, 0, 0, 0]
+        if fill_type == 'detailed':
+            speed_limit = speed_dict[lane_idx]
+            if speed_limit is not None:
+                lens = len(curr_color)-1
+                for j in range(lens):
+                    curr_color[j] = int(start_color[j]+(speed_limit-speed_min)/(speed_max -
+                                                                                speed_min)*(end_color[j]-start_color[j]))
+                curr_color[lens] = start_color[lens]+(speed_limit-speed_min)/(speed_max -
+                                                                              speed_min)*(end_color[lens]-start_color[lens])
+                # print(curr_color)
+        for lane_seg in lane.segment_list:
+            if lane_seg.type == 'Straight':
+                start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
+                end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
+                start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
+                end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
+                trace_x = [start1[0], end1[0], end2[0], start2[0], start1[0]]
+                trace_y = [start1[1], end1[1], end2[1], start2[1], start1[1]]
+                x_min = min(x_min, min(trace_x))
+                y_min = min(y_min, min(trace_y))
+                x_max = max(x_max, max(trace_x))
+                y_max = max(y_max, max(trace_y))
+                if fill_type == 'lines' or speed_limit is None:
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             #  fill='toself',
+                                             #  fillcolor='rgba(255,255,255,0)',
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
+                elif fill_type == 'detailed' and speed_limit is not None:
+                    print(curr_color)
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             fill='toself',
+                                             fillcolor='rgba' +
+                                             str(tuple(curr_color)),
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='limit'))
+                elif fill_type == 'fill':
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             fill='toself',
+                                             #  fillcolor='rgba(255,255,255,0)',
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
+            elif lane_seg.type == "Circular":
+                phase_array = np.linspace(
+                    start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
+                r1 = lane_seg.radius - lane_seg.width/2
+                x1 = (np.cos(phase_array)*r1 + lane_seg.center[0]).tolist()
+                y1 = (np.sin(phase_array)*r1 + lane_seg.center[1]).tolist()
+                # fig.add_trace(go.Scatter(x=x1, y=y1,
+                #                          mode='lines',
+                #                          line_color=color,
+                #                          showlegend=False,
+                #                          # text=theta,
+                #                          name='lines'))
+                r2 = lane_seg.radius + lane_seg.width/2
+                x2 = (np.cos(phase_array)*r2 +
+                      lane_seg.center[0]).tolist().reverse()
+                y2 = (np.sin(phase_array)*r2 +
+                      lane_seg.center[1]).tolist().reverse()
+                trace_x = x1+x2+[x1[0]]
+                trace_y = y1+y2+[y1[0]]
+                x_min = min(x_min, min(trace_x))
+                y_min = min(y_min, min(trace_y))
+                x_max = max(x_max, max(trace_x))
+                y_max = max(y_max, max(trace_y))
+                if fill_type == 'lines':
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
+                elif fill_type == 'detailed' and speed_limit != None:
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             fill='toself',
+                                             fillcolor='rgba' +
+                                             str(tuple(curr_color)),
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
+                elif fill_type == 'fill':
+                    fig.add_trace(go.Scatter(x=trace_x, y=trace_y,
+                                             mode='lines',
+                                             line_color=color,
+                                             fill='toself',
+                                             showlegend=False,
+                                             # text=theta,
+                                             name='lines'))
+            else:
+                raise ValueError(f'Unknown lane segment type {lane_seg.type}')
+    if fill_type == 'detailed':
+        fig.add_trace(go.Scatter(x=[0], y=[0],
+                                 mode='markers',
+                                 # fill='toself',
+                                 # fillcolor='rgba'+str(tuple(curr_color)),
+                                 marker=dict(
+            symbol='square',
+            size=16,
+            cmax=speed_max,
+            cmin=speed_min,
+            color='rgba(0,0,0,0)',
+            colorbar=dict(
+                title="Speed Limit"
+            ),
+            colorscale=[
+                [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+        ),
+            showlegend=False,
+        ))
+    fig.update_xaxes(range=[x_min, x_max])
+    fig.update_yaxes(range=[y_min, y_max])
+    return fig
+
+
+def plotly_map(map, color='rgba(0,0,0,1)', fig: go.Figure() = go.Figure()):
+    # if fig is None:
+    #     fig = go.Figure()
+    all_x = []
+    all_y = []
+    all_v = []
+    for lane_idx in map.lane_dict:
+        lane = map.lane_dict[lane_idx]
+        for lane_seg in lane.segment_list:
+            if lane_seg.type == 'Straight':
+                start1 = lane_seg.start + lane_seg.width/2 * lane_seg.direction_lateral
+                end1 = lane_seg.end + lane_seg.width/2 * lane_seg.direction_lateral
+                # fig.add_trace(go.Scatter(x=[start1[0], end1[0]], y=[start1[1], end1[1]],
+                #                          mode='lines',
+                #                          line_color='black',
+                #                          showlegend=False,
+                #                          # text=theta,
+                #                          name='lines'))
+                start2 = lane_seg.start - lane_seg.width/2 * lane_seg.direction_lateral
+                end2 = lane_seg.end - lane_seg.width/2 * lane_seg.direction_lateral
+                # fig.add_trace(go.Scatter(x=[start2[0], end2[0]], y=[start2[1], end2[1]],
+                #                          mode='lines',
+                #                          line_color='black',
+                #                          showlegend=False,
+                #                          # text=theta,
+                #                          name='lines'))
+                fig.add_trace(go.Scatter(x=[start1[0], end1[0], end2[0], start2[0], start1[0]], y=[start1[1], end1[1], end2[1], start2[1], start1[1]],
+                                         mode='lines',
+                                         line_color=color,
+                                         #  fill='toself',
+                                         #  fillcolor='rgba(255,255,255,0)',
+                                         #  line_color='rgba(0,0,0,0)',
+                                         showlegend=False,
+                                         # text=theta,
+                                         name='lines'))
+                # fig = go.Figure().add_heatmap(x=)
+                seg_x, seg_y, seg_v = lane_seg.get_all_speed()
+                all_x += seg_x
+                all_y += seg_y
+                all_v += seg_v
+            elif lane_seg.type == "Circular":
+                phase_array = np.linspace(
+                    start=lane_seg.start_phase, stop=lane_seg.end_phase, num=100)
+                r1 = lane_seg.radius - lane_seg.width/2
+                x = np.cos(phase_array)*r1 + lane_seg.center[0]
+                y = np.sin(phase_array)*r1 + lane_seg.center[1]
+                fig.add_trace(go.Scatter(x=x, y=y,
+                                         mode='lines',
+                                         line_color=color,
+                                         showlegend=False,
+                                         # text=theta,
+                                         name='lines'))
+
+                r2 = lane_seg.radius + lane_seg.width/2
+                x = np.cos(phase_array)*r2 + lane_seg.center[0]
+                y = np.sin(phase_array)*r2 + lane_seg.center[1]
+                fig.add_trace(go.Scatter(x=x, y=y,
+                                         mode='lines',
+                                         line_color=color,
+                                         showlegend=False,
+                                         # text=theta,
+                                         name='lines'))
+            else:
+                raise ValueError(f'Unknown lane segment type {lane_seg.type}')
+    start_color = [0, 0, 255, 0.2]
+    end_color = [255, 0, 0, 0.2]
+    curr_color = copy.deepcopy(start_color)
+    max_speed = max(all_v)
+    min_speed = min(all_v)
+
+    for i in range(len(all_v)):
+        # print(all_x[i])
+        # print(all_y[i])
+        # print(all_v[i])
+        curr_color = copy.deepcopy(start_color)
+        for j in range(len(curr_color)-1):
+            curr_color[j] += (all_v[i]-min_speed)/(max_speed -
+                                                   min_speed)*(end_color[j]-start_color[j])
+        fig.add_trace(go.Scatter(x=all_x[i], y=all_y[i],
+                                 mode='lines',
+                                 line_color='rgba(0,0,0,0)',
+                                 fill='toself',
+                                 fillcolor='rgba'+str(tuple(curr_color)),
+                                 #  marker=dict(
+                                 #     symbol='square',
+                                 #     size=16,
+                                 #     cmax=max_speed,
+                                 #     cmin=min_speed,
+                                 #     # color=all_v[i],
+                                 #     colorbar=dict(
+                                 #         title="Colorbar"
+                                 #     ),
+                                 #     colorscale=[
+                                 #         [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+                                 # ),
+                                 showlegend=False,
+                                 ))
+    fig.add_trace(go.Scatter(x=[0], y=[0],
+                             mode='markers',
+                             # fill='toself',
+                             # fillcolor='rgba'+str(tuple(curr_color)),
+                             marker=dict(
+                                 symbol='square',
+                                 size=16,
+                                 cmax=max_speed,
+                                 cmin=min_speed,
+                                 color='rgba(0,0,0,0)',
+                                 colorbar=dict(
+                                        title="Speed Limit"
+                                 ),
+                                 colorscale=[
+                                     [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]]
+    ),
+        showlegend=False,
+    ))
+    # fig.update_coloraxes(colorbar=dict(title="Colorbar"), colorscale=[
+    #                      [0, 'rgba'+str(tuple(start_color))], [1, 'rgba'+str(tuple(end_color))]])
+    return fig
+
+
+def draw_simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace'):
+    fig = draw_map(map=map, fig=fig, fill_type=map_type)
+    # mark
+    agent_list = root.agent.keys()
+    print(agent_list)
+    scheme_list = list(scheme_dict.keys())
+    i = 0
+    for agent_id in agent_list:
+        fig = draw_simulation_tree_single(
+            root, agent_id, fig, x_dim, y_dim, scheme_list[i], map_type)
+        i = (i+5) % 12
+    if scale_type == 'trace':
+        queue = [root]
+        x_min, x_max = float('inf'), -float('inf')
+        y_min, y_max = float('inf'), -float('inf')
+        scale_factor = 0.25
+    i = 0
+    queue = [root]
+    previous_mode = {}
+    agent_list = []
+    for agent_id in root.mode:
+        previous_mode[agent_id] = []
+        agent_list.append(agent_id)
+    text_pos = 'middle center'
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        # print(node.mode)
+        # [[time,x,y,theta,v]...]
+        i = 0
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            if scale_type == 'trace':
+                x_min = min(x_min, min(trace[:, x_dim]))
+                x_max = max(x_max, max(trace[:, x_dim]))
+                y_min = min(y_min, min(trace[:, y_dim]))
+                y_max = max(y_max, max(trace[:, y_dim]))
+            # print(trace)
+            # trace_y = trace[:, y_dim].tolist()
+            # trace_x = trace[:, x_dim].tolist()
+            # theta = [i/pi*180 for i in trace[:, 3]]
+            i = agent_list.index(agent_id)
+            # color = colors[i % 5]
+            # fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+            #                          mode='lines',
+            #                          line_color=color,
+            #                          text=[('{:.2f}'.format(trace_x[i]), '{:.2f}'.format(
+            #                              trace_y[i])) for i in range(len(trace_x))],
+            #                          showlegend=False)
+            #               #  name='lines')
+            #               )
+            if previous_mode[agent_id] != node.mode[agent_id]:
+                veh_mode = node.mode[agent_id][0]
+                if veh_mode == 'Normal':
+                    text_pos = 'middle center'
+                elif veh_mode == 'Brake':
+                    text_pos = 'middle left'
+                elif veh_mode == 'Accelerate':
+                    text_pos = 'middle right'
+                elif veh_mode == 'SwitchLeft':
+                    text_pos = 'top center'
+                elif veh_mode == 'SwitchRight':
+                    text_pos = 'bottom center'
+
+                fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
+                                         mode='markers+text',
+                                         line_color='rgba(255,255,255,0.3)',
+                                         text=str(agent_id)+': ' +
+                                         str(node.mode[agent_id][0]),
+                                         textposition=text_pos,
+                                         textfont=dict(
+                    #  family="sans serif",
+                    size=10,
+                                             color="grey"),
+                                         showlegend=False,
+                                         ))
+                # i += 1
+                previous_mode[agent_id] = node.mode[agent_id]
+        queue += node.child
+    if scale_type == 'trace':
+        fig.update_xaxes(
+            range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
+        fig.update_yaxes(
+            range=[y_min-scale_factor*(y_max-y_min), y_max+scale_factor*(y_max-y_min)])
+    return fig
+
+
+def draw_simulation_tree_single(root: AnalysisTreeNode, agent_id, fig: go.Figure() = go.Figure(), x_dim: int = 1, y_dim: int = 2, color=None, map_type='lines'):
+    global color_cnt
+    # fig = draw_map(map=map, fig=fig, fill_type=map_type)
+    queue = [root]
+    color_id = 0
+    if color == None:
+        color = list(scheme_dict.keys())[color_cnt]
+        color_cnt = (color_cnt+1) % 12
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        # print(node.mode)
+        # [[time,x,y,theta,v]...]
+        trace = np.array(traces[agent_id])
+        # print(trace)
+        # trace_y = trace[:, y_dim].tolist()
+        # trace_x = trace[:, x_dim].tolist()
+        # trace_x_rev = trace_x[::-1]
+        # # print(trace_x)
+        # trace_upper = [i+1 for i in trace_y]
+        # trace_lower = [i-1 for i in trace_y]
+        # trace_lower = trace_lower[::-1]
+        # # print(trace_upper)
+        # # print(trace[:, y_dim])
+        # fig.add_trace(go.Scatter(x=trace_x+trace_x_rev, y=trace_upper+trace_lower,
+        #                          fill='toself',
+        #                          fillcolor=bg_color[color_id],
+        #                          line_color='rgba(255,255,255,0)',
+        #                          showlegend=False))'
+        trace_text = []
+        for i in range(len(trace)):
+            trace_text.append([round(trace[i, j], 2)
+                              for j in range(trace.shape[1])])
+
+        fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+                                 mode='lines',
+                                 line_color=colors[scheme_dict[color]
+                                                   ][color_id],
+                                 #  text=[('{:.2f}'.format(trace[i, x_dim]), '{:.2f}'.format(
+                                 #      trace[i, y_dim])) for i in range(len(trace))],
+                                 text=trace_text,
+                                 #  name='lines',
+                                 showlegend=False))
+        color_id = (color_id+4) % 5
+        queue += node.child
+    return fig
+
+
+def draw_simulation_anime(root, map=None, fig=None):
+    # make figure
+    fig_dict = {
+        "data": [],
+        "layout": {},
+        "frames": []
+    }
+    # fig = plot_map(map, 'g', fig)
+    timed_point_dict = {}
+    stack = [root]
+    print("plot")
+    # print(root.mode)
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    # segment_start = set()
+    # previous_mode = {}
+    # for agent_id in root.mode:
+    #     previous_mode[agent_id] = []
+
+    while stack != []:
+        node = stack.pop()
+        traces = node.trace
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            print(trace)
+            # segment_start.add(round(trace[0][0], 2))
+            for i in range(len(trace)):
+                x_min = min(x_min, trace[i][1])
+                x_max = max(x_max, trace[i][1])
+                y_min = min(y_min, trace[i][2])
+                y_max = max(y_max, trace[i][2])
+                # print(round(trace[i][0], 2))
+                time_point = round(trace[i][0], 2)
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = [
+                        {agent_id: trace[i][1:].tolist()}]
+                else:
+                    init = False
+                    for record in timed_point_dict[time_point]:
+                        if list(record.values())[0] == trace[i][1:].tolist():
+                            init = True
+                            break
+                    if init == False:
+                        timed_point_dict[time_point].append(
+                            {agent_id: trace[i][1:].tolist()})
+            time = round(trace[i][0], 2)
+        stack += node.child
+    # fill in most of layout
+    # print(segment_start)
+    # print(timed_point_dict.keys())
+    duration = int(600/time)
+    fig_dict["layout"]["xaxis"] = {
+        "range": [(x_min-10), (x_max+10)],
+        "title": "x position"}
+    fig_dict["layout"]["yaxis"] = {
+        "range": [(y_min-2), (y_max+2)],
+        "title": "y position"}
+    fig_dict["layout"]["hovermode"] = "closest"
+    fig_dict["layout"]["updatemenus"] = [
+        {
+            "buttons": [
+                {
+                    "args": [None, {"frame": {"duration": duration, "redraw": False},
+                                    "fromcurrent": True, "transition": {"duration": duration,
+                                                                        "easing": "quadratic-in-out"}}],
+                    "label": "Play",
+                    "method": "animate"
+                },
+                {
+                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                      "mode": "immediate",
+                                      "transition": {"duration": 0}}],
+                    "label": "Pause",
+                    "method": "animate"
+                }
+            ],
+            "direction": "left",
+            "pad": {"r": 10, "t": 87},
+            "showactive": False,
+            "type": "buttons",
+            "x": 0.1,
+            "xanchor": "right",
+            "y": 0,
+            "yanchor": "top"
+        }
+    ]
+    sliders_dict = {
+        "active": 0,
+        "yanchor": "top",
+        "xanchor": "left",
+        "currentvalue": {
+            "font": {"size": 20},
+            "prefix": "time:",
+            "visible": True,
+            "xanchor": "right"
+        },
+        "transition": {"duration": duration, "easing": "cubic-in-out"},
+        "pad": {"b": 10, "t": 50},
+        "len": 0.9,
+        "x": 0.1,
+        "y": 0,
+        "steps": []
+    }
+    # make data
+    point_list = timed_point_dict[0]
+    print(point_list)
+    x_list = []
+    y_list = []
+    text_list = []
+    for data in point_list:
+        trace = list(data.values())[0]
+        # print(trace)
+        x_list.append(trace[0])
+        y_list.append(trace[1])
+        text_list.append(
+            ('{:.2f}'.format(trace[2]/pi*180), '{:.3f}'.format(trace[3])))
+    data_dict = {
+        "x": x_list,
+        "y": y_list,
+        "mode": "markers + text",
+        "text": text_list,
+        "textfont": dict(size=14, color="black"),
+        "textposition": "bottom center",
+        # "marker": {
+        #     "sizemode": "area",
+        #     "sizeref": 200000,
+        #     "size": 2
+        # },
+        "name": "Current Position"
+    }
+    fig_dict["data"].append(data_dict)
+
+    # make frames
+    for time_point in timed_point_dict:
+        # print(time_point)
+        frame = {"data": [], "layout": {
+            "annotations": []}, "name": '{:.2f}'.format(time_point)}
+        # print(timed_point_dict[time_point][0])
+        point_list = timed_point_dict[time_point]
+        # point_list = list(OrderedDict.fromkeys(timed_point_dict[time_point]))
+        # todokeyi
+        trace_x = []
+        trace_y = []
+        trace_theta = []
+        trace_v = []
+        for data in point_list:
+            trace = list(data.values())[0]
+            # print(trace)
+            trace_x.append(trace[0])
+            trace_y.append(trace[1])
+            trace_theta.append(trace[2])
+            trace_v.append(trace[3])
+        data_dict = {
+            "x": trace_x,
+            "y": trace_y,
+            "mode": "markers + text",
+            # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+            "text": [('{:.2f}'.format(trace_theta[i]/pi*180), '{:.3f}'.format(trace_v[i])) for i in range(len(trace_theta))],
+            "textfont": dict(size=14, color="black"),
+            "textposition": "bottom center",
+            # "marker": {
+            #     "sizemode": "area",
+            #     "sizeref": 200000,
+            #     "size": 2
+            # },
+            "name": "current position",
+            # "show_legend": False
+        }
+        frame["data"].append(data_dict)
+        for i in range(len(trace_x)):
+            ax = np.cos(trace_theta[i])*trace_v[i]
+            ay = np.sin(trace_theta[i])*trace_v[i]
+            # print(trace_x[i]+ax, trace_y[i]+ay)
+            annotations_dict = {"x": trace_x[i]+ax, "y": trace_y[i]+ay,
+                                # "xshift": ax, "yshift": ay,
+                                "ax": trace_x[i], "ay": trace_y[i],
+                                "arrowwidth": 2,
+                                # "arrowside": 'end',
+                                "showarrow": True,
+                                # "arrowsize": 1,
+                                "xref": 'x', "yref": 'y',
+                                "axref": 'x', "ayref": 'y',
+                                # "text": "erver",
+                                "arrowhead": 1,
+                                "arrowcolor": "black"}
+            frame["layout"]["annotations"].append(annotations_dict)
+
+            # if (time_point in segment_start) and (operator.ne(previous_mode[agent_id], node.mode[agent_id])):
+            #     annotations_dict = {"x": trace_x[i], "y": trace_y[i],
+            #                         # "xshift": ax, "yshift": ay,
+            #                         # "ax": trace_x[i], "ay": trace_y[i],
+            #                         # "arrowwidth": 2,
+            #                         # "arrowside": 'end',
+            #                         "showarrow": False,
+            #                         # "arrowsize": 1,
+            #                         # "xref": 'x', "yref": 'y',
+            #                         # "axref": 'x', "ayref": 'y',
+            #                         "text": str(node.mode[agent_id][0]),
+            #                         # "arrowhead": 1,
+            #                         # "arrowcolor": "black"
+            #                         }
+            #     frame["layout"]["annotations"].append(annotations_dict)
+            #     print(frame["layout"]["annotations"])
+            # i += 1
+            # previous_mode[agent_id] = node.mode[agent_id]
+
+        fig_dict["frames"].append(frame)
+        slider_step = {"args": [
+            [time_point],
+            {"frame": {"duration": duration, "redraw": False},
+             "mode": "immediate",
+             "transition": {"duration": duration}}
+        ],
+            "label": time_point,
+            "method": "animate"}
+        sliders_dict["steps"].append(slider_step)
+        # print(len(frame["layout"]["annotations"]))
+
+    fig_dict["layout"]["sliders"] = [sliders_dict]
+
+    fig = go.Figure(fig_dict)
+    fig = plotly_map(map, 'g', fig)
+    i = 0
+    queue = [root]
+    previous_mode = {}
+    agent_list = []
+    for agent_id in root.mode:
+        previous_mode[agent_id] = []
+        agent_list.append(agent_id)
+    text_pos = 'middle center'
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        # print(node.mode)
+        # [[time,x,y,theta,v]...]
+        i = 0
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            # print(trace)
+            trace_y = trace[:, 2].tolist()
+            trace_x = trace[:, 1].tolist()
+            # theta = [i/pi*180 for i in trace[:, 3]]
+            i = agent_list.index(agent_id)
+            color = colors[i % 5]
+            fig.add_trace(go.Scatter(x=trace[:, 1], y=trace[:, 2],
+                                     mode='lines',
+                                     line_color=color,
+                                     text=[(round(trace[i, 3]/pi*180, 2), round(trace[i, 4], 3))
+                                           for i in range(len(trace_y))],
+                                     showlegend=False)
+                          #  name='lines')
+                          )
+            if previous_mode[agent_id] != node.mode[agent_id]:
+                theta = trace[0, 3]
+                veh_mode = node.mode[agent_id][0]
+                if veh_mode == 'Normal':
+                    text_pos = 'middle center'
+                elif veh_mode == 'Brake':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'middle left'
+                    else:
+                        text_pos = 'middle right'
+                elif veh_mode == 'Accelerate':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'middle right'
+                    else:
+                        text_pos = 'middle left'
+                elif veh_mode == 'SwitchLeft':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'top center'
+                    else:
+                        text_pos = 'bottom center'
+                elif veh_mode == 'SwitchRight':
+                    if theta >= -pi/2 and theta <= pi/2:
+                        text_pos = 'bottom center'
+                    else:
+                        text_pos = 'top center'
+                fig.add_trace(go.Scatter(x=[trace[0, 1]], y=[trace[0, 2]],
+                                         mode='markers+text',
+                                         line_color='rgba(255,255,255,0.3)',
+                                         text=str(agent_id)+': ' +
+                                         str(node.mode[agent_id][0]),
+                                         textposition=text_pos,
+                                         textfont=dict(
+                    #  family="sans serif",
+                    size=10,
+                                             color="grey"),
+                                         showlegend=False,
+                                         ))
+                # i += 1
+                previous_mode[agent_id] = node.mode[agent_id]
+        queue += node.child
+    fig.update_traces(showlegend=False)
+    # fig.update_annotations(textfont=dict(size=14, color="black"))
+    # print(fig.frames[0].layout["annotations"])
+    return fig
+
+
+def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace'):
+    # make figure
+    fig_dict = {
+        "data": [],
+        "layout": {},
+        "frames": []
+    }
+    # fig = plot_map(map, 'g', fig)
+    timed_point_dict = {}
+    stack = [root]
+    # print("plot")
+    # print(root.mode)
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    # segment_start = set()
+    # previous_mode = {}
+    # for agent_id in root.mode:
+    #     previous_mode[agent_id] = []
+
+    while stack != []:
+        node = stack.pop()
+        traces = node.trace
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            print(trace)
+            # segment_start.add(round(trace[0][0], 2))
+            for i in range(len(trace)):
+                x_min = min(x_min, trace[i][x_dim])
+                x_max = max(x_max, trace[i][x_dim])
+                y_min = min(y_min, trace[i][y_dim])
+                y_max = max(y_max, trace[i][y_dim])
+                # print(round(trace[i][0], 2))
+                time_point = round(trace[i][0], 2)
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = [
+                        {agent_id: trace[i][0:].tolist()}]
+                else:
+                    init = False
+                    for record in timed_point_dict[time_point]:
+                        if list(record.values())[0] == trace[i][0:].tolist():
+                            init = True
+                            break
+                    if init == False:
+                        timed_point_dict[time_point].append(
+                            {agent_id: trace[i][0:].tolist()})
+            time = round(trace[i][0], 2)
+        stack += node.child
+    # fill in most of layout
+    # print(segment_start)
+    # print(timed_point_dict.keys())
+    duration = int(600/time)
+    fig_dict["layout"]["xaxis"] = {
+        "range": [x_min, x_max],
+        "title": "x position"}
+    fig_dict["layout"]["yaxis"] = {
+        "range": [y_min, y_max],
+        "title": "y position"}
+    fig_dict["layout"]["hovermode"] = "closest"
+    fig_dict["layout"]["updatemenus"] = [
+        {
+            "buttons": [
+                {
+                    "args": [None, {"frame": {"duration": duration, "redraw": False},
+                                    "fromcurrent": True, "transition": {"duration": duration,
+                                                                        "easing": "quadratic-in-out"}}],
+                    "label": "Play",
+                    "method": "animate"
+                },
+                {
+                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                      "mode": "immediate",
+                                      "transition": {"duration": 0}}],
+                    "label": "Pause",
+                    "method": "animate"
+                }
+            ],
+            "direction": "left",
+            "pad": {"r": 10, "t": 87},
+            "showactive": False,
+            "type": "buttons",
+            "x": 0.1,
+            "xanchor": "right",
+            "y": 0,
+            "yanchor": "top"
+        }
+    ]
+    sliders_dict = {
+        "active": 0,
+        "yanchor": "top",
+        "xanchor": "left",
+        "currentvalue": {
+            "font": {"size": 20},
+            "prefix": "time:",
+            "visible": True,
+            "xanchor": "right"
+        },
+        "transition": {"duration": duration, "easing": "cubic-in-out"},
+        "pad": {"b": 10, "t": 50},
+        "len": 0.9,
+        "x": 0.1,
+        "y": 0,
+        "steps": []
+    }
+    # make data
+    point_list = timed_point_dict[0]
+    print(point_list)
+    x_list = []
+    y_list = []
+    text_list = []
+    for data in point_list:
+        trace = list(data.values())[0]
+        # print(trace)
+        x_list.append(trace[x_dim])
+        y_list.append(trace[y_dim])
+        # text_list.append(
+        #     ('{:.2f}'.format(trace[x_dim]), '{:.2f}'.format(trace[y_dim])))
+        text_list.append([round(trace[i], 2) for i in range(len(trace))])
+    data_dict = {
+        "x": x_list,
+        "y": y_list,
+        "mode": "markers + text",
+        "text": text_list,
+        "textfont": dict(size=14, color="black"),
+        "textposition": "bottom center",
+        # "marker": {
+        #     "sizemode": "area",
+        #     "sizeref": 200000,
+        #     "size": 2
+        # },
+        "name": "Current Position"
+    }
+    fig_dict["data"].append(data_dict)
+
+    # make frames
+    for time_point in timed_point_dict:
+        # print(time_point)
+        frame = {"data": [], "layout": {
+            "annotations": []}, "name": '{:.2f}'.format(time_point)}
+        # print(timed_point_dict[time_point][0])
+        point_list = timed_point_dict[time_point]
+        # point_list = list(OrderedDict.fromkeys(timed_point_dict[time_point]))
+        # todokeyi
+        trace_x = []
+        trace_y = []
+        text_list = []
+        for data in point_list:
+            trace = list(data.values())[0]
+            # print(trace)
+            trace_x.append(trace[x_dim])
+            trace_y.append(trace[y_dim])
+            # text_list.append(
+            #     ('{:.2f}'.format(trace[x_dim]), '{:.2f}'.format(trace[y_dim])))
+            text_list.append([round(trace[i], 2) for i in range(len(trace))])
+        data_dict = {
+            "x": trace_x,
+            "y": trace_y,
+            "mode": "markers + text",
+            # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
+            "text": text_list,
+            "textfont": dict(size=14, color="black"),
+            "textposition": "bottom center",
+            # "marker": {
+            #     "sizemode": "area",
+            #     "sizeref": 200000,
+            #     "size": 2
+            # },
+            "name": "current position",
+            # "show_legend": False
+        }
+        frame["data"].append(data_dict)
+        # for i in range(len(trace_x)):
+        #     ax = np.cos(trace_theta[i])*trace_v[i]
+        #     ay = np.sin(trace_theta[i])*trace_v[i]
+        # print(trace_x[i]+ax, trace_y[i]+ay)
+        # annotations_dict = {"x": trace_x[i]+ax, "y": trace_y[i]+ay,
+        #                     # "xshift": ax, "yshift": ay,
+        #                     "ax": trace_x[i], "ay": trace_y[i],
+        #                     "arrowwidth": 2,
+        #                     # "arrowside": 'end',
+        #                     "showarrow": True,
+        #                     # "arrowsize": 1,
+        #                     "xref": 'x', "yref": 'y',
+        #                     "axref": 'x', "ayref": 'y',
+        #                     # "text": "erver",
+        #                     "arrowhead": 1,
+        #                     "arrowcolor": "black"}
+        # frame["layout"]["annotations"].append(annotations_dict)
+
+        fig_dict["frames"].append(frame)
+        slider_step = {"args": [
+            [time_point],
+            {"frame": {"duration": duration, "redraw": False},
+             "mode": "immediate",
+             "transition": {"duration": duration}}
+        ],
+            "label": time_point,
+            "method": "animate"}
+        sliders_dict["steps"].append(slider_step)
+        # print(len(frame["layout"]["annotations"]))
+
+    fig_dict["layout"]["sliders"] = [sliders_dict]
+
+    fig = go.Figure(fig_dict)
+    # print(map)
+    fig = draw_map(map, 'rgba(0,0,0,1)', fig, map_type)
+    i = 0
+    queue = [root]
+    previous_mode = {}
+    agent_list = []
+    for agent_id in root.mode:
+        previous_mode[agent_id] = []
+        agent_list.append(agent_id)
+    text_pos = 'middle center'
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        # print(node.mode)
+        # [[time,x,y,theta,v]...]
+        i = 0
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            # print(trace)
+            trace_y = trace[:, y_dim].tolist()
+            trace_x = trace[:, x_dim].tolist()
+            # theta = [i/pi*180 for i in trace[:, 3]]
+            i = agent_list.index(agent_id)
+            # color = colors[i % 5]
+            # fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+            #                          mode='lines',
+            #                          line_color=color,
+            #                          text=[('{:.2f}'.format(trace_x[i]), '{:.2f}'.format(
+            #                              trace_y[i])) for i in range(len(trace_x))],
+            #                          showlegend=False)
+            #               #  name='lines')
+            #               )
+            if previous_mode[agent_id] != node.mode[agent_id]:
+                veh_mode = node.mode[agent_id][0]
+                if veh_mode == 'Normal':
+                    text_pos = 'middle center'
+                elif veh_mode == 'Brake':
+                    text_pos = 'middle left'
+                elif veh_mode == 'Accelerate':
+                    text_pos = 'middle right'
+                elif veh_mode == 'SwitchLeft':
+                    text_pos = 'top center'
+                elif veh_mode == 'SwitchRight':
+                    text_pos = 'bottom center'
+
+                fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
+                                         mode='markers+text',
+                                         line_color='rgba(255,255,255,0.3)',
+                                         text=str(agent_id)+': ' +
+                                         str(node.mode[agent_id][0]),
+                                         textposition=text_pos,
+                                         textfont=dict(
+                    #  family="sans serif",
+                    size=10,
+                                             color="grey"),
+                                         showlegend=False,
+                                         ))
+                # i += 1
+                previous_mode[agent_id] = node.mode[agent_id]
+        queue += node.child
+    fig.update_traces(showlegend=False)
+    scale_factor = 0.5
+    if scale_type == 'trace':
+        fig.update_xaxes(
+            range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
+        fig.update_yaxes(
+            range=[y_min-scale_factor*(y_max-y_min), y_max+scale_factor*(y_max-y_min)])
+    # fig.update_annotations(textfont=dict(size=14, color="black"))
+    # print(fig.frames[0].layout["annotations"])
+    return fig
+
+
+def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type='lines', scale_type='trace'):
+    # make figure
+    fig_dict = {
+        "data": [],
+        "layout": {},
+        "frames": []
+    }
+    # fig = plot_map(map, 'g', fig)
+    timed_point_dict = {}
+    stack = [root]
+    print("plot")
+    # print(root.mode)
+    x_min, x_max = float('inf'), -float('inf')
+    y_min, y_max = float('inf'), -float('inf')
+    # segment_start = set()
+    # previous_mode = {}
+    # for agent_id in root.mode:
+    #     previous_mode[agent_id] = []
+
+    while stack != []:
+        node = stack.pop()
+        traces = node.trace
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            # print(trace)
+            # segment_start.add(round(trace[0][0], 2))
+            for i in range(len(trace)):
+                x_min = min(x_min, trace[i][x_dim])
+                x_max = max(x_max, trace[i][x_dim])
+                y_min = min(y_min, trace[i][y_dim])
+                y_max = max(y_max, trace[i][y_dim])
+                # print(round(trace[i][0], 2))
+                time_point = round(trace[i][0], 2)
+                tmp_trace = trace[i][0:].tolist()
+                if time_point not in timed_point_dict:
+                    timed_point_dict[time_point] = {agent_id: [tmp_trace]}
+                else:
+                    if agent_id not in timed_point_dict[time_point].keys():
+                        timed_point_dict[time_point][agent_id] = [tmp_trace]
+                    elif tmp_trace not in timed_point_dict[time_point][agent_id]:
+                        timed_point_dict[time_point][agent_id].append(
+                            tmp_trace)
+            time = round(trace[i][0], 2)
+        stack += node.child
+    # fill in most of layout
+    # print(segment_start)
+    # print(timed_point_dict.keys())
+    duration = int(100/time)
+    duration = 1
+    fig_dict["layout"]["xaxis"] = {
+        # "range": [x_min, x_max],
+        "title": "x position"}
+    fig_dict["layout"]["yaxis"] = {
+        # "range": [y_min, y_max],
+        "title": "y position"}
+    fig_dict["layout"]["hovermode"] = "closest"
+    fig_dict["layout"]["updatemenus"] = [
+        {
+            "buttons": [
+                {
+                    "args": [None, {"frame": {"duration": duration, "redraw": False},
+                                    "fromcurrent": True, "transition": {"duration": duration,
+                                                                        "easing": "quadratic-in-out"}}],
+                    "label": "Play",
+                    "method": "animate"
+                },
+                {
+                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
+                                      "mode": "immediate",
+                                      "transition": {"duration": 0}}],
+                    "label": "Pause",
+                    "method": "animate"
+                }
+            ],
+            "direction": "left",
+            "pad": {"r": 10, "t": 87},
+            "showactive": False,
+            "type": "buttons",
+            "x": 0.1,
+            "xanchor": "right",
+            "y": 0,
+            "yanchor": "top"
+        }
+    ]
+    sliders_dict = {
+        "active": 0,
+        "yanchor": "top",
+        "xanchor": "left",
+        "currentvalue": {
+            "font": {"size": 20},
+            "prefix": "time:",
+            "visible": True,
+            "xanchor": "right"
+        },
+        "transition": {"duration": duration, "easing": "cubic-in-out"},
+        "pad": {"b": 10, "t": 50},
+        "len": 0.9,
+        "x": 0.1,
+        "y": 0,
+        "steps": []
+    }
+    # # make data
+    trace_dict = timed_point_dict[0]
+
+    x_list = []
+    y_list = []
+    text_list = []
+    for time_point in timed_point_dict:
+        trace_dict = timed_point_dict[time_point]
+        for agent_id, point_list in trace_dict.items():
+            for point in point_list:
+                x_list.append(point[x_dim])
+                y_list.append(point[y_dim])
+                # text_list.append(
+                #     ('{:.2f}'.format(point[x_dim]), '{:.2f}'.format(point[y_dim])))
+                text_list.append([round(point[i], 2)
+                                 for i in range(len(point))])
+        data_dict = {
+            "x": x_list,
+            "y": y_list,
+            "mode": "markers",
+            "text": text_list,
+            "textfont": dict(size=14, color="black"),
+            "visible": False,
+            "textposition": "bottom center",
+            # "marker": {
+            #     "sizemode": "area",
+            #     "sizeref": 200000,
+            #     "size": 2
+            # },
+            "name": "Current Position"
+        }
+        fig_dict["data"].append(data_dict)
+    time_list = list(timed_point_dict.keys())
+    agent_list = list(trace_dict.keys())
+    # start = time_list[0]
+    # step = time_list[1]-start
+    trail_limit = min(10, len(time_list))
+    # print(agent_list)
+    # make frames
+    for time_point_id in range(trail_limit, len(time_list)):
+        time_point = time_list[time_point_id]
+        frame = {"data": [], "layout": {
+            "annotations": []}, "name": '{:.2f}'.format(time_point)}
+        # todokeyi
+        trail_len = min(time_point_id+1, trail_limit)
+        opacity_step = 1/trail_len
+        size_step = 2/trail_len
+        min_size = 5
+        for agent_id in agent_list:
+            for id in range(0, trail_len, 2):
+                tmp_point_list = timed_point_dict[time_list[time_point_id-id]][agent_id]
+                trace_x = []
+                trace_y = []
+                text_list = []
+                for point in tmp_point_list:
+                    trace_x.append(point[x_dim])
+                    trace_y.append(point[y_dim])
+                    # text_list.append(
+                    #     ('{:.2f}'.format(point[x_dim]), '{:.2f}'.format(point[y_dim])))
+                    text_list.append([round(point[i], 2)
+                                     for i in range(len(point))])
+                #  print(trace_y)
+                if id == 0:
+                    data_dict = {
+                        "x": trace_x,
+                        "y": trace_y,
+                        "mode": "markers",
+                        "text": text_list,
+                        "textfont": dict(size=6, color="black"),
+                        "textposition": "bottom center",
+                        "visible": True,
+                        "marker": {
+                            "color": 'Black',
+                            "opacity": opacity_step*(trail_len-id),
+                            # "sizemode": "area",
+                            # "sizeref": 200000,
+                            "size": min_size + size_step*(trail_len-id)
+                        },
+                        "name": "current position",
+                        # "show_legend": False
+                    }
+                else:
+                    data_dict = {
+                        "x": trace_x,
+                        "y": trace_y,
+                        "mode": "markers",
+                        "text": text_list,
+                        # "textfont": dict(size=6, color="black"),
+                        # "textposition": "bottom center",
+                        "visible": True,
+                        "marker": {
+                            "color": 'Black',
+                            "opacity": opacity_step*(trail_len-id),
+                            # "sizemode": "area",
+                            # "sizeref": 200000,
+                            "size": min_size + size_step*(trail_len-id)
+                        },
+                        "name": "current position",
+                        # "show_legend": False
+                    }
+                frame["data"].append(data_dict)
+
+        fig_dict["frames"].append(frame)
+        slider_step = {"args": [
+            [time_point],
+            {"frame": {"duration": duration, "redraw": False},
+             "mode": "immediate",
+             "transition": {"duration": duration}}
+        ],
+            "label": time_point,
+            "method": "animate"}
+        sliders_dict["steps"].append(slider_step)
+        # print(len(frame["layout"]["annotations"]))
+
+    fig_dict["layout"]["sliders"] = [sliders_dict]
+
+    fig = go.Figure(fig_dict)
+    fig = draw_map(map, 'rgba(0,0,0,1)', fig, map_type)
+    i = 0
+    queue = [root]
+    previous_mode = {}
+    agent_list = []
+    for agent_id in root.mode:
+        previous_mode[agent_id] = []
+        agent_list.append(agent_id)
+    text_pos = 'middle center'
+    while queue != []:
+        node = queue.pop(0)
+        traces = node.trace
+        # print(node.mode)
+        # [[time,x,y,theta,v]...]
+        i = 0
+        for agent_id in traces:
+            trace = np.array(traces[agent_id])
+            # print(trace)
+            trace_y = trace[:, y_dim].tolist()
+            trace_x = trace[:, x_dim].tolist()
+            # theta = [i/pi*180 for i in trace[:, 3]]
+            i = agent_list.index(agent_id)
+            # color = colors[i % 5]
+            # fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
+            #                          mode='lines',
+            #                          line_color=color,
+            #                          text=[('{:.2f}'.format(trace_x[i]), '{:.2f}'.format(
+            #                              trace_y[i])) for i in range(len(trace_x))],
+            #                          showlegend=False)
+            #               #  name='lines')
+            #               )
+            if previous_mode[agent_id] != node.mode[agent_id]:
+                veh_mode = node.mode[agent_id][0]
+                if veh_mode == 'Normal':
+                    text_pos = 'middle center'
+                elif veh_mode == 'Brake':
+                    text_pos = 'middle left'
+                elif veh_mode == 'Accelerate':
+                    text_pos = 'middle right'
+                elif veh_mode == 'SwitchLeft':
+                    text_pos = 'top center'
+                elif veh_mode == 'SwitchRight':
+                    text_pos = 'bottom center'
+
+                fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
+                                         mode='markers+text',
+                                         line_color='rgba(255,255,255,0.3)',
+                                         text=str(agent_id)+': ' +
+                                         str(node.mode[agent_id][0]),
+                                         textposition=text_pos,
+                                         textfont=dict(
+                    #  family="sans serif",
+                    size=10,
+                                             color="grey"),
+                                         showlegend=False,
+                                         ))
+                # i += 1
+                previous_mode[agent_id] = node.mode[agent_id]
+        queue += node.child
+    fig.update_traces(showlegend=False)
+    scale_factor = 0.5
+    if scale_type == 'trace':
+        fig.update_xaxes(
+            range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
+        fig.update_yaxes(
+            range=[y_min-scale_factor*(y_max-y_min), y_max+scale_factor*(y_max-y_min)])
+    # fig.update_annotations(textfont=dict(size=14, color="black"))
+    # print(fig.frames[0].layout["annotations"])
+    return fig
+
+
+# The 'color' property is a color and may be specified as:
+#       - A hex string (e.g. '#ff0000')
+#       - An rgb/rgba string (e.g. 'rgb(255,0,0)')
+#       - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
+#       - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
+#       - A named CSS color:
+#             aliceblue, antiquewhite, aqua, aquamarine, azure,
+#             beige, bisque, black, blanchedalmond, blue,
+#             blueviolet, brown, burlywood, cadetblue,
+#             chartreuse, chocolate, coral, cornflowerblue,
+#             cornsilk, crimson, cyan, darkblue, darkcyan,
+#             darkgoldenrod, darkgray, darkgrey, darkgreen,
+#             darkkhaki, darkmagenta, darkolivegreen, darkorange,
+#             darkorchid, darkred, darksalmon, darkseagreen,
+#             darkslateblue, darkslategray, darkslategrey,
+#             darkturquoise, darkviolet, deeppink, deepskyblue,
+#             dimgray, dimgrey, dodgerblue, firebrick,
+#             floralwhite, forestgreen, fuchsia, gainsboro,
+#             ghostwhite, gold, goldenrod, gray, grey, green,
+#             greenyellow, honeydew, hotpink, indianred, indigo,
+#             ivory, khaki, lavender, lavenderblush, lawngreen,
+#             lemonchiffon, lightblue, lightcoral, lightcyan,
+#             lightgoldenrodyellow, lightgray, lightgrey,
+#             lightgreen, lightpink, lightsalmon, lightseagreen,
+#             lightskyblue, lightslategray, lightslategrey,
+#             lightsteelblue, lightyellow, lime, limegreen,
+#             linen, magenta, maroon, mediumaquamarine,
+#             mediumblue, mediumorchid, mediumpurple,
+#             mediumseagreen, mediumslateblue, mediumspringgreen,
+#             mediumturquoise, mediumvioletred, midnightblue,
+#             mintcream, mistyrose, moccasin, navajowhite, navy,
+#             oldlace, olive, olivedrab, orange, orangered,
+#             orchid, palegoldenrod, palegreen, paleturquoise,
+#             palevioletred, papayawhip, peachpuff, peru, pink,
+#             plum, powderblue, purple, red, rosybrown,
+#             royalblue, rebeccapurple, saddlebrown, salmon,
+#             sandybrown, seagreen, seashell, sienna, silver,
+#             skyblue, slateblue, slategray, slategrey, snow,
+#             springgreen, steelblue, tan, teal, thistle, tomato,
+#             turquoise, violet, wheat, white, whitesmoke,
+#             yellow, yellowgreen
diff --git a/dryvr_plus_plus/plotter/plotter_README.md b/dryvr_plus_plus/plotter/plotter_README.md
new file mode 100644
index 0000000000000000000000000000000000000000..59eb5a686e46282ec4408ff589e67ba8545a1f24
--- /dev/null
+++ b/dryvr_plus_plus/plotter/plotter_README.md
@@ -0,0 +1,105 @@
+# Plotly-based Plotter Development Notes
+
+Now the latest version is placed in plotter2D_new.py. All functions in the plotter2D.py still work. Every function are in developemnt and might change.
+
+## Current work & Todo
+- **Animation with trails** supported in test_simu_anime() and will be tested further.
+- **Modified accelerating mode** modified, will be tested
+- **new quadrotor agent** next
+- **different color for segments of trace** done.
+
+## Functions
+Belows are the functions currently used. Some of the functions in the file are deprecated.
+
+#### general_reachtube_anime(root, map, fig, x_dim, y_dim, map_type)
+
+The genernal plotter for reachtube animation. It draws the all traces of reachtubes and the map. Animation is implemented as rectangles.
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.verify().
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
+- **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
+
+#### draw_reachtube_tree_v2(root, agent_id, fig, x_dim, y_dim, color_id, map_type)
+
+The genernal static plotter for reachtube tree. It draws the all traces of reachtubes and the map.
+The original version is implemented with rectangle and very inefficient.
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.verify().
+- **agent_id:** the id of target agent. It should a string, which is the id/name of agent.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
+- **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
+- **color_id:** a int indicating the color. Now 10 kinds of colors are supported. If it is None, the colors of segments will be auto-assigned. The default value is None.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill' or 'detailed'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors. For the 'detailed' mode, it will vistualize the speed limit information (if exists) by fill the lanes with different colors. The Default value is 'lines'.
+
+#### draw_map(map, color, fig, fill_type)
+
+The genernal static plotter for map. It is called in many functions drawing traces, so it is often unnecessary to call it separately.
+
+**parameters:**
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py.
+- **color** the color of the margin of the lanes, should be a string like 'black' or in rgb/rgba format, like 'rgb(0,0,0)' or 'rgba(0,0,0,1)'. The default value is 'rgba(0,0,0,1)' which is non-transparent black.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **fill_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
+
+#### plotly_map(map, color, fig):
+
+The old ungenernal static plotter for map which support visualization of speed limit of lanes which is a pending feature.
+
+**parameters:**
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py. It doesn't handle the map with wrong format. Only SimpleMap3_v2() class is supported now.
+- **color** the color of the margin of the lanes, should be a string like 'black' or in rgb/rgba format, like 'rgb(0,0,0)' or 'rgba(0,0,0,1)'. The default value is 'rgba(0,0,0,1)' which is non-transparent black.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+
+#### draw_simulation_tree(root, map, fig, x_dim, y_dim, map_type, scale_type):
+
+The genernal static plotter for simulation trees. It draws the traces of agents and map.
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.simulate().
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
+- **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors.
+- **scale_type** the way to draw the map. It should be 'trace' or 'map'. For the 'trace' mode, the plot will be scaled to show all traces. For the 'map' mode, the plot will be scaled to show the whole map. The Default value is 'trace'.
+
+#### draw_simulation_tree_single(root, agent_id, x_dim, y_dim, color_id, fig):
+
+The genernal static plotter for simulation tree. It draws the  traces of one specific agent.
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.simulate().
+- **agent_id:** the id of target agent. It should a string, which is the id/name of agent.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
+- **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
+- **color_id:** a int indicating the color. Now 10 kinds of colors are supported. If it is None, the colors of segments will be auto-assigned. The default value is None.
+
+#### draw_simulation_anime(root, map, fig)
+
+The old ungenernal plotter for simulation animation. It draws the all traces and the map. Animation is implemented as points and arrows. 
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.simulate().
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py. It doesn't handle the map with wrong format. Only SimpleMap3_v2() class is supported now.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+
+#### general_simu_anime(root, map, fig, x_dim, y_dim, map_type, scale_type):
+
+The genernal plotter for simulation animation. It draws the all traces and the map. Animation is implemented as points. Since arrow is hard to generalize.
+
+**parameters:**
+- **root:** the root node of the trace, should be the return value of Scenario.simulate().
+- **map:** the map of the scenario, templates are in dryvr_plus_plus.example.example_map.simple_map2.py.
+- **fig:** the object of the figure, its type should be plotly.graph_objects.Figure().
+- **x_dim:** the dimension of x coordinate in the trace list of every time step. The Default value is 1.
+- **y_dim:** the dimension of y coordinate in the trace list of every time step. The Default value is 2.
+- **map_type** the way to draw the map. It should be 'lines' or 'fill'. For the 'lines' mode, map is only drawn by margins of lanes. For the 'fill' mode, the lanes will be filled semitransparent colors. The Default value is 'lines'.
+- **scale_type** the way to draw the map. It should be 'trace' or 'map'. For the 'trace' mode, the plot will be scaled to show all traces. For the 'map' mode, the plot will be scaled to show the whole map. The Default value is 'trace'.
+
diff --git a/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py b/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py
index f307e0313fc87a60dda4d109daa1ce1c8e7a9364..381de10d0e4095168d1112299cd889050e72423e 100644
--- a/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py
+++ b/dryvr_plus_plus/scene_verifier/analysis/analysis_tree_node.py
@@ -1,13 +1,14 @@
 from typing import List, Dict
 
+
 class AnalysisTreeNode:
     """AnalysisTreeNode class
     A AnalysisTreeNode stores the continous execution of the system without transition happening"""
     trace: Dict
     """The trace for each agent. 
     The key of the dict is the agent id and the value of the dict is simulated traces for each agent"""
-    init: Dict 
-    
+    init: Dict
+
     def __init__(
         self,
         trace={},
@@ -15,14 +16,14 @@ class AnalysisTreeNode:
         mode={},
         agent={},
         child=[],
-        start_time = 0,
-        ndigits = 10,
-        type = 'simtrace'
+        start_time=0,
+        ndigits=10,
+        type='simtrace'
     ):
-        self.trace:Dict = trace
+        self.trace: Dict = trace
         self.init: Dict[str, List[float]] = init
         self.mode: Dict[str, List[str]] = mode
-        self.agent:Dict = agent
-        self.child:List[AnalysisTreeNode] = child
-        self.start_time:float = round(start_time,ndigits)
-        self.type:str = type
+        self.agent: Dict = agent
+        self.child: List[AnalysisTreeNode] = child
+        self.start_time: float = round(start_time, ndigits)
+        self.type: str = type
diff --git a/dryvr_plus_plus/scene_verifier/analysis/simulator.py b/dryvr_plus_plus/scene_verifier/analysis/simulator.py
index 853122e70a76c806346dda383fa2fa90e47a8d98..c7567cef5aaa6562b0cb16b252400f48df702491 100644
--- a/dryvr_plus_plus/scene_verifier/analysis/simulator.py
+++ b/dryvr_plus_plus/scene_verifier/analysis/simulator.py
@@ -7,6 +7,7 @@ import numpy as np
 from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
 from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
 
+
 class Simulator:
     def __init__(self):
         self.simulation_tree_root = None
@@ -19,7 +20,7 @@ class Simulator:
             mode={},
             agent={},
             child=[],
-            start_time = 0,
+            start_time=0,
         )
         for i, agent in enumerate(agent_list):
             root.init[agent.id] = init_list[i]
@@ -41,6 +42,7 @@ class Simulator:
             for agent_id in node.agent:
                 if agent_id not in node.trace:
                     # Simulate the trace starting from initial condition
+                    # [time, x, y, theta, v]
                     mode = node.mode[agent_id]
                     init = node.init[agent_id]
                     trace = node.agent[agent_id].TC_simulate(mode, init, remain_time, time_step, lane_map)
@@ -53,8 +55,10 @@ class Simulator:
             if not transitions:
                 continue
 
+
             # truncate the computed trajectories from idx and store the content after truncate
             truncated_trace = {}
+            print("idx", idx)
             for agent_idx in node.agent:
                 truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:]
                 node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1]
@@ -63,9 +67,10 @@ class Simulator:
             transition_list = list(transitions.values())
             all_transition_combinations = itertools.product(*transition_list)
 
-            # For each possible transition, construct the new node. 
+            # For each possible transition, construct the new node.
             # Obtain the new initial condition for agent having transition
             # copy the traces that are not under transition
+
             for transition_combination in all_transition_combinations:
                 next_node_mode = copy.deepcopy(node.mode) 
                 next_node_agent = node.agent 
@@ -82,15 +87,15 @@ class Simulator:
                 for agent_idx in next_node_agent:
                     if agent_idx not in next_node_init:
                         next_node_trace[agent_idx] = truncated_trace[agent_idx]
-                    
+
                 tmp = AnalysisTreeNode(
-                    trace = next_node_trace,
-                    init = next_node_init,
-                    mode = next_node_mode,
-                    agent = next_node_agent,
-                    child = [],
-                    start_time = next_node_start_time,
-                    type = 'simtrace'
+                    trace=next_node_trace,
+                    init=next_node_init,
+                    mode=next_node_mode,
+                    agent=next_node_agent,
+                    child=[],
+                    start_time=next_node_start_time,
+                    type='simtrace'
                 )
                 node.child.append(tmp)
                 simulation_queue.append(tmp)
diff --git a/dryvr_plus_plus/scene_verifier/automaton/guard.py b/dryvr_plus_plus/scene_verifier/automaton/guard.py
index cfb409a8cbc619789070b321a9796c622a9ee3e4..33908c1caeb5f008576572e9876f1e940b5bf2ac 100644
--- a/dryvr_plus_plus/scene_verifier/automaton/guard.py
+++ b/dryvr_plus_plus/scene_verifier/automaton/guard.py
@@ -13,13 +13,16 @@ import numpy as np
 from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
 from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane
 from dryvr_plus_plus.scene_verifier.utils.utils import *
+
+
 class LogicTreeNode:
-    def __init__(self, data, child = [], val = None, mode_guard = None):
-        self.data = data 
+    def __init__(self, data, child=[], val=None, mode_guard=None):
+        self.data = data
         self.child = child
         self.val = val
         self.mode_guard = mode_guard
 
+
 class NodeSubstituter(ast.NodeTransformer):
     def __init__(self, old_node, new_node):
         super().__init__()
@@ -88,6 +91,7 @@ class GuardExpressionAst:
         self.cont_variables = {}
         self.varDict = {}
 
+
     def _build_guard(self, guard_str, agent):
         """
         Build solver for current guard based on guard string
@@ -107,21 +111,25 @@ class GuardExpressionAst:
         # Thus we need to replace "==" to something else
         sympy_guard_str = guard_str.replace("==", ">=")
         for vars in self.cont_variables:
-            sympy_guard_str = sympy_guard_str.replace(vars, self.cont_variables[vars])
+            sympy_guard_str = sympy_guard_str.replace(
+                vars, self.cont_variables[vars])
 
-        symbols = list(sympy.sympify(sympy_guard_str, evaluate=False).free_symbols)
+        symbols = list(sympy.sympify(
+            sympy_guard_str, evaluate=False).free_symbols)
         symbols = [str(s) for s in symbols]
         tmp = list(self.cont_variables.values())
         symbols_map = {}
         for s in symbols:
             if s in tmp:
-                key = list(self.cont_variables.keys())[list(self.cont_variables.values()).index(s)]
+                key = list(self.cont_variables.keys())[
+                    list(self.cont_variables.values()).index(s)]
                 symbols_map[s] = key
 
         for vars in reversed(self.cont_variables):
             guard_str = guard_str.replace(vars, self.cont_variables[vars])
         guard_str = self._handleReplace(guard_str)
-        cur_solver.add(eval(guard_str))  # TODO use an object instead of `eval` a string
+        # TODO use an object instead of `eval` a string
+        cur_solver.add(eval(guard_str))
         return cur_solver, symbols_map
 
     def _handleReplace(self, input_str):
@@ -132,7 +140,7 @@ class GuardExpressionAst:
                 And(y<=0,t>=0.2,v>=-0.1)
             output: 
                 And(self.varDic["y"]<=0,self.varDic["t"]>=0.2,self.varDic["v"]>=-0.1)
-        
+
         Args:
             input_str (str): original string need to be replaced
             keys (list): list of variable strings
@@ -151,7 +159,8 @@ class GuardExpressionAst:
             for i in range(len(input_str)):
                 if input_str[i:].startswith(key):
                     idxes.append((i, i + len(key)))
-                    input_str = input_str[:i] + "@" * len(key) + input_str[i + len(key):]
+                    input_str = input_str[:i] + "@" * \
+                        len(key) + input_str[i + len(key):]
 
         idxes = sorted(idxes)
 
@@ -167,31 +176,35 @@ class GuardExpressionAst:
         is_contained = False
 
         for cont_vars in continuous_variable_dict:
-            self.cont_variables[cont_vars] = cont_vars.replace('.','_')
-            self.varDict[cont_vars.replace('.','_')] = Real(cont_vars.replace('.','_'))
+            self.cont_variables[cont_vars] = cont_vars.replace('.', '_')
+            self.varDict[cont_vars.replace('.', '_')] = Real(
+                cont_vars.replace('.', '_'))
 
-        z3_string = self.generate_z3_expression() 
+        z3_string = self.generate_z3_expression()
         if isinstance(z3_string, bool):
             if z3_string:
-                return True, True 
+                return True, True
             else:
                 return False, False
 
         cur_solver, symbols = self._build_guard(z3_string, agent)
         cur_solver.push()
         for symbol in symbols:
-            cur_solver.add(self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0])
-            cur_solver.add(self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1])
+            cur_solver.add(
+                self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0])
+            cur_solver.add(
+                self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1])
         if cur_solver.check() == sat:
             # The reachtube hits the guard
             cur_solver.pop()
             res = True
-            
             tmp_solver = Solver()
             tmp_solver.add(Not(cur_solver.assertions()[0]))
             for symbol in symbols:
-                tmp_solver.add(self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0])
-                tmp_solver.add(self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1])
+                tmp_solver.add(
+                    self.varDict[symbol] >= continuous_variable_dict[symbols[symbol]][0])
+                tmp_solver.add(
+                    self.varDict[symbol] <= continuous_variable_dict[symbols[symbol]][1])
             if tmp_solver.check() == unsat:
                 print("Full intersect, break")
                 is_contained = True
@@ -234,7 +247,7 @@ class GuardExpressionAst:
 
         If without evaluating the continuous variables the result is True, then
         the guard condition will automatically be satisfied
-        
+
         If without evaluating the continuous variables the result is False, then
         the guard condition will not be satisfied
 
@@ -246,11 +259,11 @@ class GuardExpressionAst:
             # For each value in the boolop, check results
             if isinstance(node.op, ast.And):
                 z3_str = []
-                for i,val in enumerate(node.values):
+                for i, val in enumerate(node.values):
                     tmp = self._generate_z3_expression_node(val)
                     if isinstance(tmp, bool):
                         if tmp:
-                            continue 
+                            continue
                         else:
                             return False
                     z3_str.append(tmp)
@@ -276,7 +289,7 @@ class GuardExpressionAst:
                 return z3_str
             # If string, construct string
             # If bool, check result and discard/evaluate result according to operator
-            pass 
+            pass
         elif isinstance(node, ast.Constant):
             # If is bool, return boolean result
             if isinstance(node.value, bool):
@@ -287,7 +300,7 @@ class GuardExpressionAst:
                 expr = expr.strip('\n')
                 return expr
         elif isinstance(node, ast.UnaryOp):
-            # If is UnaryOp, 
+            # If is UnaryOp,
             value = self._generate_z3_expression_node(node.operand)
             if isinstance(node.op, ast.USub):
                 return -value
@@ -302,7 +315,7 @@ class GuardExpressionAst:
             expr = expr.strip('\n')
             return expr
 
-    def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map:LaneMap):
+    def evaluate_guard_hybrid(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map: LaneMap):
         """
         Handle guard atomics that contains both continuous and hybrid variables
         Especially, we want to handle function calls that need both continuous and 
@@ -312,40 +325,47 @@ class GuardExpressionAst:
         By doing this, all calls that need both continuous and discrete variables as input will now become only continuous
         variables. We can then handle these using what we already have for the continous variables
         """
-        res = True 
+        res = True
         for i, node in enumerate(self.ast_list):
-            tmp, self.ast_list[i] = self._evaluate_guard_hybrid(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map)
-            res = res and tmp 
+            tmp, self.ast_list[i] = self._evaluate_guard_hybrid(
+                node, agent, discrete_variable_dict, continuous_variable_dict, lane_map)
+            res = res and tmp
         return res
 
-    def _evaluate_guard_hybrid(self, root, agent, disc_var_dict, cont_var_dict, lane_map:LaneMap):
-        if isinstance(root, ast.Compare): 
+    def _evaluate_guard_hybrid(self, root, agent, disc_var_dict, cont_var_dict, lane_map: LaneMap):
+        if isinstance(root, ast.Compare):
             expr = astunparse.unparse(root)
-            left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map)
-            right, root.comparators[0] = self._evaluate_guard_hybrid(root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map)
+            left, root.left = self._evaluate_guard_hybrid(
+                root.left, agent, disc_var_dict, cont_var_dict, lane_map)
+            right, root.comparators[0] = self._evaluate_guard_hybrid(
+                root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map)
             return True, root
         elif isinstance(root, ast.BoolOp):
             if isinstance(root.op, ast.And):
                 res = True
                 for i, val in enumerate(root.values):
-                    tmp, root.values[i] = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map)
-                    res = res and tmp 
+                    tmp, root.values[i] = self._evaluate_guard_hybrid(
+                        val, agent, disc_var_dict, cont_var_dict, lane_map)
+                    res = res and tmp
                     if not res:
-                        break 
-                return res, root 
+                        break
+                return res, root
             elif isinstance(root.op, ast.Or):
                 res = False
                 for val in root.values:
-                    tmp,val = self._evaluate_guard_hybrid(val, agent, disc_var_dict, cont_var_dict, lane_map)
+                    tmp, val = self._evaluate_guard_hybrid(
+                        val, agent, disc_var_dict, cont_var_dict, lane_map)
                     res = res or tmp
                 return res, root  
         elif isinstance(root, ast.BinOp):
-            left, root.left = self._evaluate_guard_hybrid(root.left, agent, disc_var_dict, cont_var_dict, lane_map)
-            right, root.right = self._evaluate_guard_hybrid(root.right, agent, disc_var_dict, cont_var_dict, lane_map)
+            left, root.left = self._evaluate_guard_hybrid(
+                root.left, agent, disc_var_dict, cont_var_dict, lane_map)
+            right, root.right = self._evaluate_guard_hybrid(
+                root.right, agent, disc_var_dict, cont_var_dict, lane_map)
             return True, root
         elif isinstance(root, ast.Call):
             if isinstance(root.func, ast.Attribute):
-                func = root.func        
+                func = root.func
                 if func.value.id == 'lane_map':
                     if func.attr == 'get_lateral_distance':
                         # Get function arguments
@@ -373,16 +393,21 @@ class GuardExpressionAst:
                         vehicle_pos = (arg1_lower, arg1_upper)
 
                         # Get corresponding lane segments with respect to the set of vehicle pos
-                        lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower)
-                        lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper)
+                        lane_seg1 = lane_map.get_lane_segment(
+                            vehicle_lane, arg1_lower)
+                        lane_seg2 = lane_map.get_lane_segment(
+                            vehicle_lane, arg1_upper)
 
                         # Compute the set of possible lateral values with respect to all possible segments
-                        lateral_set1 = self._handle_lateral_set(lane_seg1, np.array(vehicle_pos))
-                        lateral_set2 = self._handle_lateral_set(lane_seg2, np.array(vehicle_pos))
+                        lateral_set1 = self._handle_lateral_set(
+                            lane_seg1, np.array(vehicle_pos))
+                        lateral_set2 = self._handle_lateral_set(
+                            lane_seg2, np.array(vehicle_pos))
 
                         # Use the union of two sets as the set of possible lateral positions
-                        lateral_set = [min(lateral_set1[0], lateral_set2[0]), max(lateral_set1[1], lateral_set2[1])]
-                        
+                        lateral_set = [min(lateral_set1[0], lateral_set2[0]), max(
+                            lateral_set1[1], lateral_set2[1])]
+
                         # Construct the tmp variable
                         tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}'
                         # Add the tmp variable to the cont var dict
@@ -417,16 +442,21 @@ class GuardExpressionAst:
                         vehicle_pos = (arg1_lower, arg1_upper)
 
                         # Get corresponding lane segments with respect to the set of vehicle pos
-                        lane_seg1 = lane_map.get_lane_segment(vehicle_lane, arg1_lower)
-                        lane_seg2 = lane_map.get_lane_segment(vehicle_lane, arg1_upper)
+                        lane_seg1 = lane_map.get_lane_segment(
+                            vehicle_lane, arg1_lower)
+                        lane_seg2 = lane_map.get_lane_segment(
+                            vehicle_lane, arg1_upper)
 
                         # Compute the set of possible longitudinal values with respect to all possible segments
-                        longitudinal_set1 = self._handle_longitudinal_set(lane_seg1, np.array(vehicle_pos))
-                        longitudinal_set2 = self._handle_longitudinal_set(lane_seg2, np.array(vehicle_pos))
+                        longitudinal_set1 = self._handle_longitudinal_set(
+                            lane_seg1, np.array(vehicle_pos))
+                        longitudinal_set2 = self._handle_longitudinal_set(
+                            lane_seg2, np.array(vehicle_pos))
 
                         # Use the union of two sets as the set of possible longitudinal positions
-                        longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max(longitudinal_set1[1], longitudinal_set2[1])]
-                        
+                        longitudinal_set = [min(longitudinal_set1[0], longitudinal_set2[0]), max(
+                            longitudinal_set1[1], longitudinal_set2[1])]
+
                         # Construct the tmp variable
                         tmp_var_name = f'tmp_variable{len(cont_var_dict)+1}'
                         # Add the tmp variable to the cont var dict
@@ -435,13 +465,16 @@ class GuardExpressionAst:
                         root = ast.parse(tmp_var_name).body[0].value
                         return True, root
                     else:
-                        raise ValueError(f'Node type {func} from {astunparse.unparse(func)} is not supported')
+                        raise ValueError(
+                            f'Node type {func} from {astunparse.unparse(func)} is not supported')
                 else:
-                    raise ValueError(f'Node type {func} from {astunparse.unparse(func)} is not supported')
+                    raise ValueError(
+                        f'Node type {func} from {astunparse.unparse(func)} is not supported')
             else:
-                raise ValueError(f'Node type {root.func} from {astunparse.unparse(root.func)} is not supported')   
+                raise ValueError(
+                    f'Node type {root.func} from {astunparse.unparse(root.func)} is not supported')
         elif isinstance(root, ast.Attribute):
-            return True, root 
+            return True, root
         elif isinstance(root, ast.Constant):
             return root.value, root 
         elif isinstance(root, ast.Name):
@@ -455,32 +488,36 @@ class GuardExpressionAst:
                     root.operand = ast.parse('False').body[0].value
                     return True, ast.parse('True').body[0].value
             else:
-                raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
-            return True, root 
+                raise ValueError(
+                    f'Node type {root} from {astunparse.unparse(root)} is not supported')
+            return True, root
         else:
-            raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+            raise ValueError(
+                f'Node type {root} from {astunparse.unparse(root)} is not supported')
 
     def _handle_longitudinal_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]:
         if lane_seg.type == "Straight":
             # Delta lower
-            delta0 = position[0,:] - lane_seg.start
+            delta0 = position[0, :] - lane_seg.start
             # Delta upper
-            delta1 = position[1,:] - lane_seg.start
+            delta1 = position[1, :] - lane_seg.start
 
             longitudinal_low = min(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \
-                min(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1])
+                min(delta0[1]*lane_seg.direction[1],
+                    delta1[1]*lane_seg.direction[1])
             longitudinal_high = max(delta0[0]*lane_seg.direction[0], delta1[0]*lane_seg.direction[0]) + \
-                max(delta0[1]*lane_seg.direction[1], delta1[1]*lane_seg.direction[1])
+                max(delta0[1]*lane_seg.direction[1],
+                    delta1[1]*lane_seg.direction[1])
             longitudinal_low += lane_seg.longitudinal_start
             longitudinal_high += lane_seg.longitudinal_start
 
             assert longitudinal_high >= longitudinal_low
-            return longitudinal_low, longitudinal_high            
+            return longitudinal_low, longitudinal_high
         elif lane_seg.type == "Circular":
             # Delta lower
-            delta0 = position[0,:] - lane_seg.center
+            delta0 = position[0, :] - lane_seg.center
             # Delta upper
-            delta1 = position[1,:] - lane_seg.center
+            delta1 = position[1, :] - lane_seg.center
 
             phi0 = np.min([
                 np.arctan2(delta0[1], delta0[0]),
@@ -495,50 +532,66 @@ class GuardExpressionAst:
                 np.arctan2(delta1[1], delta1[0]),
             ])
 
-            phi0 = lane_seg.start_phase + wrap_to_pi(phi0 - lane_seg.start_phase)
-            phi1 = lane_seg.start_phase + wrap_to_pi(phi1 - lane_seg.start_phase)
+            phi0 = lane_seg.start_phase + \
+                wrap_to_pi(phi0 - lane_seg.start_phase)
+            phi1 = lane_seg.start_phase + \
+                wrap_to_pi(phi1 - lane_seg.start_phase)
             longitudinal_low = min(
-                lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius,
-                lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius
+                lane_seg.direction *
+                (phi0 - lane_seg.start_phase)*lane_seg.radius,
+                lane_seg.direction *
+                (phi1 - lane_seg.start_phase)*lane_seg.radius
             ) + lane_seg.longitudinal_start
             longitudinal_high = max(
-                lane_seg.direction * (phi0 - lane_seg.start_phase)*lane_seg.radius,
-                lane_seg.direction * (phi1 - lane_seg.start_phase)*lane_seg.radius
+                lane_seg.direction *
+                (phi0 - lane_seg.start_phase)*lane_seg.radius,
+                lane_seg.direction *
+                (phi1 - lane_seg.start_phase)*lane_seg.radius
             ) + lane_seg.longitudinal_start
 
             assert longitudinal_high >= longitudinal_low
             return longitudinal_low, longitudinal_high
         else:
-            raise ValueError(f'Lane segment with type {lane_seg.type} is not supported')
+            raise ValueError(
+                f'Lane segment with type {lane_seg.type} is not supported')
 
     def _handle_lateral_set(self, lane_seg: AbstractLane, position: np.ndarray) -> List[float]:
         if lane_seg.type == "Straight":
             # Delta lower
-            delta0 = position[0,:] - lane_seg.start
+            delta0 = position[0, :] - lane_seg.start
             # Delta upper
-            delta1 = position[1,:] - lane_seg.start
+            delta1 = position[1, :] - lane_seg.start
 
             lateral_low = min(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \
-                min(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1])
+                min(delta0[1]*lane_seg.direction_lateral[1],
+                    delta1[1]*lane_seg.direction_lateral[1])
             lateral_high = max(delta0[0]*lane_seg.direction_lateral[0], delta1[0]*lane_seg.direction_lateral[0]) + \
-                max(delta0[1]*lane_seg.direction_lateral[1], delta1[1]*lane_seg.direction_lateral[1])
+                max(delta0[1]*lane_seg.direction_lateral[1],
+                    delta1[1]*lane_seg.direction_lateral[1])
             assert lateral_high >= lateral_low
             return lateral_low, lateral_high
         elif lane_seg.type == "Circular":
-            dx = np.max([position[0,0]-lane_seg.center[0],0,lane_seg.center[0]-position[1,0]])
-            dy = np.max([position[0,1]-lane_seg.center[1],0,lane_seg.center[1]-position[1,1]])
+            dx = np.max([position[0, 0]-lane_seg.center[0],
+                        0, lane_seg.center[0]-position[1, 0]])
+            dy = np.max([position[0, 1]-lane_seg.center[1],
+                        0, lane_seg.center[1]-position[1, 1]])
             r_low = np.linalg.norm([dx, dy])
 
-            dx = np.max([np.abs(position[0,0]-lane_seg.center[0]),np.abs(position[1,0]-lane_seg.center[0])])
-            dy = np.max([np.abs(position[0,1]-lane_seg.center[1]),np.abs(position[1,1]-lane_seg.center[1])])
+            dx = np.max([np.abs(position[0, 0]-lane_seg.center[0]),
+                        np.abs(position[1, 0]-lane_seg.center[0])])
+            dy = np.max([np.abs(position[0, 1]-lane_seg.center[1]),
+                        np.abs(position[1, 1]-lane_seg.center[1])])
             r_high = np.linalg.norm([dx, dy])
-            lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low))
-            lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high),lane_seg.direction*(lane_seg.radius - r_low))
+            lateral_low = min(lane_seg.direction*(lane_seg.radius - r_high),
+                              lane_seg.direction*(lane_seg.radius - r_low))
+            lateral_high = max(lane_seg.direction*(lane_seg.radius - r_high),
+                               lane_seg.direction*(lane_seg.radius - r_low))
             # print(lateral_low, lateral_high)
             assert lateral_high >= lateral_low
             return lateral_low, lateral_high
         else:
-            raise ValueError(f'Lane segment with type {lane_seg.type} is not supported')
+            raise ValueError(
+                f'Lane segment with type {lane_seg.type} is not supported')
 
     def evaluate_guard_disc(self, agent, discrete_variable_dict, continuous_variable_dict, lane_map):
         """
@@ -546,50 +599,55 @@ class GuardExpressionAst:
         """
         res = True
         for i, node in enumerate(self.ast_list):
-            tmp, self.ast_list[i] = self._evaluate_guard_disc(node, agent, discrete_variable_dict, continuous_variable_dict, lane_map)
-            res = res and tmp 
+            tmp, self.ast_list[i] = self._evaluate_guard_disc(
+                node, agent, discrete_variable_dict, continuous_variable_dict, lane_map)
+            res = res and tmp
         return res
-            
+
     def _evaluate_guard_disc(self, root, agent, disc_var_dict, cont_var_dict, lane_map):
         """
         Recursively called function to evaluate guard with only discrete variables
         The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by
         boolean constants
-        
+
         :params:
         :return: The return value will be a tuple. The first element in the tuple will either be a boolean value or a the evaluated value of of an expression involving guard
         The second element in the tuple will be the updated ast node 
         """
         if isinstance(root, ast.Compare):
             expr = astunparse.unparse(root)
-            left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map)
-            right, root.comparators[0] = self._evaluate_guard_disc(root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map)
+            left, root.left = self._evaluate_guard_disc(
+                root.left, agent, disc_var_dict, cont_var_dict, lane_map)
+            right, root.comparators[0] = self._evaluate_guard_disc(
+                root.comparators[0], agent, disc_var_dict, cont_var_dict, lane_map)
             if isinstance(left, bool) or isinstance(right, bool):
                 return True, root
             if isinstance(root.ops[0], ast.GtE):
-                res = left>=right
+                res = left >= right
             elif isinstance(root.ops[0], ast.Gt):
-                res = left>right 
+                res = left > right
             elif isinstance(root.ops[0], ast.Lt):
-                res = left<right
+                res = left < right
             elif isinstance(root.ops[0], ast.LtE):
-                res = left<=right
+                res = left <= right
             elif isinstance(root.ops[0], ast.Eq):
-                res = left == right 
+                res = left == right
             elif isinstance(root.ops[0], ast.NotEq):
-                res = left != right 
+                res = left != right
             else:
-                raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+                raise ValueError(
+                    f'Node type {root} from {astunparse.unparse(root)} is not supported')
             if res:
                 root = ast.parse('True').body[0].value
             else:
-                root = ast.parse('False').body[0].value    
+                root = ast.parse('False').body[0].value
             return res, root
         elif isinstance(root, ast.BoolOp):
             if isinstance(root.op, ast.And):
                 res = True
-                for i,val in enumerate(root.values):
-                    tmp,root.values[i] = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map)
+                for i, val in enumerate(root.values):
+                    tmp, root.values[i] = self._evaluate_guard_disc(
+                        val, agent, disc_var_dict, cont_var_dict, lane_map)
                     res = res and tmp
                     if not res:
                         break
@@ -597,13 +655,16 @@ class GuardExpressionAst:
             elif isinstance(root.op, ast.Or):
                 res = False
                 for val in root.values:
-                    tmp,val = self._evaluate_guard_disc(val, agent, disc_var_dict, cont_var_dict, lane_map)
+                    tmp, val = self._evaluate_guard_disc(
+                        val, agent, disc_var_dict, cont_var_dict, lane_map)
                     res = res or tmp
                 return res, root     
         elif isinstance(root, ast.BinOp):
             # Check left and right in the binop and replace all attributes involving discrete variables
-            left, root.left = self._evaluate_guard_disc(root.left, agent, disc_var_dict, cont_var_dict, lane_map)
-            right, root.right = self._evaluate_guard_disc(root.right, agent, disc_var_dict, cont_var_dict, lane_map)
+            left, root.left = self._evaluate_guard_disc(
+                root.left, agent, disc_var_dict, cont_var_dict, lane_map)
+            right, root.right = self._evaluate_guard_disc(
+                root.right, agent, disc_var_dict, cont_var_dict, lane_map)
             return True, root
         elif isinstance(root, ast.Call):
             expr = astunparse.unparse(root)
@@ -623,7 +684,7 @@ class GuardExpressionAst:
                     if res:
                         root = ast.parse('True').body[0].value
                     else:
-                        root = ast.parse('False').body[0].value    
+                        root = ast.parse('False').body[0].value
                 else:
                     for mode_name in agent.controller.modes:
                         if res in agent.controller.modes[mode_name]:
@@ -658,7 +719,8 @@ class GuardExpressionAst:
                     root.operand = ast.parse('False').body[0].value
                     return True, ast.parse('True').body[0].value
             else:
-                raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+                raise ValueError(
+                    f'Node type {root} from {astunparse.unparse(root)} is not supported')
             return True, root
         elif isinstance(root, ast.Name):
             expr = root.id
@@ -672,7 +734,8 @@ class GuardExpressionAst:
             else:
                 return True, root 
         else:
-            raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+            raise ValueError(
+                f'Node type {root} from {astunparse.unparse(root)} is not supported')
 
     def evaluate_guard(self, agent, continuous_variable_dict, discrete_variable_dict, lane_map):
         res = True
@@ -685,28 +748,32 @@ class GuardExpressionAst:
 
     def _evaluate_guard(self, root, agent, cnts_var_dict, disc_var_dict, lane_map):
         if isinstance(root, ast.Compare):
-            left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map)
-            right = self._evaluate_guard(root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map)
+            left = self._evaluate_guard(
+                root.left, agent, cnts_var_dict, disc_var_dict, lane_map)
+            right = self._evaluate_guard(
+                root.comparators[0], agent, cnts_var_dict, disc_var_dict, lane_map)
             if isinstance(root.ops[0], ast.GtE):
-                return left>=right
+                return left >= right
             elif isinstance(root.ops[0], ast.Gt):
-                return left>right 
+                return left > right
             elif isinstance(root.ops[0], ast.Lt):
-                return left<right
+                return left < right
             elif isinstance(root.ops[0], ast.LtE):
-                return left<=right
+                return left <= right
             elif isinstance(root.ops[0], ast.Eq):
-                return left == right 
+                return left == right
             elif isinstance(root.ops[0], ast.NotEq):
-                return left != right 
+                return left != right
             else:
-                raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+                raise ValueError(
+                    f'Node type {root} from {astunparse.unparse(root)} is not supported')
 
         elif isinstance(root, ast.BoolOp):
             if isinstance(root.op, ast.And):
                 res = True
                 for val in root.values:
-                    tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map)
+                    tmp = self._evaluate_guard(
+                        val, agent, cnts_var_dict, disc_var_dict, lane_map)
                     res = res and tmp
                     if not res:
                         break
@@ -714,20 +781,24 @@ class GuardExpressionAst:
             elif isinstance(root.op, ast.Or):
                 res = False
                 for val in root.values:
-                    tmp = self._evaluate_guard(val, agent, cnts_var_dict, disc_var_dict, lane_map)
+                    tmp = self._evaluate_guard(
+                        val, agent, cnts_var_dict, disc_var_dict, lane_map)
                     res = res or tmp
                     if res:
                         break
                 return res
         elif isinstance(root, ast.BinOp):
-            left = self._evaluate_guard(root.left, agent, cnts_var_dict, disc_var_dict, lane_map)
-            right = self._evaluate_guard(root.right, agent, cnts_var_dict, disc_var_dict, lane_map)
+            left = self._evaluate_guard(
+                root.left, agent, cnts_var_dict, disc_var_dict, lane_map)
+            right = self._evaluate_guard(
+                root.right, agent, cnts_var_dict, disc_var_dict, lane_map)
             if isinstance(root.op, ast.Sub):
                 return left - right
             elif isinstance(root.op, ast.Add):
                 return left + right
             else:
-                raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+                raise ValueError(
+                    f'Node type {root} from {astunparse.unparse(root)} is not supported')
         elif isinstance(root, ast.Call):
             expr = astunparse.unparse(root)
             # Check if the root is a function
@@ -743,7 +814,7 @@ class GuardExpressionAst:
                 for arg in disc_var_dict:
                     expr = expr.replace(arg, f'"{disc_var_dict[arg]}"')
                 for arg in cnts_var_dict:
-                    expr = expr.replace(arg, str(cnts_var_dict[arg]))    
+                    expr = expr.replace(arg, str(cnts_var_dict[arg]))
                 res = eval(expr)
                 for mode_name in agent.controller.modes:
                     if res in agent.controller.modes[mode_name]:
@@ -768,7 +839,8 @@ class GuardExpressionAst:
         elif isinstance(root, ast.Constant):
             return root.value
         elif isinstance(root, ast.UnaryOp):
-            val = self._evaluate_guard(root.operand, agent, cnts_var_dict, disc_var_dict, lane_map)
+            val = self._evaluate_guard(
+                root.operand, agent, cnts_var_dict, disc_var_dict, lane_map)
             if isinstance(root.op, ast.USub):
                 return -val
             if isinstance(root.op, ast.Not):
@@ -790,7 +862,9 @@ class GuardExpressionAst:
             else:
                 raise ValueError(f"{variable} doesn't exist in either continuous varibales or discrete variables") 
         else:
-            raise ValueError(f'Node type {root} from {astunparse.unparse(root)} is not supported')
+            raise ValueError(
+                f'Node type {root} from {astunparse.unparse(root)} is not supported')
+
 
     def parse_any_all(self, cont_var_dict: Dict[str, float], disc_var_dict: Dict[str, float], len_dict: Dict[str, int]) -> None: 
         for i in range(len(self.ast_list)):
@@ -1042,9 +1116,9 @@ class GuardExpressionAst:
         return root
 
 if __name__ == "__main__":
-    with open('tmp.pickle','rb') as f:
+    with open('tmp.pickle', 'rb') as f:
         guard_list = pickle.load(f)
     tmp = GuardExpressionAst(guard_list)
     # tmp.evaluate_guard()
     # tmp.construct_tree_from_str('(other_x-ego_x<20) and other_x-ego_x>10 and other_vehicle_lane==ego_vehicle_lane')
-    print("stop")
\ No newline at end of file
+    print("stop")
diff --git a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py b/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py
index 780c4bd2a2a06f5a730d6774de9ac82b3080e91f..d276e52d5d6264797e8e5bf0dac4fc94d1842121 100644
--- a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py
+++ b/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py
@@ -1,21 +1,23 @@
-#parse python file
+# parse python file
 
-#REQUIRES PYTHON 3.8!
+# REQUIRES PYTHON 3.8!
 from cgitb import reset
-#import clang.cindex
+# import clang.cindex
 import typing
 import json
 import sys
 from typing import List, Tuple
-import re 
+import re
 import itertools
 import ast
-
+import time
 from treelib import Node, Tree
 
 '''
 Edge class: utility class to hold the source, dest, guards, and resets for a transition
 '''
+
+
 class Edge:
     def __init__(self, source, dest, guards, resets):
         self.source = source
@@ -23,18 +25,21 @@ class Edge:
         self.guards = guards
         self.resets = resets
 
+
 '''
-Statement super class. Holds the code and mode information for a statement. 
+Statement super class. Holds the code and mode information for a statement.
 If there is no mode information, mode and modeType are None.
 '''
+
+
 class Statement:
-    def __init__(self, code, mode, modeType, func = None, args = None):
+    def __init__(self, code, mode, modeType, func=None, args=None):
         self.code = code
         self.modeType = modeType
         self.mode = mode
-        self.func = func 
+        self.func = func
         self.args = args
-    
+
     def print(self):
         print(self.code)
 
@@ -42,26 +47,29 @@ class Statement:
 '''
 Guard class. Subclass of statement.
 '''
+
+
 class Guard(Statement):
     def __init__(self, code, mode, modeType, inp_ast, func=None, args=None):
         super().__init__(code, mode, modeType, func, args)
         self.ast = inp_ast
 
-
     '''
-    Returns true if a guard is checking that we are in a mode. 
+    Returns true if a guard is checking that we are in a mode.
     '''
+
     def isModeCheck(self):
         return self.modeType != None
 
     '''
     Helper function to parse a node that contains a guard. Parses out the code and mode.
-    Returns a Guard. 
+    Returns a Guard.
     TODO: needs to handle more complex guards.
     '''
     def parseGuard(node, code):
-        #assume guard is a strict comparision (modeType == mode)
-        if isinstance(node.test, ast.Compare):
+        # assume guard is a strict comparision (modeType == mode)
+        # keyi mark: may support more comparision
+        if isinstance(node.test, ast.Compare):  # ==
             if isinstance(node.test.comparators[0], ast.Attribute):
                 if ("Mode" in str(node.test.comparators[0].value.id)):
                     modeType = str(node.test.comparators[0].value.id)
@@ -69,9 +77,9 @@ class Guard(Statement):
                     return Guard(ast.get_source_segment(code, node.test), mode, modeType, node.test)
             else:
                 return Guard(ast.get_source_segment(code, node.test), None, None, node.test)
-        elif isinstance(node.test, ast.BoolOp):
+        elif isinstance(node.test, ast.BoolOp):  # or and
             return Guard(ast.get_source_segment(code, node.test), None, None, node.test)
-        elif isinstance(node.test, ast.Call):
+        elif isinstance(node.test, ast.Call):  # function not used
             source_segment = ast.get_source_segment(code, node.test)
             # func = node.test.func.id 
             # args = []
@@ -79,29 +87,33 @@ class Guard(Statement):
             #     args.append(arg.value.id + '.' + arg.attr)
             return Guard(source_segment, None, None, node.test)
 
+
 '''
 Reset class. Subclass of statement.
 '''
+
+
 class Reset(Statement):
     def __init__(self, code, mode, modeType, inp_ast):
         super().__init__(code, mode, modeType)
         self.ast = inp_ast
 
     '''
-    Returns true if a reset is updating our mode. 
+    Returns true if a reset is updating our mode.
     '''
+
     def isModeUpdate(self):
         return self.modeType != None
 
     '''
     Helper function to parse a node that contains a reset. Parses out the code and mode.
-    Returns a reset. 
+    Returns a reset.
     '''
     def parseReset(node, code):
-        #assume reset is modeType = newMode
+        # assume reset is modeType = newMode
         if isinstance(node.value, ast.Attribute):
-            #print("resets " + str(node.value.value.id))
-            #print("resets " + str(node.value.attr))
+            # print("resets " + str(node.value.value.id))
+            # print("resets " + str(node.value.attr))
             if ("Mode" in str(node.value.value.id)):
                 modeType = str(node.value.value.id)
                 mode = str(node.value.attr)
@@ -112,6 +124,8 @@ class Reset(Statement):
 '''
 Util class to handle building transitions given a path.
 '''
+
+
 class TransitionUtil:
     '''
     Takes in a list of reset objects. Returns a string in the format json expected.
@@ -119,13 +133,13 @@ class TransitionUtil:
     def resetString(resets):
         outstr = ""
         for reset in resets:
-            outstr+= reset.code + ";"
+            outstr += reset.code + ";"
         outstr = outstr.strip(";")
         return outstr
 
     '''
     Takes in guard code. Returns a string in the format json expected.
-    TODO: needs to handle more complex guards. 
+    TODO: needs to handle more complex guards.
     '''
     def parseGuardCode(code):
         parts = code.split("and")
@@ -137,33 +151,33 @@ class TransitionUtil:
         return out
 
     '''
-    Helper function for parseGuardCode. 
+    Helper function for parseGuardCode.
     '''
     def guardString(guards):
         output = ""
         first = True
-        for guard in guards: 
-            #print(type(condition))
+        for guard in guards:
+            # print(type(condition))
             if first:
-                output+= TransitionUtil.parseGuardCode(guard.code)
+                output += TransitionUtil.parseGuardCode(guard.code)
             else:
-                output = "And(" + TransitionUtil.parseGuardCode(guard.code) + ",(" + output + "))"
+                output = "And(" + TransitionUtil.parseGuardCode(guard.code) + \
+                    ",(" + output + "))"
             first = False
         return output
 
-
     '''
     Helper function to get the index of the vertex for a set of modes.
     Modes is a list of all modes in the current vertex.
-    Vertices is the list of vertices. 
+    Vertices is the list of vertices.
     TODO: needs to be tested on more complex examples to see if ordering stays and we can use index function
     '''
     def getIndex(modes, vertices):
         return vertices.index(tuple(modes))
 
     '''
-    Function that creates transitions given a path. 
-    Will create multiple transitions if not all modeTypes are checked/set in the path. 
+    Function that creates transitions given a path.
+    Will create multiple transitions if not all modeTypes are checked/set in the path.
     Returns a list of edges that correspond to the path.
     '''
     def createTransition(path, vertices, modes):
@@ -189,8 +203,8 @@ class TransitionUtil:
         for modeType in modes.keys():
             foundMode = False
             for condition in modeChecks:
-                #print(condition.modeType)
-                #print(modeType)
+                # print(condition.modeType)
+                # print(modeType)
                 if condition.modeType == modeType:
                     sourceModes.append(condition.mode)
                     foundMode = True
@@ -232,6 +246,7 @@ class TransitionUtil:
                 edges.append(Edge(sourceindex, destindex, guards, resets))
         return edges
 
+
 class ControllerAst():
     '''
     Initalizing function for a controllerAst object.
@@ -239,13 +254,14 @@ class ControllerAst():
     Statement tree is a tree of nodes that contain a list in their data. The list contains a single guard or a list of resets.
     Variables (inputs to the controller) are collected.
     Modes are collected from all enums that have the word "mode" in them.
-    Vertices are generated by taking the products of mode types. 
+    Vertices are generated by taking the products of mode types.
     '''
-    def __init__(self, code = None, file_name = None):
+
+    def __init__(self, code=None, file_name=None):
         assert code is not None or file_name is not None
         if file_name is not None:
-            with open(file_name,'r') as f:
-                code = f.read()        
+            with open(file_name, 'r') as f:
+                code = f.read()
 
         self.code = code
         self.tree = ast.parse(code)
@@ -261,62 +277,65 @@ class ControllerAst():
     '''
     Function to populate paths variable with all paths of the controller.
     '''
+
     def getAllPaths(self):
         self.paths = self.getNextModes([], True)
         return self.paths
-    
+
     '''
-    getNextModes takes in a list of current modes. It should include all modes. 
+    getNextModes takes in a list of current modes. It should include all modes.
     getNextModes returns a list of paths that can be followed when in the given mode.
     A path is a list of statements, all guards and resets along the path. They are in the order they are encountered in the code.
     TODO: should we not force all modes be listed? Or rerun for each unknown/don't care node? Or add them all to the list
     '''
-    def getNextModes(self, currentModes: List[str], getAllPaths= False) -> List[str]:
-        #walk the tree and capture all paths that have modes that are listed. Path is a list of statements
+
+    def getNextModes(self, currentModes: List[str], getAllPaths=False) -> List[str]:
+        # walk the tree and capture all paths that have modes that are listed. Path is a list of statements
         paths = []
         rootid = self.statementtree.root
         currnode = self.statementtree.get_node(rootid)
         paths = self.walkstatements(currnode, currentModes, getAllPaths)
-        
-        return paths 
+
+        return paths
 
     '''
     Helper function to walk the statement tree from parentnode and find paths that are allowed in the currentMode.
-    Returns a list of paths. 
+    Returns a list of paths.
     '''
-    def walkstatements(self, parentnode, currentModes, getAllPaths):
-        nextsPaths = []
 
+    def walkstatements(self, parentnode: Node, currentModes, getAllPaths):
+        nextsPaths = []
+        # print("walkstatements", parentnode.tag)
         for node in self.statementtree.children(parentnode.identifier):
             statement = node.data
-            
+            # if parentnode.tag == "ego.vehicle_mode == VehicleMode.Brake":
+            #     print(statement[0])
             if isinstance(statement[0], Guard) and statement[0].isModeCheck():
                 if getAllPaths or statement[0].mode in currentModes:
-                    #print(statement.mode)
-                    newPaths = self.walkstatements(node, currentModes, getAllPaths)
+                    newPaths = self.walkstatements(
+                        node, currentModes, getAllPaths)
                     for path in newPaths:
                         newpath = statement.copy()
                         newpath.extend(path)
                         nextsPaths.append(newpath)
                     if len(nextsPaths) == 0:
                         nextsPaths.append(statement)
-    
+
             else:
-                newPaths =self.walkstatements(node, currentModes, getAllPaths)
+                newPaths = self.walkstatements(node, currentModes, getAllPaths)
                 for path in newPaths:
                     newpath = statement.copy()
                     newpath.extend(path)
                     nextsPaths.append(newpath)
                 if len(nextsPaths) == 0:
-                            nextsPaths.append(statement)
- 
+                    nextsPaths.append(statement)
         return nextsPaths
 
-
     '''
     Function to create a json of the full graph.
-    Requires that paths class variables has been set. 
+    Requires that paths class variables has been set.
     '''
+
     def create_json(self, input_file_name, output_file_name):
         if not self.paths:
             print("Cannot call create_json without calling getAllPaths")
@@ -335,22 +354,23 @@ class ControllerAst():
         resets = []
 
         for path in self.paths:
-            transitions = TransitionUtil.createTransition(path, self.vertices, self.modes)
+            transitions = TransitionUtil.createTransition(
+                path, self.vertices, self.modes)
             for edge in transitions:
                 edges.append([edge.source, edge.dest])
                 guards.append(TransitionUtil.guardString(edge.guards))
                 resets.append(TransitionUtil.resetString(edge.resets))
-    
+
         output_dict['vertex'] = self.vertexStrings
-        #print(vertices)
+        # print(vertices)
         output_dict['variables'] = self.variables
         # #add edge, transition(guards) and resets
         output_dict['edge'] = edges
-        #print(len(edges))
+        # print(len(edges))
         output_dict['guards'] = guards
-        #print(len(guards))
+        # print(len(guards))
         output_dict['resets'] = resets
-        #print(len(resets))
+        # print(len(resets))
 
         output_json = json.dumps(output_dict, indent=4)
         outfile = open(output_file_name, "w")
@@ -359,11 +379,12 @@ class ControllerAst():
 
         print("wrote json to " + output_file_name)
 
-    #inital tree walk, parse into a tree of resets/modes
+    # inital tree walk, parse into a tree of resets/modes
     '''
-    Function called by init function. Walks python ast and parses to a statement tree. 
+    Function called by init function. Walks python ast and parses to a statement tree.
     Returns a statement tree (nodes contain a list of either a single guard or muliple resets), the variables, and a mode dictionary
     '''
+
     def initalwalktree(self, code, tree):
         vars = []
         discrete_vars = []
@@ -372,7 +393,8 @@ class ControllerAst():
         state_object_dict = {}
         vars_dict = {}
         statementtree = Tree()
-        for node in ast.walk(tree): #don't think we want to walk the whole thing because lose ordering/depth
+        # don't think we want to walk the whole thing because lose ordering/depth
+        for node in ast.walk(tree):
             # Get all the modes
             if isinstance(node, ast.ClassDef):
                 if "Mode" in node.name:
@@ -383,7 +405,8 @@ class ControllerAst():
                     mode_dict[modeType] = modes
             if isinstance(node, ast.ClassDef):
                 if "State" in node.name:
-                    state_object_dict[node.name] = {"cont":[],"disc":[], "type": []}
+                    state_object_dict[node.name] = {
+                        "cont": [], "disc": [], "type": []}
                     for item in node.body:
                         if isinstance(item, ast.FunctionDef):
                             if "init" in item.name:
@@ -393,13 +416,15 @@ class ControllerAst():
                                             state_object_dict[node.name]['cont'].append(arg.arg)
                                             # vars.append(arg.arg)
                                         else:
-                                            state_object_dict[node.name]['disc'].append(arg.arg)
+                                            state_object_dict[node.name]['disc'].append(
+                                                arg.arg)
                                             # discrete_vars.append(arg.arg)
             if isinstance(node, ast.FunctionDef):
                 if node.name == 'controller':
-                    #print(node.body)
-                    statementtree = self.parsenodelist(code, node.body, False, Tree(), None)
-                    #print(type(node.args))
+                    # print(node.body)
+                    statementtree = self.parsenodelist(
+                        code, node.body, False, Tree(), None)
+                    # print(type(node.args))
                     args = node.args.args
                     for arg in args:
                         if arg.annotation is None:
@@ -421,25 +446,26 @@ class ControllerAst():
                                 arg_annotation = arg.annotation.slice.id
                             
                         arg_name = arg.arg
-                        vars_dict[arg_name] = {'cont':[], 'disc':[], "type": []}
+                        vars_dict[arg_name] = {
+                            'cont': [], 'disc': [], "type": []}
                         for var in state_object_dict[arg_annotation]['cont']:
                             vars.append(arg_name+"."+var)
                             vars_dict[arg_name]['cont'].append(var)
                         for var in state_object_dict[arg_annotation]['disc']:
                             discrete_vars.append(arg_name+"."+var)
                             vars_dict[arg_name]['disc'].append(var)
-                        
         return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict]
 
 
     '''
     Helper function for initalwalktree which parses the statements in the controller function into a statement tree
     '''
-    def parsenodelist(self, code, nodes, addResets, tree, parent):
-        childrens_guards=[]
-        childrens_resets=[]
+
+    def parsenodelist(self, code, nodes, addResets, tree: Tree, parent):
+        childrens_guards = []
+        childrens_resets = []
         recoutput = []
-        #tree.show()
+        tree.show()
         if parent == None:
             s = Statement("root", None, None)
             tree.create_node("root")
@@ -448,56 +474,72 @@ class ControllerAst():
         for childnode in nodes:
             if isinstance(childnode, ast.Assign) and addResets:
                 reset = Reset.parseReset(childnode, code)
-                #print("found reset: " + reset.code)
+                # print("found reset: " + reset.code)
                 childrens_resets.append(reset)
             if isinstance(childnode, ast.If):
                 guard = Guard.parseGuard(childnode, code)
                 childrens_guards.append(guard)
-                #print("found if statement: " + guard.code)
+                # print("found if statement: " + guard.code)
                 newTree = Tree()
-                newTree.create_node(tag= guard.code, data = [guard])
-                #print(self.nodect)
-                tempresults = self.parsenodelist(code, childnode.body, True, newTree, newTree.root)
-                #for result in tempresults:
+                newTree.create_node(tag=guard.code, data=[guard])
+                # print(self.nodect)
+                tempresults = self.parsenodelist(
+                    code, childnode.body, True, newTree, newTree.root)
+                # for result in tempresults:
                 recoutput.append(tempresults)
 
-        
-        #pathsafterme = [] 
+                # if (len(childnode.orelse) > 0):
+                #     childnode = childnode.orelse[0]
+                #     if isinstance(childnode, ast.If):
+                #         guard = Guard.parseGuard(childnode, code)
+                #         childrens_guards.append(guard)
+                #         print("found if statement: " + guard.code)
+                #         newTree = Tree()
+                #         newTree.create_node(tag=guard.code, data=[guard])
+                #         # print(self.nodect)
+                #         tempresults = self.parsenodelist(
+                #             code, childnode.body, True, newTree, newTree.root)
+                #         # for result in tempresults:
+                #         recoutput.append(tempresults)
+
+        # pathsafterme = []
         if len(childrens_resets) > 0:
-            #print("adding node:" + str(self.nodect) + "with parent:" + str(parent))
-            tree.create_node(tag = childrens_resets[0].code, data = childrens_resets, parent= parent)
+            # print("adding node:" + str(self.nodect) + "with parent:" + str(parent))
+            tree.create_node(
+                tag=childrens_resets[0].code, data=childrens_resets, parent=parent)
         for subtree in recoutput:
-            #print("adding subtree:" + " to parent:" + str(parent))
+            # print("adding subtree:" + " to parent:" + str(parent))
             tree.paste(parent, subtree)
-                
-        
+        # tree.show()
         return tree
 
+
 class EmptyAst(ControllerAst):
     def __init__(self):
         super().__init__(code="True", file_name=None)
         self.discrete_variables = []
         self.modes = {
-            'NullMode':['Null'],
-            'LaneMode':['Normal']
+            'NullMode': ['Null'],
+            'LaneMode': ['Normal']
         }
         self.paths = None
         self.state_object_dict = {
-            'State':{
-                'cont':[],
-                'disc':[],
-                'type':[]
+            'State': {
+                'cont': [],
+                'disc': [],
+                'type': []
             }
         }
         self.variables = []
         self.vars_dict = []
         self.vertexStrings = ['Null,Normal']
-        self.vertices=[('Null','Normal')]
+        self.vertices = [('Null', 'Normal')]
         self.statementtree.create_node("root")
 
+
 ##main code###
 if __name__ == "__main__":
-    #if len(sys.argv) < 4:
+    # if len(sys.argv) < 4:
     #    print("incorrect usage. call createGraph.py program inputfile outputfilename")
     #    quit()
 
@@ -511,18 +553,18 @@ if __name__ == "__main__":
     output_dict = {
     }
 
-    #read in the controler code
-    f = open(input_code_name,'r')
+    # read in the controler code
+    f = open(input_code_name, 'r')
     code = f.read()
 
-    #parse the controller code into our controller ast objct
+    # parse the controller code into our controller ast objct
     controller_obj = ControllerAst(code)
 
     print(controller_obj.variables)
 
-    #demonstrate you can check getNextModes after only initalizing
+    # demonstrate you can check getNextModes after only initalizing
     paths = controller_obj.getNextModes("NormalA;Normal3")
-   
+
     print("Results")
     for path in paths:
         for item in path:
@@ -530,16 +572,11 @@ if __name__ == "__main__":
         print()
     print("Done")
 
-    #attempt to write to json, fail because we haven't populated paths yet
+    # attempt to write to json, fail because we haven't populated paths yet
     controller_obj.create_json(input_file_name, output_file_name)
 
-    #call function that gets all paths
+    # call function that gets all paths
     controller_obj.getAllPaths()
 
-    #write json with all paths
+    # write json with all paths
     controller_obj.create_json(input_file_name, output_file_name)
-
-
-    
-    
-
diff --git a/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py b/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py
index 4bc395d78f90465297283dc6f5cfc5b0c67bf249..7b79b60897d57c8a19d5aa989f8d0f941688af1e 100644
--- a/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py
+++ b/dryvr_plus_plus/scene_verifier/dryvr/common/utils.py
@@ -28,7 +28,7 @@ def importSimFunction(path):
     And the simulation function must be named TC_Simulate
     The function should looks like following:
         TC_Simulate(Mode, initialCondition, time_bound)
-    
+
     Args:
         path (str): Simulator directory.
 
@@ -56,7 +56,7 @@ def randomPoint(lower, upper):
     """
     Pick a random point between lower and upper bound
     This function supports both int or list
-    
+
     Args:
         lower (list or int or float): lower bound.
         upper (list or int or float): upper bound.
@@ -80,7 +80,7 @@ def calcDelta(lower, upper):
     """
     Calculate the delta value between the lower and upper bound
     The function only supports list since we assue initial set is always list
-    
+
     Args:
         lower (list): lowerbound.
         upper (list): upperbound.
@@ -101,7 +101,7 @@ def calcCenterPoint(lower, upper):
     """
     Calculate the center point between the lower and upper bound
     The function only supports list since we assue initial set is always list
-    
+
     Args:
         lower (list): lowerbound.
         upper (list): upperbound.
@@ -122,7 +122,7 @@ def buildModeStr(g, vertex):
     """
     Build a unique string to represent a mode
     This should be something like "modeName,modeNum"
-    
+
     Args:
         g (igraph.Graph): Graph object.
         vertex (int or str): vertex number.
@@ -142,7 +142,7 @@ def handleReplace(input_str, keys):
             And(y<=0,t>=0.2,v>=-0.1)
         output: 
             And(self.varDic["y"]<=0,self.varDic["t"]>=0.2,self.varDic["v"]>=-0.1)
-    
+
     Args:
         input_str (str): original string need to be replaced
         keys (list): list of variable strings
@@ -160,7 +160,8 @@ def handleReplace(input_str, keys):
         for i in range(len(input_str)):
             if input_str[i:].startswith(key):
                 idxes.append((i, i + len(key)))
-                input_str = input_str[:i] + "@" * len(key) + input_str[i + len(key):]
+                input_str = input_str[:i] + "@" * \
+                    len(key) + input_str[i + len(key):]
 
     idxes = sorted(idxes)
 
@@ -180,7 +181,7 @@ def neg(orig):
             And(y<=0,t>=0.2,v>=-0.1)
         output: 
             Not(And(y<=0,t>=0.2,v>=-0.1))
-    
+
     Args:
         orig (str): original string need to be neg
 
@@ -194,7 +195,7 @@ def neg(orig):
 def trimTraces(traces):
     """
     trim all traces to the same length
-    
+
     Args:
         traces (list): list of traces generated by simulator
     Returns:
@@ -207,7 +208,7 @@ def trimTraces(traces):
     for trace in traces:
         trace_lengths.append(len(trace))
     trace_len = min(trace_lengths)
-
+    print(trace_lengths)
     for trace in traces:
         ret_traces.append(trace[:trace_len])
 
@@ -217,16 +218,18 @@ def trimTraces(traces):
 def checkVerificationInput(data):
     """
     Check verification input to make sure it is valid
-    
+
     Args:
         data (obj): json data object
     Returns:
         None
 
     """
-    assert len(data['variables']) == len(data['initialSet'][0]), "Initial set dimension mismatch"
+    assert len(data['variables']) == len(
+        data['initialSet'][0]), "Initial set dimension mismatch"
 
-    assert len(data['variables']) == len(data['initialSet'][1]), "Initial set dimension mismatch"
+    assert len(data['variables']) == len(
+        data['initialSet'][1]), "Initial set dimension mismatch"
 
     assert len(data['edge']) == len(data["guards"]), "guard number mismatch"
 
@@ -242,16 +245,18 @@ def checkVerificationInput(data):
 def checkSynthesisInput(data):
     """
     Check Synthesis input to make sure it is valid
-    
+
     Args:
         data (obj): json data object
     Returns:
         None
 
     """
-    assert len(data['variables']) == len(data['initialSet'][0]), "Initial set dimension mismatch"
+    assert len(data['variables']) == len(
+        data['initialSet'][0]), "Initial set dimension mismatch"
 
-    assert len(data['variables']) == len(data['initialSet'][1]), "Initial set dimension mismatch"
+    assert len(data['variables']) == len(
+        data['initialSet'][1]), "Initial set dimension mismatch"
 
     for i in range(len(data['variables'])):
         assert data['initialSet'][0][i] <= data['initialSet'][1][i], "initial set lowerbound is larger than upperbound"
@@ -279,7 +284,7 @@ def isIpynb():
 def overloadConfig(configObj, userConfig):
     """
     Overload example config to config module
-    
+
     Args:
         configObj (module): config module
         userConfig (dict): example specified config
diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py b/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py
index 79f1b70677881254fe6d2676f41485c3fdaef17a..b294aae441508b220604e1baa6149099c8daa7f0 100644
--- a/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py
+++ b/dryvr_plus_plus/scene_verifier/dryvr/core/guard.py
@@ -47,14 +47,17 @@ class Guard:
         # Thus we need to replace "==" to something else
         sympy_guard_str = guard_str.replace("==", ">=")
 
-        symbols = list(sympy.sympify(sympy_guard_str, evaluate=False).free_symbols)
+        symbols = list(sympy.sympify(
+            sympy_guard_str, evaluate=False).free_symbols)
         symbols = [str(s) for s in symbols]
-        symbols_idx = {s: self.variables.index(s) + 1 for s in symbols if s in self.variables}
+        symbols_idx = {s: self.variables.index(
+            s) + 1 for s in symbols if s in self.variables}
         if 't' in symbols:
             symbols_idx['t'] = 0
 
         guard_str = handleReplace(guard_str, list(self.varDic.keys()))
-        cur_solver.add(eval(guard_str))  # TODO use an object instead of `eval` a string
+        # TODO use an object instead of `eval` a string
+        cur_solver.add(eval(guard_str))
         return cur_solver, symbols_idx
 
     def guard_sim_trace(self, trace, guard_str):
@@ -97,8 +100,10 @@ class Guard:
             upper = trace[idx + 1]
             cur_solver.push()
             for symbol in symbols:
-                cur_solver.add(self.varDic[symbol] >= min(lower[symbols[symbol]], upper[symbols[symbol]]))
-                cur_solver.add(self.varDic[symbol] <= max(lower[symbols[symbol]], upper[symbols[symbol]]))
+                cur_solver.add(self.varDic[symbol] >= min(
+                    lower[symbols[symbol]], upper[symbols[symbol]]))
+                cur_solver.add(self.varDic[symbol] <= max(
+                    lower[symbols[symbol]], upper[symbols[symbol]]))
             if cur_solver.check() == sat:
                 cur_solver.pop()
                 guard_set[idx] = upper
@@ -163,8 +168,10 @@ class Guard:
             lower_bound = tube[i]
             upper_bound = tube[i + 1]
             for symbol in symbols:
-                cur_solver.add(self.varDic[symbol] >= lower_bound[symbols[symbol]])
-                cur_solver.add(self.varDic[symbol] <= upper_bound[symbols[symbol]])
+                cur_solver.add(self.varDic[symbol] >=
+                               lower_bound[symbols[symbol]])
+                cur_solver.add(self.varDic[symbol] <=
+                               upper_bound[symbols[symbol]])
             if cur_solver.check() == sat:
                 # The reachtube hits the guard
                 cur_solver.pop()
@@ -174,8 +181,10 @@ class Guard:
                 tmp_solver = Solver()
                 tmp_solver.add(Not(cur_solver.assertions()[0]))
                 for symbol in symbols:
-                    tmp_solver.add(self.varDic[symbol] >= lower_bound[symbols[symbol]])
-                    tmp_solver.add(self.varDic[symbol] <= upper_bound[symbols[symbol]])
+                    tmp_solver.add(
+                        self.varDic[symbol] >= lower_bound[symbols[symbol]])
+                    tmp_solver.add(
+                        self.varDic[symbol] <= upper_bound[symbols[symbol]])
                 if tmp_solver.check() == unsat:
                     print("Full intersect, break")
                     break
@@ -188,8 +197,10 @@ class Guard:
                     init_upper = guard_set_upper[0][1:]
                     for j in range(1, len(guard_set_lower)):
                         for k in range(1, len(guard_set_lower[0])):
-                            init_lower[k - 1] = min(init_lower[k - 1], guard_set_lower[j][k])
-                            init_upper[k - 1] = max(init_upper[k - 1], guard_set_upper[j][k])
+                            init_lower[k - 1] = min(init_lower[k - 1],
+                                                    guard_set_lower[j][k])
+                            init_upper[k - 1] = max(init_upper[k - 1],
+                                                    guard_set_upper[j][k])
                     # Return next initial Set, the result tube, and the true transit time
                     return [init_lower, init_upper], tube[:i], guard_set_lower[0][0]
 
@@ -204,7 +215,7 @@ class Guard:
                     init_lower[k - 1] = min(init_lower[k - 1], guard_set_lower[j][k])
                     init_upper[k - 1] = max(init_upper[k - 1], guard_set_upper[j][k])
             # init_upper[0] = init_lower[0]
-            
+
             # Return next initial Set, the result tube, and the true transit time
             return [init_lower, init_upper], tube[:i], guard_set_lower[0][0]
 
diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py b/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py
index dfa49b0a6d2d55b3c539029518cfbb8f1459a128..cce5c1a73199dda227b3ade0b8bbade9794c0d0b 100644
--- a/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py
+++ b/dryvr_plus_plus/scene_verifier/dryvr/core/uniformchecker.py
@@ -49,7 +49,8 @@ class UniformChecker:
             cond = cond.replace("==", ">=")
             symbols = list(sympy.sympify(cond).free_symbols)
             symbols = [str(s) for s in symbols]
-            symbols_idx = {s: self._variables.index(s) + 1 for s in symbols if s in self._variables}
+            symbols_idx = {s: self._variables.index(
+                s) + 1 for s in symbols if s in self._variables}
             if 't' in symbols:
                 symbols_idx['t'] = 0
             self._solver_dict[mode].append(symbols_idx)  # TODO Fix typing
diff --git a/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py b/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py
index d84b5176e7c3ae8aca917af0ea533e5f6aab137d..068ff23a5e68b584716358234a8ef6d7b405b8b7 100644
--- a/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py
+++ b/dryvr_plus_plus/scene_verifier/dryvr/discrepancy/Global_Disc.py
@@ -20,13 +20,15 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray
     trace_initial_time = center_trace[0, 0]
     x_points: np.ndarray = center_trace[:, 0] - trace_initial_time
     assert np.all(training_traces[0, :, 0] == training_traces[1:, :, 0])
-    y_points: np.ndarray = all_sensitivities_calc(training_traces, initial_radii)
+    y_points: np.ndarray = all_sensitivities_calc(
+        training_traces, initial_radii)
     points: np.ndarray = np.zeros((ndims - 1, trace_len, 2))
     points[np.where(initial_radii != 0), 0, 1] = 1.0
     points[:, :, 0] = np.reshape(x_points, (1, x_points.shape[0]))
     points[:, 1:, 1] = y_points
     normalizing_initial_set_radii: np.ndarray = initial_radii.copy()
-    normalizing_initial_set_radii[np.where(normalizing_initial_set_radii == 0)] = 1.0
+    normalizing_initial_set_radii[np.where(
+        normalizing_initial_set_radii == 0)] = 1.0
     df: np.ndarray = np.zeros((trace_len, ndims))
     if method == 'PW':
         df[:, 1:] = np.transpose(
@@ -38,7 +40,8 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray
         points[:, :, 1] = np.maximum(points[:, :, 1], _EPSILON)
         points[:, :, 1] = np.log(points[:, :, 1])
         for dim_ind in range(1, ndims):
-            new_min = min(np.min(points[dim_ind - 1, 1:, 1]) + _TRUE_MIN_CONST, -10)
+            new_min = min(
+                np.min(points[dim_ind - 1, 1:, 1]) + _TRUE_MIN_CONST, -10)
             if initial_radii[dim_ind - 1] == 0:
                 # exclude initial set, then add true minimum points
                 new_points: np.ndarray = np.row_stack(
@@ -46,27 +49,35 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray
             else:
                 # start from zero, then add true minimum points
                 new_points: np.ndarray = np.row_stack((points[dim_ind - 1, 0, :],
-                                                       np.array((points[dim_ind - 1, 0, 0], new_min)),
+                                                       np.array(
+                                                           (points[dim_ind - 1, 0, 0], new_min)),
                                                        np.array((points[dim_ind - 1, -1, 0], new_min))))
                 df[0, dim_ind] = initial_radii[dim_ind - 1]
                 # Tuple order is start_time, end_time, slope, y-intercept
-            cur_dim_points = np.concatenate((points[dim_ind - 1, 1:, :], new_points), axis=0)
-            cur_hull: sp.spatial.ConvexHull = sp.spatial.ConvexHull(cur_dim_points)
-            linear_separators: List[Tuple[float, float, float, float, int, int]] = []
-            vert_inds = list(zip(cur_hull.vertices[:-1], cur_hull.vertices[1:]))
+            cur_dim_points = np.concatenate(
+                (points[dim_ind - 1, 1:, :], new_points), axis=0)
+            cur_hull: sp.spatial.ConvexHull = sp.spatial.ConvexHull(
+                cur_dim_points)
+            linear_separators: List[Tuple[float,
+                                          float, float, float, int, int]] = []
+            vert_inds = list(
+                zip(cur_hull.vertices[:-1], cur_hull.vertices[1:]))
             vert_inds.append((cur_hull.vertices[-1], cur_hull.vertices[0]))
             for end_ind, start_ind in vert_inds:
                 if cur_dim_points[start_ind, 1] != new_min and cur_dim_points[end_ind, 1] != new_min:
                     slope = (cur_dim_points[end_ind, 1] - cur_dim_points[start_ind, 1]) / (
-                                cur_dim_points[end_ind, 0] - cur_dim_points[start_ind, 0])
-                    y_intercept = cur_dim_points[start_ind, 1] - cur_dim_points[start_ind, 0] * slope
+                        cur_dim_points[end_ind, 0] - cur_dim_points[start_ind, 0])
+                    y_intercept = cur_dim_points[start_ind,
+                                                 1] - cur_dim_points[start_ind, 0] * slope
                     start_time = cur_dim_points[start_ind, 0]
                     end_time = cur_dim_points[end_ind, 0]
                     assert start_time < end_time
                     if start_time == 0:
-                        linear_separators.append((start_time, end_time, slope, y_intercept, 0, end_ind + 1))
+                        linear_separators.append(
+                            (start_time, end_time, slope, y_intercept, 0, end_ind + 1))
                     else:
-                        linear_separators.append((start_time, end_time, slope, y_intercept, start_ind + 1, end_ind + 1))
+                        linear_separators.append(
+                            (start_time, end_time, slope, y_intercept, start_ind + 1, end_ind + 1))
             linear_separators.sort()
             prev_val = 0
             prev_ind = 1 if initial_radii[dim_ind - 1] == 0 else 0
@@ -86,13 +97,17 @@ def get_reachtube_segment(training_traces: np.ndarray, initial_radii: np.ndarray
         raise ValueError
     assert (np.all(df >= 0))
     reachtube_segment: np.ndarray = np.zeros((trace_len - 1, 2, ndims))
-    reachtube_segment[:, 0, :] = np.minimum(center_trace[1:, :] - df[1:, :], center_trace[:-1, :] - df[:-1, :])
-    reachtube_segment[:, 1, :] = np.maximum(center_trace[1:, :] + df[1:, :], center_trace[:-1, :] + df[:-1, :])
+    reachtube_segment[:, 0, :] = np.minimum(
+        center_trace[1:, :] - df[1:, :], center_trace[:-1, :] - df[:-1, :])
+    reachtube_segment[:, 1, :] = np.maximum(
+        center_trace[1:, :] + df[1:, :], center_trace[:-1, :] + df[:-1, :])
     # assert 100% training accuracy (all trajectories are contained)
     for trace_ind in range(training_traces.shape[0]):
         if not (np.all(reachtube_segment[:, 0, :] <= training_traces[trace_ind, 1:, :]) and np.all(reachtube_segment[:, 1, :] >= training_traces[trace_ind, 1:, :])):
-            assert np.any(np.abs(training_traces[trace_ind, 0, 1:]-center_trace[0, 1:]) > initial_radii)
-            print(f"Warning: Trace #{trace_ind}", "of this initial set is sampled outside of the initial set because of floating point error and is not contained in the initial set")
+            assert np.any(
+                np.abs(training_traces[trace_ind, 0, 1:]-center_trace[0, 1:]) > initial_radii)
+            print(f"Warning: Trace #{trace_ind}",
+                  "of this initial set is sampled outside of the initial set because of floating point error and is not contained in the initial set")
     return reachtube_segment
 
 
@@ -102,24 +117,30 @@ def all_sensitivities_calc(training_traces: np.ndarray, initial_radii: np.ndarra
     ndims: int
     num_traces, trace_len, ndims = training_traces.shape
     normalizing_initial_set_radii: np.array = initial_radii.copy()
-    y_points: np.array = np.zeros((normalizing_initial_set_radii.shape[0], trace_len - 1))
-    normalizing_initial_set_radii[np.where(normalizing_initial_set_radii == 0)] = 1.0
+    y_points: np.array = np.zeros(
+        (normalizing_initial_set_radii.shape[0], trace_len - 1))
+    normalizing_initial_set_radii[np.where(
+        normalizing_initial_set_radii == 0)] = 1.0
     for cur_dim_ind in range(1, ndims):
-        normalized_initial_points: np.array = training_traces[:, 0, 1:] / normalizing_initial_set_radii
-        initial_distances = sp.spatial.distance.pdist(normalized_initial_points, 'chebyshev') + _SMALL_EPSILON
+        # keyi: move out of loop
+        normalized_initial_points: np.array = training_traces[:,
+                                                              0, 1:] / normalizing_initial_set_radii
+        initial_distances = sp.spatial.distance.pdist(
+            normalized_initial_points, 'chebyshev') + _SMALL_EPSILON
         for cur_time_ind in range(1, trace_len):
             y_points[cur_dim_ind - 1, cur_time_ind - 1] = np.max((sp.spatial.distance.pdist(
-                    np.reshape(training_traces[:, cur_time_ind, cur_dim_ind],
-                    (training_traces.shape[0], 1)), 'chebychev'
-                )/ normalizing_initial_set_radii[cur_dim_ind - 1]) / initial_distances)
+                np.reshape(training_traces[:, cur_time_ind, cur_dim_ind],
+                           (training_traces.shape[0], 1)), 'chebychev'
+            ) / normalizing_initial_set_radii[cur_dim_ind - 1]) / initial_distances)
     return y_points
 
-if __name__=="__main__":
+
+if __name__ == "__main__":
     with open("test.npy", "rb") as f:
         training_traces = np.load(f)
-    initial_radii = np.array([1.96620653e-06, 2.99999995e+00, 3.07000514e-07, 8.84958773e-13, 1.05625786e-16, 3.72500000e+00, 0.00000000e+00, 0.00000000e+00])
-    result = get_reachtube_segment(training_traces, initial_radii, method='PWGlobal')
+    initial_radii = np.array([1.96620653e-06, 2.99999995e+00, 3.07000514e-07,
+                             8.84958773e-13, 1.05625786e-16, 3.72500000e+00, 0.00000000e+00, 0.00000000e+00])
+    result = get_reachtube_segment(
+        training_traces, initial_radii, method='PWGlobal')
     print(training_traces.dtype)
     # plot_rtsegment_and_traces(result, training_traces[np.array((0, 6))])
-
-
diff --git a/dryvr_plus_plus/scene_verifier/map/lane.py b/dryvr_plus_plus/scene_verifier/map/lane.py
index e48e3f10cc14647bd5d6d3cfcba4d2198a4e7b25..e9d120ef17d545e560d711433cf4bcef1f49e6fd 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane.py
@@ -4,11 +4,14 @@ import numpy as np
 
 from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane
 
+
 class Lane():
     COMPENSATE = 3
-    def __init__(self, id, seg_list: List[AbstractLane]):
+
+    def __init__(self, id, seg_list: List[AbstractLane], speed_limit=None):
         self.id = id
         self.segment_list: List[AbstractLane] = seg_list
+        self.speed_limit = speed_limit
         self._set_longitudinal_start()
 
     def _set_longitudinal_start(self):
@@ -17,28 +20,37 @@ class Lane():
             lane_seg.longitudinal_start = longitudinal_start
             longitudinal_start += lane_seg.length
 
-    def get_lane_segment(self, position:np.ndarray) -> AbstractLane:
+    def get_lane_segment(self, position: np.ndarray) -> AbstractLane:
         for seg_idx, segment in enumerate(self.segment_list):
             logitudinal, lateral = segment.local_coordinates(position)
+            # why COMPENSATE? why no lateral? use on_lane?
             is_on = 0-Lane.COMPENSATE <= logitudinal < segment.length
             if is_on:
                 return seg_idx, segment
-        return -1,None
+        return -1, None
 
-    def get_heading(self, position:np.ndarray) -> float:
+    def get_heading(self, position: np.ndarray) -> float:
         seg_idx, segment = self.get_lane_segment(position)
         longitudinal, lateral = segment.local_coordinates(position)
         heading = segment.heading_at(longitudinal)
         return heading
 
-    def get_longitudinal_position(self, position:np.ndarray) -> float:
+    def get_longitudinal_position(self, position: np.ndarray) -> float:
         seg_idx, segment = self.get_lane_segment(position)
         longitudinal, lateral = segment.local_coordinates(position)
-        for i in range(seg_idx):    
+        for i in range(seg_idx):
             longitudinal += self.segment_list[i].length
         return longitudinal
 
-    def get_lateral_distance(self, position:np.ndarray) -> float:
+    def get_lateral_distance(self, position: np.ndarray) -> float:
         seg_idx, segment = self.get_lane_segment(position)
         longitudinal, lateral = segment.local_coordinates(position)
         return lateral
+
+    def get_speed_limit_old(self, position: np.ndarray) -> float:
+        seg_idx, segment = self.get_lane_segment(position)
+        longitudinal, lateral = segment.local_coordinates(position)
+        return segment.speed_limit_at(longitudinal)
+
+    def get_speed_limit(self):
+        return self.speed_limit
diff --git a/dryvr_plus_plus/scene_verifier/map/lane_map.py b/dryvr_plus_plus/scene_verifier/map/lane_map.py
index 3e5b9c2d58e943269964d354a5a095e3c69ec3a3..c4a5672e12b9cebe520c9e7ffbb9a52bea3da892 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane_map.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane_map.py
@@ -1,3 +1,4 @@
+from ctypes.wintypes import PINT
 from typing import Dict, List
 import copy
 from enum import Enum
@@ -7,17 +8,19 @@ import numpy as np
 from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane
 from dryvr_plus_plus.scene_verifier.map.lane import Lane
 
+
 class LaneMap:
-    def __init__(self, lane_seg_list:List[Lane] = []):
-        self.lane_dict:Dict[str, Lane] = {}
-        self.left_lane_dict:Dict[str, List[str]] = {}
-        self.right_lane_dict:Dict[str, List[str]] = {}
+    def __init__(self, lane_seg_list: List[Lane] = []):
+        self.lane_dict: Dict[str, Lane] = {}
+        self.left_lane_dict: Dict[str, List[str]] = {}
+        self.right_lane_dict: Dict[str, List[str]] = {}
         for lane_seg in lane_seg_list:
             self.lane_dict[lane_seg.id] = lane_seg
             self.left_lane_dict[lane_seg.id] = []
             self.right_lane_dict[lane_seg.id] = []
 
-    def add_lanes(self, lane_seg_list:List[AbstractLane]):
+    # why AbstractLane not Lane
+    def add_lanes(self, lane_seg_list: List[AbstractLane]):
         for lane_seg in lane_seg_list:
             self.lane_dict[lane_seg.id] = lane_seg
             self.left_lane_dict[lane_seg.id] = []
@@ -30,7 +33,7 @@ class LaneMap:
             Warning(f'lane {lane_idx} not available')
             return False
         left_lane_list = self.left_lane_dict[lane_idx]
-        return len(left_lane_list)>0
+        return len(left_lane_list) > 0
 
     def left_lane(self, lane_idx):
         assert all((elem in self.left_lane_dict) for elem in self.lane_dict)
@@ -40,7 +43,6 @@ class LaneMap:
             raise ValueError(f"lane_idx {lane_idx} not in lane_dict")
         left_lane_list = self.left_lane_dict[lane_idx]
         return copy.deepcopy(left_lane_list[0])
-        
     def has_right(self, lane_idx):
         if isinstance(lane_idx, Enum):
             lane_idx = lane_idx.name
@@ -48,7 +50,7 @@ class LaneMap:
             Warning(f'lane {lane_idx} not available')
             return False
         right_lane_list = self.right_lane_dict[lane_idx]
-        return len(right_lane_list)>0
+        return len(right_lane_list) > 0
 
     def right_lane(self, lane_idx):
         assert all((elem in self.right_lane_dict) for elem in self.lane_dict)
@@ -58,36 +60,57 @@ class LaneMap:
             raise ValueError(f"lane_idx {lane_idx} not in lane_dict")
         right_lane_list = self.right_lane_dict[lane_idx]
         return copy.deepcopy(right_lane_list[0])
-        
+
     def lane_geometry(self, lane_idx):
         if isinstance(lane_idx, Enum):
             lane_idx = lane_idx.name
         return self.lane_dict[lane_idx].get_geometry()
 
-    def get_longitudinal_position(self, lane_idx:str, position:np.ndarray) -> float:
+    def get_longitudinal_position(self, lane_idx: str, position: np.ndarray) -> float:
         if not isinstance(position, np.ndarray):
             position = np.array(position)
         lane = self.lane_dict[lane_idx]
         return lane.get_longitudinal_position(position)
 
-    def get_lateral_distance(self, lane_idx:str, position:np.ndarray) -> float:
+    def get_lateral_distance(self, lane_idx: str, position: np.ndarray) -> float:
         if not isinstance(position, np.ndarray):
             position = np.array(position)
         lane = self.lane_dict[lane_idx]
         return lane.get_lateral_distance(position)
 
-    def get_altitude(self, lane_idx, position:np.ndarray) -> float:
+    def get_altitude(self, lane_idx, position: np.ndarray) -> float:
         raise NotImplementedError
 
-    def get_lane_heading(self, lane_idx:str, position: np.ndarray) -> float:
+    def get_lane_heading(self, lane_idx: str, position: np.ndarray) -> float:
         if not isinstance(position, np.ndarray):
             position = np.array(position)
         lane = self.lane_dict[lane_idx]
         return lane.get_heading(position)
 
-    def get_lane_segment(self, lane_idx:str, position: np.ndarray) -> AbstractLane:
+    def get_lane_segment(self, lane_idx: str, position: np.ndarray) -> AbstractLane:
         if not isinstance(position, np.ndarray):
             position = np.array(position)
         lane = self.lane_dict[lane_idx]
         seg_idx, segment = lane.get_lane_segment(position)
-        return segment
\ No newline at end of file
+        return segment
+
+    def get_speed_limit_old(self, lane_idx: str, position: np.ndarray) -> float:
+        if not isinstance(position, np.ndarray):
+            position = np.array(position)
+        lane = self.lane_dict[lane_idx]
+        limit = lane.get_speed_limit_old(position)
+        # print(limit)
+        # print(position)
+        return limit
+
+    def get_speed_limit(self, lane_idx: str) -> float:
+        lane = self.lane_dict[lane_idx]
+        # print(lane.get_speed_limit())
+        return lane.get_speed_limit()
+
+    def get_all_speed_limit(self) -> Dict[str, float]:
+        ret_dict = {}
+        for lane_idx, lane in self.lane_dict.items():
+            ret_dict[lane_idx] = lane.get_speed_limit()
+        print(ret_dict)
+        return ret_dict
diff --git a/dryvr_plus_plus/scene_verifier/map/lane_segment.py b/dryvr_plus_plus/scene_verifier/map/lane_segment.py
index 642d7a005ba84f6d3956dbed49dcd5e4c7ac94e5..fd67782d64eba56a7470767eaccde988874de97a 100644
--- a/dryvr_plus_plus/scene_verifier/map/lane_segment.py
+++ b/dryvr_plus_plus/scene_verifier/map/lane_segment.py
@@ -1,9 +1,13 @@
+from turtle import speed
 from typing import List
 import numpy as np
 from abc import ABCMeta, abstractmethod
 from typing import Tuple, List, Optional, Union
+import copy
+from sympy import false
+
+from dryvr_plus_plus.scene_verifier.utils.utils import wrap_to_pi, Vector, get_class_path, class_from_path, to_serializable
 
-from dryvr_plus_plus.scene_verifier.utils.utils import wrap_to_pi, Vector, get_class_path, class_from_path,to_serializable
 
 class LineType:
 
@@ -14,6 +18,7 @@ class LineType:
     CONTINUOUS = 2
     CONTINUOUS_LINE = 3
 
+
 class AbstractLane(object):
 
     """A lane on the road, described by its central curve."""
@@ -25,7 +30,7 @@ class AbstractLane(object):
     longitudinal_start: float = 0
     line_types: List["LineType"]
 
-    def __init__(self, id:str):
+    def __init__(self, id: str):
         self.id = id
         self.type = None
 
@@ -137,18 +142,19 @@ class AbstractLane(object):
         angle = np.abs(wrap_to_pi(heading - self.heading_at(s)))
         return abs(r) + max(s - self.length, 0) + max(0 - s, 0) + heading_weight*angle
 
+
 class StraightLane(AbstractLane):
 
     """A lane going in straight line."""
 
     def __init__(self,
-                 id: str, 
+                 id: str,
                  start: Vector,
                  end: Vector,
                  width: float = AbstractLane.DEFAULT_WIDTH,
                  line_types: Tuple[LineType, LineType] = None,
                  forbidden: bool = False,
-                 speed_limit: float = 20,
+                 speed_limit: List[Tuple[float, float]] = None,
                  priority: int = 0) -> None:
         """
         New straight lane.
@@ -164,14 +170,19 @@ class StraightLane(AbstractLane):
         self.start = np.array(start)
         self.end = np.array(end)
         self.width = width
-        self.heading = np.arctan2(self.end[1] - self.start[1], self.end[0] - self.start[0])
+        self.heading = np.arctan2(
+            self.end[1] - self.start[1], self.end[0] - self.start[0])
         self.length = np.linalg.norm(self.end - self.start)
         self.line_types = line_types or [LineType.STRIPED, LineType.STRIPED]
         self.direction = (self.end - self.start) / self.length
-        self.direction_lateral = np.array([-self.direction[1], self.direction[0]])
+        self.direction_lateral = np.array(
+            [-self.direction[1], self.direction[0]])
         self.forbidden = forbidden
         self.priority = priority
-        self.speed_limit = speed_limit
+        if speed_limit != None:
+            self.speed_limit = sorted(speed_limit, key=lambda elem: elem[0])
+        else:
+            self.speed_limit = None
         self.type = 'Straight'
         self.longitudinal_start = 0
 
@@ -190,6 +201,54 @@ class StraightLane(AbstractLane):
         lateral = np.dot(delta, self.direction_lateral)
         return float(longitudinal), float(lateral)
 
+    def speed_limit_at(self, longitudinal: float) -> float:
+        # print(self.speed_limit)
+        if longitudinal >= self.speed_limit[-1][0]:
+            # print(longitudinal, self.speed_limit[-1][1])
+            return self.speed_limit[-1][1]
+        prev_limit = self.speed_limit[0][1]
+        for (start, limit) in self.speed_limit:
+            if longitudinal <= start:
+                # print(longitudinal, prev_limit)
+                return prev_limit
+            prev_limit = limit
+
+        return -1
+    # in format for polty filling mode
+
+    def get_all_speed(self):
+        end_longitudinal, end_lateral = self.local_coordinates(self.end)
+        ret_x = []
+        ret_y = []
+        ret_v = []
+        x_y = np.ndarray(shape=2)
+        seg_pos = []
+        speed_limit = copy.deepcopy(self.speed_limit)
+        speed_limit.append(tuple([end_longitudinal, self.speed_limit[-1][1]]))
+        for i in range(len(self.speed_limit)):
+            seg_start = speed_limit[i][0]
+            limit = speed_limit[i][1]
+            if end_longitudinal < seg_start:
+                break
+            seg_pos = []
+            seg_end = min(end_longitudinal, speed_limit[i+1][0])
+            x_y = self.position(seg_start, self.width/2)
+            seg_pos.append(x_y.tolist())
+            x_y = self.position(seg_end, self.width/2)
+            seg_pos.append(x_y.tolist())
+            x_y = self.position(seg_end, -self.width/2)
+            seg_pos.append(x_y.tolist())
+            x_y = self.position(seg_start, -self.width/2)
+            seg_pos.append(x_y.tolist())
+            ret_x.append([pos[0] for pos in seg_pos])
+            ret_y.append([pos[1] for pos in seg_pos])
+            ret_v.append(limit)
+        # print('get_all_speed')
+        # print(ret_x)
+        # print(ret_y)
+        # print(ret_v)
+        return ret_x, ret_y, ret_v
+
     @classmethod
     def from_config(cls, config: dict):
         config["start"] = np.array(config["start"])
@@ -216,7 +275,7 @@ class CircularLane(AbstractLane):
     """A lane going in circle arc."""
 
     def __init__(self,
-                 id, 
+                 id,
                  center: Vector,
                  radius: float,
                  start_phase: float,
@@ -225,7 +284,7 @@ class CircularLane(AbstractLane):
                  width: float = AbstractLane.DEFAULT_WIDTH,
                  line_types: List[LineType] = None,
                  forbidden: bool = False,
-                 speed_limit: float = 20,
+                 speed_limit: List[Tuple[float, float]] = None,
                  priority: int = 0) -> None:
         super().__init__(id)
         self.center = np.array(center)
@@ -239,7 +298,7 @@ class CircularLane(AbstractLane):
         self.forbidden = forbidden
         self.length = radius*(end_phase - start_phase) * self.direction
         self.priority = priority
-        self.speed_limit = speed_limit
+        self.speed_limit = sorted(speed_limit, key=lambda elem: elem[0])
         self.type = 'Circular'
         self.longitudinal_start = 0
 
@@ -264,6 +323,12 @@ class CircularLane(AbstractLane):
         lateral = self.direction*(self.radius - r)
         return longitudinal, lateral
 
+    def speed_limit_at(self, longitudinal: float) -> float:
+        for (start, limit) in self.speed_limit:
+            if longitudinal <= start:
+                return limit
+        return -1
+
     @classmethod
     def from_config(cls, config: dict):
         config["center"] = np.array(config["center"])
@@ -288,15 +353,15 @@ class CircularLane(AbstractLane):
 
 
 class LaneSegment:
-    def __init__(self, id, lane_parameter = None):
+    def __init__(self, id, lane_parameter=None):
         self.id = id
         # self.left_lane:List[str] = left_lane
-        # self.right_lane:List[str] = right_lane 
+        # self.right_lane:List[str] = right_lane
         # self.next_segment:int = next_segment
 
-        self.lane_parameter = None 
+        self.lane_parameter = None
         if lane_parameter is not None:
             self.lane_parameter = lane_parameter
 
     def get_geometry(self):
-        return self.lane_parameter
\ No newline at end of file
+        return self.lane_parameter
diff --git a/dryvr_plus_plus/scene_verifier/scenario/scenario.py b/dryvr_plus_plus/scene_verifier/scenario/scenario.py
index 3b003429d3c5323a2b31dc0cd97b8898cd26f1df..e42f2d6930a26a41d5a213c2c9af2896e4bcc6b2 100644
--- a/dryvr_plus_plus/scene_verifier/scenario/scenario.py
+++ b/dryvr_plus_plus/scene_verifier/scenario/scenario.py
@@ -18,6 +18,7 @@ from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisT
 from dryvr_plus_plus.scene_verifier.sensor.base_sensor import BaseSensor
 from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
 
+
 class Scenario:
     def __init__(self):
         self.agent_dict = {}
@@ -31,14 +32,14 @@ class Scenario:
     def set_sensor(self, sensor):
         self.sensor = sensor
 
-    def set_map(self, lane_map:LaneMap):
+    def set_map(self, lane_map: LaneMap):
         self.map = lane_map
         # Update the lane mode field in the agent
         for agent_id in self.agent_dict:
             agent = self.agent_dict[agent_id]
             self.update_agent_lane_mode(agent, lane_map)
 
-    def add_agent(self, agent:BaseAgent):
+    def add_agent(self, agent: BaseAgent):
         if self.map is not None:
             # Update the lane mode field in the agent
             self.update_agent_lane_mode(agent, self.map)
@@ -50,12 +51,13 @@ class Scenario:
                 agent.controller.modes['LaneMode'].append(lane_id)
         mode_vals = list(agent.controller.modes.values())
         agent.controller.vertices = list(itertools.product(*mode_vals))
-        agent.controller.vertexStrings = [','.join(elem) for elem in agent.controller.vertices]
+        agent.controller.vertexStrings = [
+            ','.join(elem) for elem in agent.controller.vertices]
 
     def set_init(self, init_list, init_mode_list):
         assert len(init_list) == len(self.agent_dict)
         assert len(init_mode_list) == len(self.agent_dict)
-        for i,agent_id in enumerate(self.agent_dict.keys()):
+        for i, agent_id in enumerate(self.agent_dict.keys()):
             self.init_dict[agent_id] = copy.deepcopy(init_list[i])
             self.init_mode_dict[agent_id] = copy.deepcopy(init_mode_list[i])
 
@@ -84,6 +86,7 @@ class Scenario:
         for agent_id in self.agent_dict:
             init = self.init_dict[agent_id]
             tmp = np.array(init)
+            print(tmp)
             if tmp.ndim < 2:
                 init = [init, init]
             init_list.append(init)
@@ -92,13 +95,13 @@ class Scenario:
         return self.verifier.compute_full_reachtube(init_list, init_mode_list, agent_list, self, time_horizon, time_step, self.map)
 
     def check_guard_hit(self, state_dict):
-        lane_map = self.map 
+        lane_map = self.map
         guard_hits = []
         any_contained = False        
         for agent_id in state_dict:
-            agent:BaseAgent = self.agent_dict[agent_id]
+            agent: BaseAgent = self.agent_dict[agent_id]
             agent_state, agent_mode = state_dict[agent_id]
-        
+
             t = agent_state[0]
             agent_state = agent_state[1:]
             paths = agent.controller.getNextModes(agent_mode)
@@ -118,11 +121,12 @@ class Scenario:
                 
                 # Unroll all the any/all functions in the guard
                 guard_expression.parse_any_all(continuous_variable_dict, discrete_variable_dict, length_dict)
-                
+
                 # Check if the guard can be satisfied
-                # First Check if the discrete guards can be satisfied by actually evaluate the values 
+                # First Check if the discrete guards can be satisfied by actually evaluate the values
                 # since there's no uncertainty. If there's functions, actually execute the functions
-                guard_can_satisfied = guard_expression.evaluate_guard_disc(agent, discrete_variable_dict, continuous_variable_dict, self.map)
+                guard_can_satisfied = guard_expression.evaluate_guard_disc(
+                    agent, discrete_variable_dict, continuous_variable_dict, self.map)
                 if not guard_can_satisfied:
                     continue
 
@@ -132,7 +136,8 @@ class Scenario:
                     continue
 
                 # Handle guards realted only to continuous variables using SMT solvers. These types of guards can be pretty general
-                guard_satisfied, is_contained = guard_expression.evaluate_guard_cont(agent, continuous_variable_dict, self.map)
+                guard_satisfied, is_contained = guard_expression.evaluate_guard_cont(
+                    agent, continuous_variable_dict, self.map)
                 any_contained = any_contained or is_contained
                 if guard_satisfied:
                     guard_hits.append((agent_id, guard_list, reset_list))
@@ -145,14 +150,15 @@ class Scenario:
         guard_hit_bool = False
 
         # TODO: can add parallalization for this loop
-        for idx in range(0,trace_length):
+        for idx in range(0, trace_length):
             # For each trace, check with the guard to see if there's any possible transition
             # Store all possible transition in a list
             # A transition is defined by (agent, src_mode, dest_mode, corresponding reset, transit idx)
             # Here we enforce that only one agent transit at a time
             all_agent_state = {}
             for agent_id in node.agent:
-                all_agent_state[agent_id] = (node.trace[agent_id][idx*2:idx*2+2], node.mode[agent_id])
+                all_agent_state[agent_id] = (
+                    node.trace[agent_id][idx*2:idx*2+2], node.mode[agent_id])
             hits, is_contain = self.check_guard_hit(all_agent_state)
             # print(idx, is_contain)
             if hits != []:
@@ -167,12 +173,14 @@ class Scenario:
         reset_idx_dict = {}
         for hits, all_agent_state, hit_idx in guard_hits:
             for agent_id, guard_list, reset_list in hits:
-                dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
+                dest_list, reset_rect = self.apply_reset(
+                    node.agent[agent_id], reset_list, all_agent_state)
                 if agent_id not in reset_dict:
                     reset_dict[agent_id] = {}
                     reset_idx_dict[agent_id] = {}
                 if not dest_list:
-                    warnings.warn(f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode")
+                    warnings.warn(
+                        f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode")
                     dest_list.append(None)
                 for dest in dest_list:
                     if dest not in reset_dict[agent_id]:
@@ -180,22 +188,25 @@ class Scenario:
                         reset_idx_dict[agent_id][dest] = []
                     reset_dict[agent_id][dest].append(reset_rect)
                     reset_idx_dict[agent_id][dest].append(hit_idx)
-            
+
         # Combine reset rects and construct transitions
         for agent in reset_dict:
             for dest in reset_dict[agent]:
-                combined_rect = None 
+                combined_rect = None
                 for rect in reset_dict[agent][dest]:
                     rect = np.array(rect)
                     if combined_rect is None:
-                        combined_rect = rect 
+                        combined_rect = rect
                     else:
-                        combined_rect[0,:] = np.minimum(combined_rect[0,:], rect[0,:])
-                        combined_rect[1,:] = np.maximum(combined_rect[1,:], rect[1,:])
+                        combined_rect[0, :] = np.minimum(
+                            combined_rect[0, :], rect[0, :])
+                        combined_rect[1, :] = np.maximum(
+                            combined_rect[1, :], rect[1, :])
                 combined_rect = combined_rect.tolist()
                 min_idx = min(reset_idx_dict[agent][dest])
                 max_idx = max(reset_idx_dict[agent][dest])
-                transition = (agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx))
+                transition = (
+                    agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx))
                 possible_transitions.append(transition)
         # Return result
         return possible_transitions
@@ -211,7 +222,7 @@ class Scenario:
         lane_map = self.map
         satisfied_guard = []
         for agent_id in state_dict:
-            agent:BaseAgent = self.agent_dict[agent_id]
+            agent: BaseAgent = self.agent_dict[agent_id]
             agent_state, agent_mode = state_dict[agent_id]
             t = agent_state[0]
             agent_state = agent_state[1:]
@@ -221,6 +232,7 @@ class Scenario:
                 guard_list = []
                 reset_list = []
                 for item in path:
+                    # print(item.code)
                     if isinstance(item, Guard):
                         guard_list.append(item)
                     elif isinstance(item, Reset):
@@ -237,20 +249,28 @@ class Scenario:
                 '''Check guards related to modes to see if the guards can be satisfied'''
                 '''Actually plug in the values to see if the guards can be satisfied'''
                 # Check if the guard can be satisfied
-                guard_satisfied = guard_expression.evaluate_guard(agent, continuous_variable_dict, discrete_variable_dict, self.map)
+                guard_satisfied = guard_expression.evaluate_guard(
+                    agent, continuous_variable_dict, discrete_variable_dict, self.map)
                 if guard_satisfied:
                     # If the guard can be satisfied, handle resets
                     next_init = agent_state
                     dest = copy.deepcopy(agent_mode)
                     possible_dest = [[elem] for elem in dest]
+                    # like [['Normal'], ['Lane1']]
+                    print('possible_dest', possible_dest)
                     for reset in reset_list:
                         # Specify the destination mode
                         reset = reset.code
                         if "mode" in reset:
+                            print(agent.controller.vars_dict['ego'])
+                            print(reset)
+                            # why break
+                            # vars_dict: {'cont': ['x', 'y', 'theta', 'v'], 'disc': ['vehicle_mode', 'lane_mode'], 'type': []}
                             for i, discrete_variable_ego in enumerate(agent.controller.vars_dict['ego']['disc']):
                                 if discrete_variable_ego in reset:
                                     break
                             tmp = reset.split('=')
+                            # like output.lane_mode = lane_map.right_lane(ego.lane_mode)
                             if 'map' in tmp[1]:
                                 tmp = tmp[1]
                                 for var in discrete_variable_dict:
@@ -262,26 +282,32 @@ class Scenario:
                             else:
                                 tmp = tmp[1].split('.')
                                 if tmp[0].strip(' ') in agent.controller.modes:
-                                    possible_dest[i] = [tmp[1]]                            
-                        else: 
+                                    possible_dest[i] = [tmp[1]]
+                        else:
+                            #
                             for i, cts_variable in enumerate(agent.controller.vars_dict['ego']['cont']):
                                 if "output."+cts_variable in reset:
-                                    break 
+                                    break
                             tmp = reset.split('=')
                             tmp = tmp[1]
                             for cts_variable in continuous_variable_dict:
-                                tmp = tmp.replace(cts_variable, str(continuous_variable_dict[cts_variable]))
+                                tmp = tmp.replace(cts_variable, str(
+                                    continuous_variable_dict[cts_variable]))
                             next_init[i] = eval(tmp)
+                    # print('possible_dest', possible_dest)
+                    # [['Brake'], ['Lane1']] -> [('Brake', 'Lane1')]
                     all_dest = list(itertools.product(*possible_dest))
+                    # print('all_dest', all_dest)
                     if not all_dest:
-                        warnings.warn(f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode")
+                        warnings.warn(
+                            f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode")
                         all_dest.append(None)
                     for dest in all_dest:
                         next_transition = (
-                            agent_id, agent_mode, dest, next_init, 
+                            agent_id, agent_mode, dest, next_init,
                         )
+                        # print(next_transition)
                         satisfied_guard.append(next_transition)
-
         return satisfied_guard
 
     def get_transition_simulate(self, node:AnalysisTreeNode) -> Tuple[Dict[str,List[Tuple[float]]], int]: