diff --git a/demo/tacas2023/exp1/exp1.py b/demo/tacas2023/exp1/exp1.py
index 98d23a551b25f31c479c92db6c954c01636a2a45..7f0e82d8f7fddc0acea8877bd063bc4b05ca8068 100644
--- a/demo/tacas2023/exp1/exp1.py
+++ b/demo/tacas2023/exp1/exp1.py
@@ -1,5 +1,6 @@
 from quadrotor_agent import QuadrotorAgent
 from verse import Scenario
+from verse.scenario.scenario import ScenarioConfig
 from verse.plotter.plotter3D_new import *
 from verse.plotter.plotter3D import *
 from verse.map.example_map.map_tacas import M6 
@@ -28,8 +29,9 @@ class TrackMode(Enum):
 
 if __name__ == "__main__":
     input_code_name = './demo/tacas2023/exp1/quadrotor_controller3.py'
-
-    scenario = Scenario()
+    para = True
+    config = ScenarioConfig(parallel=para)
+    scenario = Scenario(config)
     time_step = 0.1
     quadrotor1 = QuadrotorAgent(
         'test1', file_name=input_code_name, t_v_pair=(1, 1), box_side=[0.4]*3)
@@ -58,12 +60,13 @@ if __name__ == "__main__":
     # traces = scenario.simulate(40, time_step, seed=4)
     # traces.dump("./output1.json")
     # traces = AnalysisTree.load('./output1.json')
-    mode='ser'
-    mode='par'
     start_time = time.time()
-    traces = scenario.verify(40, time_step, mode=mode)
+    traces = scenario.verify(40, time_step)
     run_time = time.time() - start_time
