diff --git a/demo/demo1.py b/demo/demo1.py
index 44b98844ff3bae7c5c5b80e087f242ee420cd85a..446e380bbe7cdfdfb024f0f45984c3c4c64ae7d0 100644
--- a/demo/demo1.py
+++ b/demo/demo1.py
@@ -1,11 +1,10 @@
 from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
-from dryvr_plus_plus.plotter.plotter2D_new import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
-import plotly.graph_objects as go
-import numpy as np
 from enum import Enum, auto
+import plotly.graph_objects as go
+from dryvr_plus_plus.plotter.plotter2D_new import *
 
 
 class VehicleMode(Enum):
@@ -57,5 +56,5 @@ if __name__ == "__main__":
 
     traces = scenario.simulate(10, 0.01)
     fig = go.Figure()
-    fig = test_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
+    fig = general_simu_anime(traces, tmp_map, fig, 1, 2, 'lines', 'trace')
     fig.show()
diff --git a/demo/demo2.py b/demo/demo2.py
index 2386b599c4e6e2f6fb3338c18664db1533c60665..d58e88f7cce2f33714e11dd8fe02b6261c3ab225 100644
--- a/demo/demo2.py
+++ b/demo/demo2.py
@@ -1,11 +1,11 @@
-from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
+from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
-from dryvr_plus_plus.plotter.plotter2D import *
-from dryvr_plus_plus.plotter.plotter2D_new import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor2
+from enum import Enum, auto
 import plotly.graph_objects as go
-# import matplotlib.pyplot as plt
+from dryvr_plus_plus.plotter.plotter2D_new import *
+
 
 import numpy as np
 from enum import Enum, auto
@@ -58,8 +58,7 @@ if __name__ == "__main__":
         ]
     )
 
-    traces = scenario.simulate(30)
+    traces = scenario.simulate(30, 0.05)
     fig = go.Figure()
     fig = test_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
     fig.show()
-
diff --git a/demo/demo3.py b/demo/demo3.py
index a459d2ecf1799151f86ef91b974ddcc3dbc5eb1a..1ec2ce32b6490e90759d66512643e908737c0ea3 100644
--- a/demo/demo3.py
+++ b/demo/demo3.py
@@ -2,13 +2,11 @@ from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent
 from dryvr_plus_plus.example.example_agent.car_agent import CarAgent
 from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario
 from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6
-from dryvr_plus_plus.plotter.plotter2D import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor3
-
-import matplotlib.pyplot as plt
-import plotly.graph_objects as go
-import numpy as np
 from enum import Enum, auto
+import plotly.graph_objects as go
+from dryvr_plus_plus.plotter.plotter2D_new import *
+
 
 class LaneObjectMode(Enum):
     Vehicle = auto()
@@ -17,17 +15,20 @@ class LaneObjectMode(Enum):
     Signal = auto()     # Traffic lights
     Obstacle = auto()   # Static (to road/lane) obstacles
 
+
 class VehicleMode(Enum):
     Normal = auto()
     SwitchLeft = auto()
     SwitchRight = auto()
     Brake = auto()
 
+
 class LaneMode(Enum):
     Lane0 = auto()
     Lane1 = auto()
     Lane2 = auto()
 
+
 class State:
     x = 0.0
     y = 0.0