-    traces.dump('demo/tacas2023/exp1/output_'+mode+'.json')
+    if scenario.config.parallel:
+        traces.dump('demo/tacas2023/exp1/output_par.json')
+    else:
+        traces.dump('demo/tacas2023/exp1/output_ser.json')
     print({
         "#A": len(scenario.agent_dict),
         "A": "Q",
diff --git a/demo/tacas2023/exp11/inc-expr.py b/demo/tacas2023/exp11/inc-expr.py
index 57c3c24b0900afb43bd46dd3e4e7e793175768d9..8330494d5bdbfffd8c66467c083080801922937a 100644
--- a/demo/tacas2023/exp11/inc-expr.py
+++ b/demo/tacas2023/exp11/inc-expr.py
@@ -59,8 +59,6 @@ if 'p' in arg:
     from verse.plotter.plotter2D import simulation_tree
 
 def run(sim, meas=False):
-    mode='ser'
-    mode='par'
     time = timeit.default_timer()
     if sim:
         scenario.simulator.cache_hits = (0, 0)
@@ -68,11 +66,14 @@ def run(sim, meas=False):
     else:
         scenario.verifier.tube_cache_hits = (0,0)
         scenario.verifier.trans_cache_hits = (0,0)
-        traces = scenario.verify(60, 0.1, mode=mode)
+        traces = scenario.verify(60, 0.1)
 
     if 'd' in arg:
         # traces.dump_tree()
-        traces.dump('./demo/tacas2023/exp11/main_'+mode+'.json')
+        if scenario.config.parallel:
+            traces.dump('./demo/tacas2023/exp11/main_par.json')
+        else:
+            traces.dump('./demo/tacas2023/exp11/main_ser.json')
         # traces.dump("tree2.json" if meas else "tree1.json") 
 
     if 'p' in arg and meas:
@@ -97,7 +98,8 @@ def run(sim, meas=False):
 
 if __name__ == "__main__":
     input_code_name = './demo/tacas2023/exp11/decision_logic/inc-expr6.py' if "6" in arg else './demo/tacas2023/exp11/decision_logic/inc-expr.py'
-    config = ScenarioConfig()
+    para = True
+    config = ScenarioConfig(parallel=para)
     config.incremental = 'i' in arg
     scenario = Scenario(config)
 
diff --git a/demo/tacas2023/exp3/exp3.py b/demo/tacas2023/exp3/exp3.py
index ff788da1a448f2eba19e9aac96901afd6c0d0e7f..fda338f99ac31beebe5a092d96da5510d2274f70 100644
--- a/demo/tacas2023/exp3/exp3.py
+++ b/demo/tacas2023/exp3/exp3.py
@@ -1,6 +1,7 @@
 from verse.agents.example_agent import CarAgent, NPCAgent
 from verse.map.example_map.map_tacas import M2
 from verse import Scenario
+from verse.scenario.scenario import ScenarioConfig
 from verse.plotter.plotter2D import *
 
 from enum import Enum, auto
@@ -34,7 +35,9 @@ class TrackMode(Enum):
 
 if __name__ == "__main__":
     input_code_name = './demo/tacas2023/exp3/example_controller7.py'
-    scenario = Scenario()
+    para = True
+    config = ScenarioConfig(parallel=para)
+    scenario = Scenario(config)
 
     car = CarAgent('car1', file_name=input_code_name)
     scenario.add_agent(car)
@@ -78,12 +81,13 @@ if __name__ == "__main__":
     # fig = simulation_anime(traces, tmp_map, fig, 1, 2, [
     #                        1, 2], 'lines', 'trace', sample_rate=1)
     # fig.show()
-    mode='ser'
-    mode='par'
     start_time = time.time()
-    traces = scenario.verify(80, 0.05, mode=mode)
+    traces = scenario.verify(80, 0.05)
     run_time = time.time()-start_time 
-    traces.dump('./demo/tacas2023/exp3/output3_'+mode+'.json')
+    if scenario.config.parallel:
+        traces.dump('./demo/tacas2023/exp3/output3_par.json')
+    else:
+        traces.dump('./demo/tacas2023/exp3/output3_ser.json')
 
     print({
         "#A": len(scenario.agent_dict),
diff --git a/demo/tacas2023/exp5/exp5.py b/demo/tacas2023/exp5/exp5.py
index 7e9ead6cc687f12c7ba54f3e6c277f0d1c919f74..a4a8d005dc1da7ca044fc12d1b5df8c92d11d797 100644
--- a/demo/tacas2023/exp5/exp5.py
+++ b/demo/tacas2023/exp5/exp5.py
@@ -38,7 +38,8 @@ class TrackMode(Enum):
 if __name__ == "__main__":
     input_code_name = './demo/tacas2023/exp5/example_controller5.py'    
     # config = 
-    scenario = Scenario(ScenarioConfig(init_seg_length=1))
+    para = True
+    scenario = Scenario(ScenarioConfig(init_seg_length=1,parallel=para))
     scenario.add_agent(CarAgent('car1', file_name=input_code_name))
     scenario.add_agent(NPCAgent('car2'))
     scenario.add_agent(NPCAgent('car3'))
@@ -62,12 +63,14 @@ if __name__ == "__main__":
         ]
     )
     # scenario.config.init_seg_length = 5
-    mode='ser'
-    mode='par'
     start_time = time.time()
-    traces = scenario.verify(60, 0.1, mode=mode)  # traces.dump('./output1.json')
+    traces = scenario.verify(60, 0.1)  # traces.dump('./output1.json')
     run_time = time.time()-start_time
-    traces.dump('./demo/tacas2023/exp5/output5_'+mode+'.json')
+    if scenario.config.parallel:
+        traces.dump('./demo/tacas2023/exp5/output5_par.json')
+    else:
+        traces.dump('./demo/tacas2023/exp5/output5_ser.json')
+
     print({
         "#A": len(scenario.agent_dict),
         "A": "C",
diff --git a/demo/tacas2023/exp9/exp9_dryvr.py b/demo/tacas2023/exp9/exp9_dryvr.py
index ed99416e5cbe6ab6e0c23cc0aca7894745319223..7c34bfaf16c27c0b112ed7022b69cc4ef0d2ed9e 100644
--- a/demo/tacas2023/exp9/exp9_dryvr.py
+++ b/demo/tacas2023/exp9/exp9_dryvr.py
@@ -1,5 +1,6 @@
 from quadrotor_agent import QuadrotorAgent
 from verse import Scenario
+from verse.scenario.scenario import ScenarioConfig
 from verse.plotter.plotter2D import reachtube_tree
 from verse.plotter.plotter3D_new import *
 from verse.plotter.plotter3D import *
@@ -31,7 +32,9 @@ if __name__ == "__main__":
     input_code_name = './demo/tacas2023/exp9/quadrotor_controller3.py'
     input_code_name2 = './demo/tacas2023/exp9/quadrotor_controller4.py'
 
-    scenario = Scenario()
+    para = True
+    config = ScenarioConfig(parallel=para)
+    scenario = Scenario(config)
     time_step = 0.2
     quadrotor1 = QuadrotorAgent(
         'test1', file_name=input_code_name, t_v_pair=(1, 1), box_side=[0.4]*3)
@@ -61,12 +64,14 @@ if __name__ == "__main__":
     # fig = go.Figure()
     # fig = simulation_tree_3d(traces, tmp_map, fig, 1, 2, 3, [1, 2, 3])
     # fig.show()
-    mode='ser'
-    mode='par'
+
     start_time = time.time()
-    traces = scenario.verify(60, time_step, mode=mode)
+    traces = scenario.verify(60, time_step)
     run_time = time.time() - start_time
-    traces.dump('./demo/tacas2023/exp9/output9_dryvr_'+mode+'.json')
+    if scenario.config.parallel:
+        traces.dump('./demo/tacas2023/exp9/output9_dryvr_par.json')
+    else:
+        traces.dump('./demo/tacas2023/exp9/output9_dryvr_ser.json')
     print({
         "#A": len(scenario.agent_dict),
         "A": "Q",
diff --git a/demo/tacas2023/exp9/exp9_neureach.py b/demo/tacas2023/exp9/exp9_neureach.py
index e24c745b1b1248579a8cf96c0aca476d68f3f416..14dc97fe1e377198a0a805a4756867171ac62c8a 100644
--- a/demo/tacas2023/exp9/exp9_neureach.py
+++ b/demo/tacas2023/exp9/exp9_neureach.py
@@ -65,8 +65,6 @@ if __name__ == "__main__":
 
     # traces = scenario.verify(60, time_step)
     # traces.dump('./demo/tacas2023/exp9/output9_DryVR.json')
-    mode='ser'
-    # mode='par'
     start_time = time.time()
     traces = scenario.verify(60, time_step,
                              params={
@@ -80,7 +78,7 @@ if __name__ == "__main__":
                              }
                             )
     run_time = time.time() - start_time
-    traces.dump('./demo/tacas2023/exp9/output9_neureach_'+mode+'.json')
+    traces.dump('./demo/tacas2023/exp9/output9_neureach.json')
 
     print({
         "#A": len(scenario.agent_dict),
diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py
index b4bcc2c8b05c93fecd6405388815b10a4a7b9c6d..8b7ce42579c916f80fde1c50cf987e073ff4cbe6 100644
--- a/verse/analysis/verifier.py
+++ b/verse/analysis/verifier.py
@@ -104,6 +104,7 @@ class Verifier:
         node: AnalysisTreeNode,
         remain_time: float,
         time_step: float,
+        max_height: int,
         lane_map: LaneMap,
         init_seg_length,
         reachability_method,
@@ -112,6 +113,8 @@ class Verifier:
         transition_graph,
         params = {},
     ):
+        if (max_height == None):
+            max_height = float('inf')
         combined_inits = {a: combine_all(inits) for a, inits in node.init.items()}
         # print(node.init)
         # print(node.mode)
@@ -224,7 +227,9 @@ class Verifier:
                     # pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions)))
                 else:
                     self.trans_cache.add_tube(agent_id, combined_inits, node, transit_agents, transition, transit_ind, run_num)
-
+        if (node.height >= max_height):
+            print("max depth reached")
+            return node
         next_nodes = []
         max_end_idx = 0
         for transition in all_possible_transitions:
@@ -293,6 +298,7 @@ class Verifier:
         transition_graph,
         time_horizon,
         time_step,
+        max_height,
         lane_map,
         init_seg_length,
         reachability_method,
@@ -301,8 +307,8 @@ class Verifier:
         params = {},
     ):
 
-        # if (max_height == None):
-        #     max_height = float('inf')
+        if (max_height == None):
+            max_height = float('inf')
         root = AnalysisTreeNode(
             trace={},
             init={},
@@ -310,6 +316,7 @@ class Verifier:
             static = {},
             uncertain_param={},
             agent={},
+            height=0,
             assert_hits={},
             child=[],
             start_time = 0,
@@ -331,6 +338,7 @@ class Verifier:
         while True:
             wait = False
             if len(verification_queue) > 0:
+                print([node.id for node in verification_queue])
                 node: AnalysisTreeNode = verification_queue.pop(0)
                 num_transitions+=1
                 # pp(("start ver", node.start_time, {a: (*node.mode[a], *node.init[a]) for a in node.mode}))
@@ -339,28 +347,29 @@ class Verifier:
                     continue
                 # For trace not already verified
                 result_refs.append(self.compute_full_reachtube_step.remote
-                                (self, node, remain_time, time_step, lane_map, init_seg_length, reachability_method,run_num, past_runs, transition_graph, params))
+                                (self, node, remain_time, time_step, max_height, lane_map, init_seg_length, reachability_method,run_num, past_runs, transition_graph, params))
                 if len(result_refs) >= self.config.parallel_ver_ahead:
                     wait = True
             elif len(result_refs) > 0:
                 wait = True
             else:
                 break
-            print(len(verification_queue), len(result_refs))
+            # print(len(verification_queue), len(result_refs))
             if wait:
                 [res], result_refs = ray.wait(result_refs)
                 node= ray.get(res)
                 id=node.id
                 next_nodes = node.child
-                print("got id:", id)
+                # print("got id:", id)
                 nodes[id].assert_hits=node.assert_hits
                 nodes[id].child=next_nodes
                 nodes[id].trace=node.trace
                 last_id = nodes[-1].id
                 for i, node in enumerate(next_nodes):
                     node.id = i + 1 + last_id
-                verification_queue.extend(next_nodes)
-                nodes.extend(next_nodes)
+                if node.height <= max_height:
+                    verification_queue.extend(next_nodes)
+                    nodes.extend(next_nodes)
         self.reachtube_tree = AnalysisTree(root)
         # print(f">>>>>>>> Number of calls to reachability engine: {num_calls}")
         # print(f">>>>>>>> Number of transitions happening: {num_transitions}")
@@ -378,6 +387,7 @@ class Verifier:
         transition_graph,
         time_horizon,
         time_step,
+        max_height,
         lane_map,
         init_seg_length,
         reachability_method,
@@ -392,6 +402,7 @@ class Verifier:
             static = {},
             uncertain_param={},
             agent={},
+            height=0,
             assert_hits={},
             child=[],
             start_time = 0,
@@ -513,13 +524,13 @@ class Verifier:
                     # pp(("to sim", new_cache.keys(), len(paths_to_sim)))
 
             # Get all possible transitions to next mode
-            start = time.perf_counter()
-            asserts_orig, all_possible_transitions_orig = transition_graph.get_transition_verify(new_cache, paths_to_sim, node)
-            time_transition_orig += time.perf_counter() - start
+            # start = time.perf_counter()
+            # asserts_orig, all_possible_transitions_orig = transition_graph.get_transition_verify(new_cache, paths_to_sim, node)
+            # time_transition_orig += time.perf_counter() - start
             start = time.perf_counter()
             asserts, all_possible_transitions = transition_graph.get_transition_verify_opt(new_cache, paths_to_sim, node)
             time_transition += time.perf_counter() - start
-            assert asserts == asserts_orig and all_possible_transitions_orig == all_possible_transitions
+            # assert asserts == asserts_orig and all_possible_transitions_orig == all_possible_transitions
             # pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions]))
             node.assert_hits = asserts
             if asserts != None:
@@ -543,6 +554,9 @@ class Verifier:
                         # pp(("dedup!", pre_len, len(cached_tubes[agent_id].transitions)))
                     else:
                         self.trans_cache.add_tube(agent_id, combined_inits, node, transit_agents, transition, transit_ind, run_num)
+            if (node.height >= max_height):
+                print("max depth reached")
+                continue
             max_end_idx = 0
             for transition in all_possible_transitions:
                 # Each transition will contain a list of rectangles and their corresponding indexes in the original list
@@ -605,3 +619,14 @@ class Verifier:
         self.num_transitions = num_transitions
         print('reachtube time:', time_reachtube, 'transition time:', time_transition_orig, 'transition opt time:', time_transition)
         return self.reachtube_tree
+
+def checkHeight(root, max_height):
+    if root:
+        # First recur on left child
+        # then print the data of node
+        if(root.child == []):
+            print("HEIGHT", root.height)
+            if(root.height > max_height):
+                print("Exceeds max height")
+        for c in root.child:
+            checkHeight(c, max_height)
\ No newline at end of file
diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py
index a000a7e2266f51d0008265c70bcb964759c72feb..aa23afbb7af54e550b997e84903927434630c91a 100644
--- a/verse/scenario/scenario.py
+++ b/verse/scenario/scenario.py
@@ -158,6 +158,7 @@ class ScenarioConfig:
     reachability_method: str = 'DRYVR'
     parallel_sim_ahead: int = 8
     parallel_ver_ahead: int = 8
+    parallel: bool = True
 
 class Scenario:
     def __init__(self, config=ScenarioConfig()):
@@ -322,7 +323,7 @@ class Scenario:
         self.past_runs.append(tree)
         return tree
 
-    def verify(self, time_horizon, time_step, params={}, mode='ser') -> AnalysisTree:
+    def verify(self, time_horizon, time_step, params={}, max_height=None) -> AnalysisTree:
         self.check_init()
         init_list = []
         init_mode_list = []
@@ -339,13 +340,13 @@ class Scenario:
             static_list.append(self.static_dict[agent_id])
             uncertain_param_list.append(self.uncertain_param_dict[agent_id])
             agent_list.append(self.agent_dict[agent_id])
-        if mode == 'ser':
+        if not self.config.parallel:
             tree = self.verifier.compute_full_reachtube_ser(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon,
-                                                    time_step, self.map, self.config.init_seg_length, self.config.reachability_method, len(self.past_runs), self.past_runs, params)
+                                                    time_step, max_height,self.map, self.config.init_seg_length, self.config.reachability_method, len(self.past_runs), self.past_runs, params)
         else:
             ray.init()
             tree = self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon,
-                                                    time_step, self.map, self.config.init_seg_length, self.config.reachability_method, len(self.past_runs), self.past_runs, params)
+                                                    time_step, max_height, self.map, self.config.init_seg_length, self.config.reachability_method, len(self.past_runs), self.past_runs, params)
             ray.shutdown()
         self.past_runs.append(tree)
         return tree