@@ -57,10 +58,10 @@ if __name__ == "__main__":
     scenario.set_map(tmp_map)
     scenario.set_init(
         [
-            [[0, -0.2, 0, 1.0],[0.01, 0.2, 0, 1.0]],
-            [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], 
-            [[20, 3, 0, 0.5],[20, 3, 0, 0.5]], 
-            [[30, 0, 0, 0.5],[30, 0, 0, 0.5]], 
+            [[0, -0.2, 0, 1.0], [0.01, 0.2, 0, 1.0]],
+            [[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
+            [[20, 3, 0, 0.5], [20, 3, 0, 0.5]],
+            [[30, 0, 0, 0.5], [30, 0, 0, 0.5]],
         ],
         [
             (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
@@ -78,9 +79,8 @@ if __name__ == "__main__":
     # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
     # fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig)
     # fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig)
-    # plt.show()    
+    # plt.show()
 
     fig = go.Figure()
-    fig = plotly_simulation_anime(traces, tmp_map, fig)
+    fig = general_simu_anime(traces, tmp_map, fig, 1, 2, 'lines')
     fig.show()
-
diff --git a/demo/demo4.py b/demo/demo4.py
index a0025e590cdc268c8ae6d27accc5bed50bd614d0..40f430b5e0e8589cd29d60312af67a59ca98fdec 100644
--- a/demo/demo4.py
+++ b/demo/demo4.py
@@ -5,11 +5,10 @@ from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMa
 from dryvr_plus_plus.plotter.plotter2D import *
 from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor3
 from dryvr_plus_plus.scene_verifier.sensor.base_sensor import BaseSensor
-
-import matplotlib.pyplot as plt
-import plotly.graph_objects as go
-import numpy as np
 from enum import Enum, auto
+import plotly.graph_objects as go
+from dryvr_plus_plus.plotter.plotter2D_new import *
+
 
 class LaneObjectMode(Enum):
     Vehicle = auto()
@@ -18,18 +17,21 @@ class LaneObjectMode(Enum):
     Signal = auto()     # Traffic lights
     Obstacle = auto()   # Static (to road/lane) obstacles
 
+
 class VehicleMode(Enum):
     Normal = auto()
     SwitchLeft = auto()
     SwitchRight = auto()
     Brake = auto()
 
+
 class LaneMode(Enum):
     Lane0 = auto()
     Lane1 = auto()
     Lane2 = auto()
     Lane3 = auto()
 
+
 class State:
     x = 0.0
     y = 0.0
@@ -63,12 +65,12 @@ if __name__ == "__main__":
     scenario.set_map(tmp_map)
     scenario.set_init(
         [
-            [[0, -0.2, 0, 1.0],[0.05, 0.2, 0, 1.0]],
-            [[10, 0, 0, 0.5],[10, 0, 0, 0.5]], 
-            [[20, 3, 0, 0.5],[20, 3, 0, 0.5]], 
-            [[30, 0, 0, 0.5],[30, 0, 0, 0.5]], 
-            [[23, -3, 0, 0.5],[23, -3, 0, 0.5]], 
-            [[40, -6, 0, 0.5],[40, -6, 0, 0.5]], 
+            [[0, -0.2, 0, 1.0], [0.05, 0.2, 0, 1.0]],
+            [[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
+            [[20, 3, 0, 0.5], [20, 3, 0, 0.5]],
+            [[30, 0, 0, 0.5], [30, 0, 0, 0.5]],
+            [[23, -3, 0, 0.5], [23, -3, 0, 0.5]],
+            [[40, -6, 0, 0.5], [40, -6, 0, 0.5]],
         ],
         [
             (VehicleMode.Normal, LaneMode.Lane1),
@@ -92,7 +94,6 @@ if __name__ == "__main__":
     # fig = plot_reachtube_tree(traces, 'car6', 1, [2], 'r', fig)
     # plt.show()
 
-    # fig = go.Figure()
-    # fig = plotly_simulation_anime(traces, tmp_map, fig)
-    # fig.show()    
-    
+    fig = go.Figure()
+    fig = draw_simulation_tree(traces, tmp_map, fig, 1, 2, 'lines')
+    fig.show()
diff --git a/dryvr_plus_plus/plotter/plotter2D_new.py b/dryvr_plus_plus/plotter/plotter2D_new.py
index 5016b55be0270480c8770d01410191cfd903b7ef..6013916481481fb198e25e469af3fd4ee8285c2b 100644
--- a/dryvr_plus_plus/plotter/plotter2D_new.py
+++ b/dryvr_plus_plus/plotter/plotter2D_new.py
@@ -3,22 +3,22 @@ This file consist main plotter code for DryVR reachtube output
 """
 
 from __future__ import annotations
-from audioop import reverse
+# from audioop import reverse
 # from curses import start_color
-from re import A
-import matplotlib.patches as patches
-import matplotlib.pyplot as plt
+# from re import A
+# import matplotlib.patches as patches
+# import matplotlib.pyplot as plt
 import numpy as np
 from math import pi
 import plotly.graph_objects as go
 from typing import List
-from PIL import Image, ImageDraw
-import io
+# from PIL import Image, ImageDraw
+# import io
 import copy
-import operator
-from collections import OrderedDict
+# import operator
+# from collections import OrderedDict
 
-from torch import layout
+# from torch import layout
 from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
 
 colors = [['#CC0000', '#FF0000', '#FF3333', '#FF6666', '#FF9999'],
@@ -39,6 +39,9 @@ scheme_dict = {'red': 0, 'orange': 1, 'yellow': 2, 'yellowgreen': 3, 'lime': 4,
 bg_color = ['rgba(31,119,180,1)', 'rgba(255,127,14,0.2)', 'rgba(44,160,44,0.2)', 'rgba(214,39,40,0.2)', 'rgba(148,103,189,0.2)',
             'rgba(140,86,75,0.2)', 'rgba(227,119,194,0.2)', 'rgba(127,127,127,0.2)', 'rgba(188,189,34,0.2)', 'rgba(23,190,207,0.2)']
 color_cnt = 0
+text_size = 8
+scale_factor = 0.25
+mode_color = 'rgba(0,0,0,0.5)'
 
 
 def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim: int = 2, map_type='lines'):
@@ -69,18 +72,6 @@ def general_reachtube_anime(root, map=None, fig=None, x_dim: int = 1, y_dim: int
                 x_max = max(x_max, trace[i][x_dim])
                 y_min = min(y_min, trace[i][y_dim])
                 y_max = max(y_max, trace[i][y_dim])
-                # if round(trace[i][0], 2) not in timed_point_dict:
-                #     timed_point_dict[round(trace[i][0], 2)] = [
-                #         trace[i][1:].tolist()]
-                # else:
-                #     init = False
-                #     for record in timed_point_dict[round(trace[i][0], 2)]:
-                #         if record == trace[i][1:].tolist():
-                #             init = True
-                #             break
-                #     if init == False:
-                #         timed_point_dict[round(trace[i][0], 2)].append(
-                #             trace[i][1:].tolist())
                 time_point = round(trace[i][0], 2)
                 rect = [trace[i][0:].tolist(), trace[i+1][0:].tolist()]
                 if time_point not in timed_point_dict:
@@ -584,7 +575,7 @@ def draw_simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_
         queue = [root]
         x_min, x_max = float('inf'), -float('inf')
         y_min, y_max = float('inf'), -float('inf')
-        scale_factor = 0.25
+        # scale_factor = 0.25
     i = 0
     queue = [root]
     previous_mode = {}
@@ -635,13 +626,13 @@ def draw_simulation_tree(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_
 
                 fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
                                          mode='markers+text',
-                                         line_color='rgba(255,255,255,0.3)',
+                                         line_color=mode_color,
                                          text=str(agent_id)+': ' +
                                          str(node.mode[agent_id][0]),
                                          textposition=text_pos,
                                          textfont=dict(
                     #  family="sans serif",
-                    size=10,
+                    size=text_size,
                                              color="grey"),
                                          showlegend=False,
                                          ))
@@ -687,7 +678,7 @@ def draw_simulation_tree_single(root: AnalysisTreeNode, agent_id, fig: go.Figure
         #                          showlegend=False))'
         trace_text = []
         for i in range(len(trace)):
-            trace_text.append([round(trace[i, j], 2)
+            trace_text.append(['{:.2f}'.format(trace[i, j])
                               for j in range(trace.shape[1])])
 
         fig.add_trace(go.Scatter(x=trace[:, x_dim], y=trace[:, y_dim],
@@ -825,7 +816,7 @@ def draw_simulation_anime(root, map=None, fig=None):
         "y": y_list,
         "mode": "markers + text",
         "text": text_list,
-        "textfont": dict(size=14, color="black"),
+        "textfont": dict(size=text_size, color="black"),
         "textposition": "bottom center",
         # "marker": {
         #     "sizemode": "area",
@@ -862,7 +853,7 @@ def draw_simulation_anime(root, map=None, fig=None):
             "mode": "markers + text",
             # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
             "text": [('{:.2f}'.format(trace_theta[i]/pi*180), '{:.3f}'.format(trace_v[i])) for i in range(len(trace_theta))],
-            "textfont": dict(size=14, color="black"),
+            "textfont": dict(size=text_size, color="black"),
             "textposition": "bottom center",
             # "marker": {
             #     "sizemode": "area",
@@ -983,13 +974,13 @@ def draw_simulation_anime(root, map=None, fig=None):
                         text_pos = 'top center'
                 fig.add_trace(go.Scatter(x=[trace[0, 1]], y=[trace[0, 2]],
                                          mode='markers+text',
-                                         line_color='rgba(255,255,255,0.3)',
+                                         line_color=mode_color,
                                          text=str(agent_id)+': ' +
                                          str(node.mode[agent_id][0]),
                                          textposition=text_pos,
                                          textfont=dict(
                     #  family="sans serif",
-                    size=10,
+                    size=text_size,
                                              color="grey"),
                                          showlegend=False,
                                          ))
@@ -1026,7 +1017,7 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
         traces = node.trace
         for agent_id in traces:
             trace = np.array(traces[agent_id])
-            print(trace)
+            # print(trace)
             # segment_start.add(round(trace[0][0], 2))
             for i in range(len(trace)):
                 x_min = min(x_min, trace[i][x_dim])
@@ -1052,7 +1043,7 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
     # fill in most of layout
     # print(segment_start)
     # print(timed_point_dict.keys())
-    duration = int(600/time)
+    duration = 10
     fig_dict["layout"]["xaxis"] = {
         "range": [x_min, x_max],
         "title": "x position"}
@@ -1107,7 +1098,7 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
     }
     # make data
     point_list = timed_point_dict[0]
-    print(point_list)
+    # print(point_list)
     x_list = []
     y_list = []
     text_list = []
@@ -1118,13 +1109,14 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
         y_list.append(trace[y_dim])
         # text_list.append(
         #     ('{:.2f}'.format(trace[x_dim]), '{:.2f}'.format(trace[y_dim])))
-        text_list.append([round(trace[i], 2) for i in range(len(trace))])
+        text_list.append(['{:.2f}'.format(trace[i])
+                         for i in range(len(trace))])
     data_dict = {
         "x": x_list,
         "y": y_list,
         "mode": "markers + text",
         "text": text_list,
-        "textfont": dict(size=14, color="black"),
+        "textfont": dict(size=text_size, color="black"),
         "textposition": "bottom center",
         # "marker": {
         #     "sizemode": "area",
@@ -1148,20 +1140,28 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
         trace_y = []
         text_list = []
         for data in point_list:
+            # todo
             trace = list(data.values())[0]
             # print(trace)
             trace_x.append(trace[x_dim])
             trace_y.append(trace[y_dim])
             # text_list.append(
             #     ('{:.2f}'.format(trace[x_dim]), '{:.2f}'.format(trace[y_dim])))
-            text_list.append([round(trace[i], 2) for i in range(len(trace))])
+            text_list.append(['{:.2f}'.format(trace[i])
+                              for i in range(len(trace))])
         data_dict = {
             "x": trace_x,
             "y": trace_y,
             "mode": "markers + text",
-            # "text": [(round(trace_theta[i]/pi*180, 2), round(trace_v[i], 3)) for i in range(len(trace_theta))],
             "text": text_list,
-            "textfont": dict(size=14, color="black"),
+            "marker": {
+                # "color": color,
+                # "opacity": opacity_step*(trail_len-id),
+                # # "sizemode": "area",
+                # # "sizeref": 200000,
+                # "size": min_size + size_step*(trail_len-id)
+            },
+            "textfont": dict(size=text_size, color="black"),
             "textposition": "bottom center",
             # "marker": {
             #     "sizemode": "area",
@@ -1252,13 +1252,14 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
 
                 fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
                                          mode='markers+text',
-                                         line_color='rgba(255,255,255,0.3)',
+                                         line_color=mode_color,
                                          text=str(agent_id)+': ' +
                                          str(node.mode[agent_id][0]),
                                          textposition=text_pos,
+                                         opacity=0.5,
                                          textfont=dict(
                     #  family="sans serif",
-                    size=10,
+                    size=text_size,
                                              color="grey"),
                                          showlegend=False,
                                          ))
@@ -1266,7 +1267,7 @@ def general_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_ty
                 previous_mode[agent_id] = node.mode[agent_id]
         queue += node.child
     fig.update_traces(showlegend=False)
-    scale_factor = 0.5
+    # scale_factor = 0.25
     if scale_type == 'trace':
         fig.update_xaxes(
             range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
@@ -1301,14 +1302,11 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
         traces = node.trace
         for agent_id in traces:
             trace = np.array(traces[agent_id])
-            # print(trace)
-            # segment_start.add(round(trace[0][0], 2))
             for i in range(len(trace)):
                 x_min = min(x_min, trace[i][x_dim])
                 x_max = max(x_max, trace[i][x_dim])
                 y_min = min(y_min, trace[i][y_dim])
                 y_max = max(y_max, trace[i][y_dim])
-                # print(round(trace[i][0], 2))
                 time_point = round(trace[i][0], 2)
                 tmp_trace = trace[i][0:].tolist()
                 if time_point not in timed_point_dict:
@@ -1392,14 +1390,14 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
                 y_list.append(point[y_dim])
                 # text_list.append(
                 #     ('{:.2f}'.format(point[x_dim]), '{:.2f}'.format(point[y_dim])))
-                text_list.append([round(point[i], 2)
-                                 for i in range(len(point))])
+                text_list.append(['{:.2f}'.format(point[i])
+                                  for i in range(len(point))])
         data_dict = {
             "x": x_list,
             "y": y_list,
             "mode": "markers",
             "text": text_list,
-            "textfont": dict(size=14, color="black"),
+            "textfont": dict(size=text_size, color="black"),
             "visible": False,
             "textposition": "bottom center",
             # "marker": {
@@ -1427,6 +1425,7 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
         size_step = 2/trail_len
         min_size = 5
         for agent_id in agent_list:
+            color = colors[agent_list.index(agent_id) % 12][1]
             for id in range(0, trail_len, 2):
                 tmp_point_list = timed_point_dict[time_list[time_point_id-id]][agent_id]
                 trace_x = []
@@ -1437,20 +1436,20 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
                     trace_y.append(point[y_dim])
                     # text_list.append(
                     #     ('{:.2f}'.format(point[x_dim]), '{:.2f}'.format(point[y_dim])))
-                    text_list.append([round(point[i], 2)
-                                     for i in range(len(point))])
+                    text_list.append(['{:.2f}'.format(point[i])
+                                      for i in range(len(point))])
                 #  print(trace_y)
                 if id == 0:
                     data_dict = {
                         "x": trace_x,
                         "y": trace_y,
-                        "mode": "markers",
+                        "mode": "markers+text",
                         "text": text_list,
-                        "textfont": dict(size=6, color="black"),
+                        "textfont": dict(size=text_size, color="black"),
                         "textposition": "bottom center",
                         "visible": True,
                         "marker": {
-                            "color": 'Black',
+                            "color": color,
                             "opacity": opacity_step*(trail_len-id),
                             # "sizemode": "area",
                             # "sizeref": 200000,
@@ -1469,7 +1468,7 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
                         # "textposition": "bottom center",
                         "visible": True,
                         "marker": {
-                            "color": 'Black',
+                            "color": color,
                             "opacity": opacity_step*(trail_len-id),
                             # "sizemode": "area",
                             # "sizeref": 200000,
@@ -1541,13 +1540,14 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
 
                 fig.add_trace(go.Scatter(x=[trace[0, x_dim]], y=[trace[0, y_dim]],
                                          mode='markers+text',
-                                         line_color='rgba(255,255,255,0.3)',
+                                         line_color=mode_color,
                                          text=str(agent_id)+': ' +
                                          str(node.mode[agent_id][0]),
+                                         opacity=0.5,
                                          textposition=text_pos,
                                          textfont=dict(
                     #  family="sans serif",
-                    size=10,
+                    size=text_size,
                                              color="grey"),
                                          showlegend=False,
                                          ))
@@ -1555,7 +1555,7 @@ def test_simu_anime(root, map=None, fig=None, x_dim: int = 1, y_dim=2, map_type=
                 previous_mode[agent_id] = node.mode[agent_id]
         queue += node.child
     fig.update_traces(showlegend=False)
-    scale_factor = 0.5
+    # scale_factor = 0.5
     if scale_type == 'trace':
         fig.update_xaxes(
             range=[x_min-scale_factor*(x_max-x_min), x_max+scale_factor*(x_max-x_min)])
diff --git a/dryvr_plus_plus/scene_verifier/analysis/simulator.py b/dryvr_plus_plus/scene_verifier/analysis/simulator.py
index c7567cef5aaa6562b0cb16b252400f48df702491..39f753a39376ff3532cd528e02971c715527144f 100644
--- a/dryvr_plus_plus/scene_verifier/analysis/simulator.py
+++ b/dryvr_plus_plus/scene_verifier/analysis/simulator.py
@@ -12,7 +12,7 @@ class Simulator:
     def __init__(self):
         self.simulation_tree_root = None
 
-    def simulate(self, init_list, init_mode_list, agent_list:List[BaseAgent], transition_graph, time_horizon, time_step, lane_map):
+    def simulate(self, init_list, init_mode_list, agent_list: List[BaseAgent], transition_graph, time_horizon, time_step, lane_map):
         # Setup the root of the simulation tree
         root = AnalysisTreeNode(
             trace={},
@@ -33,9 +33,9 @@ class Simulator:
         simulation_queue.append(root)
         # Perform BFS through the simulation tree to loop through all possible transitions
         while simulation_queue != []:
-            node:AnalysisTreeNode = simulation_queue.pop(0)
+            node: AnalysisTreeNode = simulation_queue.pop(0)
             print(node.start_time, node.mode)
-            remain_time = round(time_horizon - node.start_time,10)
+            remain_time = round(time_horizon - node.start_time, 10)
             if remain_time <= 0:
                 continue
             # For trace not already simulated
@@ -45,20 +45,20 @@ class Simulator:
                     # [time, x, y, theta, v]
                     mode = node.mode[agent_id]
                     init = node.init[agent_id]
-                    trace = node.agent[agent_id].TC_simulate(mode, init, remain_time, time_step, lane_map)
-                    trace[:,0] += node.start_time
+                    trace = node.agent[agent_id].TC_simulate(
+                        mode, init, remain_time, time_step, lane_map)
+                    trace[:, 0] += node.start_time
                     node.trace[agent_id] = trace.tolist()
 
-            transitions, transition_idx = transition_graph.get_transition_simulate_new(node)
+            transitions, transition_idx = transition_graph.get_transition_simulate_new(
+                node)
 
             # If there's no transitions (returned transitions is empty), continue
             if not transitions:
                 continue
 
-
             # truncate the computed trajectories from idx and store the content after truncate
             truncated_trace = {}
-            print("idx", idx)
             for agent_idx in node.agent:
                 truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:]
                 node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1]
@@ -72,8 +72,8 @@ class Simulator:
             # copy the traces that are not under transition
 
             for transition_combination in all_transition_combinations:
-                next_node_mode = copy.deepcopy(node.mode) 
-                next_node_agent = node.agent 
+                next_node_mode = copy.deepcopy(node.mode)
+                next_node_agent = node.agent
                 next_node_start_time = list(truncated_trace.values())[0][0][0]
                 next_node_init = {}
                 next_node_trace = {}
@@ -82,8 +82,8 @@ class Simulator:
                     if dest_mode is None:
                         continue
                     # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0)
-                    next_node_mode[transit_agent_idx] = dest_mode 
-                    next_node_init[transit_agent_idx] = next_init 
+                    next_node_mode[transit_agent_idx] = dest_mode
+                    next_node_init[transit_agent_idx] = next_init
                 for agent_idx in next_node_agent:
                     if agent_idx not in next_node_init:
                         next_node_trace[agent_idx] = truncated_trace[agent_idx]