From c2c610c214cc4fbfaaefeb745873f333a3e9f085 Mon Sep 17 00:00:00 2001
From: Yangge Li <li213@illinois.edu>
Date: Tue, 14 Jun 2022 19:06:35 -0500
Subject: [PATCH] working on performance optimization by moving around code
 performing unroll and handle discrete variables

---
 demo/demo3.py                                 |  33 +--
 demo/demo4.py                                 |   2 +-
 .../scene_verifier/analysis/simulator.py      |  20 +-
 .../scene_verifier/analysis/verifier.py       |   2 +-
 .../scene_verifier/automaton/guard.py         | 136 ++++++++++
 .../scene_verifier/scenario/scenario.py       | 239 +++++++++++++++++-
 6 files changed, 390 insertions(+), 42 deletions(-)

diff --git a/demo/demo3.py b/demo/demo3.py
index 079c68d6..807c1f45 100644
--- a/demo/demo3.py
+++ b/demo/demo3.py
@@ -42,7 +42,7 @@ class State:
 
 
 if __name__ == "__main__":
-    input_code_name = './demo/example_controller6.py'
+    input_code_name = './example_controller4.py'
     scenario = Scenario()
 
     car = CarAgent('car1', file_name=input_code_name)
@@ -70,25 +70,18 @@ if __name__ == "__main__":
             (VehicleMode.Normal, LaneMode.Lane1),
         ]
     )
-    traces = scenario.simulate(70)
-    # traces = scenario.verify(60)
+    # traces = scenario.simulate(70)
+    traces = scenario.verify(60)
 
-    # fig = plt.figure(2)
-    # fig = plot_map(tmp_map, 'g', fig)
-    # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
-    # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
-    # fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig)
-    # fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig)
-    # for traces in res_list:
-    # #     generate_simulation_anime(traces, tmp_map, fig)
-    # fig = plot_simulation_tree(traces, 'car1', 1, [2], 'b', fig)
-    # fig = plot_simulation_tree(traces, 'car2', 1, [2], 'r', fig)
-    # fig = plot_simulation_tree(traces, 'car3', 1, [2], 'r', fig)
-    # fig = plot_simulation_tree(traces, 'car4', 1, [2], 'r', fig)
-    # plt.show()
-        
+    fig = plt.figure(2)
+    fig = plot_map(tmp_map, 'g', fig)
+    fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig)
+    fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig)
+    fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig)
+    fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig)
+    # plt.show()    
 
-    fig = go.Figure()
-    fig = plotly_simulation_anime(traces, tmp_map, fig)
-    fig.show()
+    # fig = go.Figure()
+    # fig = plotly_simulation_anime(traces, tmp_map, fig)
+    # fig.show()
 
diff --git a/demo/demo4.py b/demo/demo4.py
index 04c7c541..2a9843f2 100644
--- a/demo/demo4.py
+++ b/demo/demo4.py
@@ -43,7 +43,7 @@ class State:
 
 
 if __name__ == "__main__":
-    input_code_name = './example_controller5.py'
+    input_code_name = './example_controller4.py'
     scenario = Scenario()
 
     car = CarAgent('car1', file_name=input_code_name)
diff --git a/dryvr_plus_plus/scene_verifier/analysis/simulator.py b/dryvr_plus_plus/scene_verifier/analysis/simulator.py
index 7d064ff4..d5e8c97f 100644
--- a/dryvr_plus_plus/scene_verifier/analysis/simulator.py
+++ b/dryvr_plus_plus/scene_verifier/analysis/simulator.py
@@ -46,27 +46,13 @@ class Simulator:
                     trace[:,0] += node.start_time
                     node.trace[agent_id] = trace.tolist()
 
-            trace_length = len(list(node.trace.values())[0])
-            transitions = []
-            for idx in range(trace_length):
-                # For each trace, check with the guard to see if there's any possible transition
-                # Store all possible transition in a list
-                # A transition is defined by (agent, src_mode, dest_mode, corresponding reset, transit idx)
-                # Here we enforce that only one agent transit at a time
-                all_agent_state = {}
-                for agent_id in node.agent:
-                    all_agent_state[agent_id] = (node.trace[agent_id][idx], node.mode[agent_id])
-                possible_transitions = transition_graph.get_all_transition(all_agent_state)
-                if possible_transitions != []:
-                    for agent_idx, src_mode, dest_mode, next_init in possible_transitions:
-                        transitions.append((agent_idx, src_mode, dest_mode, next_init, idx))
-                    break
+            transitions, transition_idx = transition_graph.get_transition_simulate_new(node)
 
             # truncate the computed trajectories from idx and store the content after truncate
             truncated_trace = {}
             for agent_idx in node.agent:
-                truncated_trace[agent_idx] = node.trace[agent_idx][idx:]
-                node.trace[agent_idx] = node.trace[agent_idx][:idx+1]
+                truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:]
+                node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1]
 
             # For each possible transition, construct the new node. 
             # Obtain the new initial condition for agent having transition
diff --git a/dryvr_plus_plus/scene_verifier/analysis/verifier.py b/dryvr_plus_plus/scene_verifier/analysis/verifier.py
index 5267f5e8..d3f8aa66 100644
--- a/dryvr_plus_plus/scene_verifier/analysis/verifier.py
+++ b/dryvr_plus_plus/scene_verifier/analysis/verifier.py
@@ -67,7 +67,7 @@ class Verifier:
             # TODO: Check safety conditions here
 
             # Get all possible transitions to next mode
-            all_possible_transitions = transition_graph.get_all_transition_set(node)
+            all_possible_transitions = transition_graph.get_transition_verify_new(node)
             max_end_idx = 0
             for transition in all_possible_transitions:
                 transit_agent_idx, src_mode, dest_mode, next_init, idx = transition 
diff --git a/dryvr_plus_plus/scene_verifier/automaton/guard.py b/dryvr_plus_plus/scene_verifier/automaton/guard.py
index ebaad776..d0cfb9c2 100644
--- a/dryvr_plus_plus/scene_verifier/automaton/guard.py
+++ b/dryvr_plus_plus/scene_verifier/automaton/guard.py
@@ -867,6 +867,142 @@ class GuardExpressionAst:
         # Return the modified node
         return root
 
+    def parse_any_all_new(self, cont_var_dict: Dict[str, float], disc_var_dict: Dict[str, float], len_dict: Dict[str, int]) -> Dict[str, List[str]]: 
+        cont_var_updater = {}
+        for i in range(len(self.ast_list)):
+            root = self.ast_list[i]
+            j = 0
+            while j < sum(1 for _ in ast.walk(root)):
+                # TODO: Find a faster way to access nodes in the tree
+                node = list(ast.walk(root))[j]
+                if isinstance(node, ast.Call) and\
+                    isinstance(node.func, ast.Name) and\
+                    (node.func.id=='any' or node.func.id=='all'):
+                    new_node = self.unroll_any_all_new(node, cont_var_dict, disc_var_dict, len_dict, cont_var_updater)
+                    root = NodeSubstituter(node, new_node).visit(root)
+                j += 1
+            self.ast_list[i] = root 
+        return cont_var_updater
+
+    def unroll_any_all_new(
+        self, node: ast.Call, 
+        cont_var_dict: Dict[str, float], 
+        disc_var_dict: Dict[str, float], 
+        len_dict: Dict[str, float],
+        cont_var_updater: Dict[str, List[str]],
+    ) -> Tuple[ast.BoolOp, Dict[str, List[str]]]:
+        parse_arg = node.args[0]
+        if isinstance(parse_arg, ast.GeneratorExp):
+            iter_name_list = []
+            targ_name_list = []
+            iter_len_list = []
+            # Get all the iter, targets and the length of iter list 
+            for generator in parse_arg.generators:
+                iter_name_list.append(generator.iter.id) # a_list
+                targ_name_list.append(generator.target.id) # a
+                iter_len_list.append(range(len_dict[generator.iter.id])) # len(a_list)
+
+            elt = parse_arg.elt
+            expand_elt_ast_list = []
+            iter_len_list = list(itertools.product(*iter_len_list))
+            # Loop through all possible combination of iter value
+            for i in range(len(iter_len_list)):
+                changed_elt = copy.deepcopy(elt)
+                iter_pos_list = iter_len_list[i]
+                # substitute temporary variable in each of the elt and add corresponding variables in the variable dicts
+                parsed_elt = self._parse_elt_new(changed_elt, cont_var_dict, disc_var_dict, cont_var_updater, iter_name_list, targ_name_list, iter_pos_list)
+                # Add the expanded elt into the list 
+                expand_elt_ast_list.append(parsed_elt)
+            # Create the new boolop (and/or) node based on the list of expanded elt
+            return ValueSubstituter(expand_elt_ast_list, node).visit(node)
+        else:
+            return node
+
+    def _parse_elt_new(self, root, cont_var_dict, disc_var_dict, cont_var_updater, iter_name_list, targ_name_list, iter_pos_list) -> Any:
+        # Loop through all node in the elt ast 
+        for node in ast.walk(root):
+            # If the node is an attribute
+            if isinstance(node, ast.Attribute):
+                if node.value.id in targ_name_list:
+                    # Find corresponding targ_name in the targ_name_list
+                    targ_name = node.value.id
+                    var_index = targ_name_list.index(targ_name)
+
+                    # Find the corresponding iter_name in the iter_name_list 
+                    iter_name = iter_name_list[var_index]
+
+                    # Create the name for the tmp variable 
+                    iter_pos = iter_pos_list[var_index]
+                    tmp_variable_name = f"{iter_name}_{iter_pos}.{node.attr}"
+
+                    # Replace variables in the etl by using tmp variables
+                    root = ValueSubstituter(tmp_variable_name, node).visit(root)
+
+                    # Find the value of the tmp variable in the cont/disc_var_dict
+                    # Add the tmp variables into the cont/disc_var_dict
+                    # NOTE: At each time step, for each agent, the variable value mapping and their 
+                    # sequence in the list is single. Therefore, for the same key, we will always rewrite 
+                    # its content. 
+                    variable_name = iter_name + '.' + node.attr
+                    variable_val = None
+                    if variable_name in cont_var_dict:
+                        # variable_val = cont_var_dict[variable_name][iter_pos]
+                        # cont_var_dict[tmp_variable_name] = variable_val
+                        if variable_name not in cont_var_updater:
+                            cont_var_updater[variable_name] = [(tmp_variable_name, iter_pos)]
+                        else:
+                            if (tmp_variable_name, iter_pos) not in cont_var_updater[variable_name]:
+                                cont_var_updater[variable_name].append((tmp_variable_name, iter_pos))
+                    elif variable_name in disc_var_dict:
+                        variable_val = disc_var_dict[variable_name][iter_pos]
+                        disc_var_dict[tmp_variable_name] = variable_val
+                        # if variable_name not in disc_var_updater:
+                        #     disc_var_updater[variable_name] = [(tmp_variable_name, iter_pos)]
+                        # else:
+                        #     if (tmp_variable_name, iter_pos) not in disc_var_updater[variable_name]:
+                        #         disc_var_updater[variable_name].append((tmp_variable_name, iter_pos))
+
+            elif isinstance(node, ast.Name):
+                if node.id in targ_name_list:
+                    node:ast.Name
+                    # Find corresponding targ_name in the targ_name_list
+                    targ_name = node.id
+                    var_index = targ_name_list.index(targ_name)
+
+                    # Find the corresponding iter_name in the iter_name_list 
+                    iter_name = iter_name_list[var_index]
+
+                    # Create the name for the tmp variable 
+                    iter_pos = iter_pos_list[var_index]
+                    tmp_variable_name = f"{iter_name}_{iter_pos}"
+
+                    # Replace variables in the etl by using tmp variables
+                    root = ValueSubstituter(tmp_variable_name, node).visit(root)
+
+                    # Find the value of the tmp variable in the cont/disc_var_dict
+                    # Add the tmp variables into the cont/disc_var_dict
+                    variable_name = iter_name
+                    variable_val = None
+                    if variable_name in cont_var_dict:
+                        # variable_val = cont_var_dict[variable_name][iter_pos]
+                        # cont_var_dict[tmp_variable_name] = variable_val
+                        if variable_name not in cont_var_updater:
+                            cont_var_updater[variable_name] = [(tmp_variable_name, iter_pos)]
+                        else:
+                            if (tmp_variable_name, iter_pos) not in cont_var_updater[variable_name]:
+                                cont_var_updater.append(tmp_variable_name, iter_pos)
+                    elif variable_name in disc_var_dict:
+                        variable_val = disc_var_dict[variable_name][iter_pos]
+                        disc_var_dict[tmp_variable_name] = variable_val
+                        # if variable_name not in disc_var_updater:
+                        #     disc_var_updater[variable_name] = [(tmp_variable_name, iter_pos)]
+                        # else:
+                        #     if (tmp_variable_name, iter_pos) not in disc_var_updater[variable_name]:
+                        #         disc_var_updater[variable_name].append((tmp_variable_name, iter_pos))
+
+        # Return the modified node
+        return root
+
 if __name__ == "__main__":
     with open('tmp.pickle','rb') as f:
         guard_list = pickle.load(f)
diff --git a/dryvr_plus_plus/scene_verifier/scenario/scenario.py b/dryvr_plus_plus/scene_verifier/scenario/scenario.py
index 8d1d8e76..9fb105be 100644
--- a/dryvr_plus_plus/scene_verifier/scenario/scenario.py
+++ b/dryvr_plus_plus/scene_verifier/scenario/scenario.py
@@ -4,6 +4,7 @@ import itertools
 import warnings
 
 import numpy as np
+from sympy import Q
 
 from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
 from dryvr_plus_plus.scene_verifier.automaton.guard import GuardExpressionAst
@@ -13,6 +14,7 @@ from dryvr_plus_plus.scene_verifier.analysis.simulator import Simulator
 from dryvr_plus_plus.scene_verifier.analysis.verifier import Verifier
 from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
 from dryvr_plus_plus.scene_verifier.utils.utils import *
+from dryvr_plus_plus.scene_verifier.analysis.analysis_tree_node import AnalysisTreeNode
 
 class Scenario:
     def __init__(self):
@@ -90,7 +92,7 @@ class Scenario:
     def check_guard_hit(self, state_dict):
         lane_map = self.map 
         guard_hits = []
-        any_contained = False        # TODO: Handle this
+        any_contained = False        
         for agent_id in state_dict:
             agent:BaseAgent = self.agent_dict[agent_id]
             agent_state, agent_mode = state_dict[agent_id]
@@ -122,7 +124,6 @@ class Scenario:
                 if not guard_can_satisfied:
                     continue
 
-                # TODO: Handle hybrid guards that involves both continuous and discrete dynamics 
                 # Will have to limit the amount of hybrid guards that we want to handle. The difficulty will be handle function guards.
                 guard_can_satisfied = guard_expression.evaluate_guard_hybrid(agent, discrete_variable_dict, continuous_variable_dict, self.map)
                 if not guard_can_satisfied:
@@ -135,7 +136,7 @@ class Scenario:
                     guard_hits.append((agent_id, guard_list, reset_list))
         return guard_hits, any_contained
 
-    def get_all_transition_set(self, node):
+    def get_transition_verify(self, node):
         possible_transitions = []
         trace_length = int(len(list(node.trace.values())[0])/2)
         guard_hits = []
@@ -277,3 +278,235 @@ class Scenario:
                         satisfied_guard.append(next_transition)
 
         return satisfied_guard
+
+    def get_transition_simulate(self, node:AnalysisTreeNode) -> Tuple[List[Tuple[float]], int]:
+        trace_length = len(list(node.trace.values())[0])
+        transitions = []
+        for idx in range(trace_length):
+            # For each trace, check with the guard to see if there's any possible transition
+            # Store all possible transition in a list
+            # A transition is defined by (agent, src_mode, dest_mode, corresponding reset, transit idx)
+            # Here we enforce that only one agent transit at a time
+            all_agent_state = {}
+            for agent_id in node.agent:
+                all_agent_state[agent_id] = (node.trace[agent_id][idx], node.mode[agent_id])
+            possible_transitions = self.get_all_transition(all_agent_state)
+            if possible_transitions != []:
+                for agent_idx, src_mode, dest_mode, next_init in possible_transitions:
+                    transitions.append((agent_idx, src_mode, dest_mode, next_init, idx))
+                break
+        return transitions, idx
+             
+    def apply_cont_var_updater(self,cont_var_dict, updater):
+        for variable in updater:
+            for unrolled_variable, unrolled_variable_index in updater[variable]:
+                cont_var_dict[unrolled_variable] = cont_var_dict[variable][unrolled_variable_index]
+
+    # def apply_disc_var_updater(self,disc_var_dict, updater):
+    #     for variable in updater:
+    #         unrolled_variable, unrolled_variable_index = updater[variable]
+    #         disc_var_dict[unrolled_variable] = disc_var_dict[variable][unrolled_variable_index]
+
+    def get_transition_simulate_new(self, node:AnalysisTreeNode) -> List[Tuple[float]]:
+        lane_map = self.map
+        trace_length = len(list(node.trace.values())[0])
+
+        # For each agent
+        agent_guard_dict:Dict[str,List[GuardExpressionAst]] = {}
+
+        for agent_id in node.agent:
+            # Get guard
+            agent:BaseAgent = self.agent_dict[agent_id]
+            agent_mode = node.mode[agent_id]
+            paths = agent.controller.getNextModes(agent_mode)
+            state_dict = {}
+            for tmp in node.agent:
+                state_dict[tmp] = (node.trace[tmp][0], node.mode[tmp])
+            cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
+            for path in paths:
+                guard_list = []
+                reset_list = []
+                for item in path:
+                    if isinstance(item, Guard):
+                        guard_list.append(item)
+                    elif isinstance(item, Reset):
+                        reset_list.append(item)
+                guard_expression = GuardExpressionAst(guard_list)
+
+                continuous_variable_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, len_dict)
+                if agent_id not in agent_guard_dict:
+                    agent_guard_dict[agent_id] = [(guard_expression, continuous_variable_updater, copy.deepcopy(discrete_variable_dict), reset_list)]
+                else:
+                    agent_guard_dict[agent_id].append((guard_expression, continuous_variable_updater, copy.deepcopy(discrete_variable_dict), reset_list))
+
+        transitions = []
+        for idx in range(trace_length):
+            satisfied_guard = []
+            for agent_id in agent_guard_dict:
+                agent:BaseAgent = self.agent_dict[agent_id]
+                state_dict = {}
+                for tmp in node.agent:
+                    state_dict[tmp] = (node.trace[tmp][idx], node.mode[tmp])
+                agent_state, agent_mode = state_dict[agent_id]
+                agent_state = agent_state[1:]
+                continuous_variable_dict, _, _ = self.sensor.sense(self, agent, state_dict, self.map)
+                for guard_expression, continuous_variable_updater, discrete_variable_dict, reset_list in agent_guard_dict[agent_id]:
+                    new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
+                    # new_disc_var_dict = copy.deepcopy(discrete_variable_dict)
+                    one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression)
+                    self.apply_cont_var_updater(new_cont_var_dict, continuous_variable_updater)
+                    # self.apply_disc_var_updater(new_disc_var_dict, discrete_variable_updater)
+                    guard_satisfied = one_step_guard.evaluate_guard(agent, new_cont_var_dict, discrete_variable_dict, self.map)
+                    if guard_satisfied:
+                        # If the guard can be satisfied, handle resets
+                        next_init = agent_state
+                        dest = copy.deepcopy(agent_mode)
+                        possible_dest = [[elem] for elem in dest]
+                        for reset in reset_list:
+                            # Specify the destination mode
+                            reset = reset.code
+                            if "mode" in reset:
+                                for i, discrete_variable_ego in enumerate(agent.controller.vars_dict['ego']['disc']):
+                                    if discrete_variable_ego in reset:
+                                        break
+                                tmp = reset.split('=')
+                                if 'map' in tmp[1]:
+                                    tmp = tmp[1]
+                                    for var in discrete_variable_dict:
+                                        tmp = tmp.replace(var, f"'{discrete_variable_dict[var]}'")
+                                    possible_dest[i] = eval(tmp)
+                                else:
+                                    tmp = tmp[1].split('.')
+                                    if tmp[0].strip(' ') in agent.controller.modes:
+                                        possible_dest[i] = [tmp[1]]                            
+                            else: 
+                                for i, cts_variable in enumerate(agent.controller.vars_dict['ego']['cont']):
+                                    if "output."+cts_variable in reset:
+                                        break 
+                                tmp = reset.split('=')
+                                tmp = tmp[1]
+                                for cts_variable in continuous_variable_dict:
+                                    tmp = tmp.replace(cts_variable, str(continuous_variable_dict[cts_variable]))
+                                next_init[i] = eval(tmp)
+                        all_dest = list(itertools.product(*possible_dest))
+                        if not all_dest:
+                            warnings.warn(f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode")
+                            all_dest.append(None)
+                        for dest in all_dest:
+                            next_transition = (
+                                agent_id, agent_mode, dest, next_init, 
+                            )
+                            satisfied_guard.append(next_transition)
+            if satisfied_guard != []:
+                for agent_idx, src_mode, dest_mode, next_init in satisfied_guard:
+                    transitions.append((agent_idx, src_mode, dest_mode, next_init, idx))    
+                break
+        return transitions, idx
+
+
+    def get_transition_verify_new(self, node:AnalysisTreeNode):
+        lane_map = self.map 
+        possible_transitions = []
+        
+        agent_guard_dict = {}
+        for agent_id in node.agent:
+            agent:BaseAgent = self.agent_dict[agent_id]
+            agent_mode = node.mode[agent_id]
+            state_dict = {}
+            for tmp in node.agent:
+                state_dict[tmp] = (node.trace[tmp][0*2:0*2+2], node.mode[tmp])
+            
+            continuous_variable_dict, discrete_variable_dict, length_dict = self.sensor.sense(self, agent, state_dict, self.map)
+            paths = agent.controller.getNextModes(agent_mode)
+            for path in paths:
+                # Construct the guard expression
+                guard_list = []
+                reset_list = []
+                for item in path:
+                    if isinstance(item, Guard):
+                        guard_list.append(item)
+                    elif isinstance(item, Reset):
+                        reset_list.append(item)
+                guard_expression = GuardExpressionAst(guard_list)
+                
+                cont_var_updater = guard_expression.parse_any_all_new(continuous_variable_dict, discrete_variable_dict, length_dict)
+
+                guard_can_satisfied = guard_expression.evaluate_guard_disc(agent, discrete_variable_dict, continuous_variable_dict, self.map)
+                if not guard_can_satisfied:
+                    continue
+                if agent_id not in agent_guard_dict:
+                    agent_guard_dict[agent_id] = [(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset_list)]
+                else:
+                    agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset_list))
+
+        trace_length = int(len(list(node.trace.values())[0])/2)
+        guard_hits = []
+        guard_hit_bool = False
+        for idx in range(0,trace_length):
+            any_contained = False 
+            hits = []
+            state_dict = {}
+            for tmp in node.agent:
+                state_dict[tmp] = (node.trace[tmp][idx*2:idx*2+2], node.mode[tmp])
+            
+            for agent_id in agent_guard_dict:
+                agent:BaseAgent = self.agent_dict[agent_id]
+                agent_state, agent_mode = state_dict[agent_id]
+                agent_state = agent_state[1:]
+                continuous_variable_dict, _, _ = self.sensor.sense(self, agent, state_dict, self.map)
+                for guard_expression, continuous_variable_updater, discrete_variable_dict, reset_list in agent_guard_dict[agent_id]:
+                    new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
+                    one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression)
+
+                    self.apply_cont_var_updater(new_cont_var_dict, continuous_variable_updater)
+                    guard_can_satisfied = one_step_guard.evaluate_guard_hybrid(agent, discrete_variable_dict, new_cont_var_dict, self.map)
+                    if not guard_can_satisfied:
+                        continue
+                    guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont(agent, new_cont_var_dict, self.map)
+                    any_contained = any_contained or is_contained
+                    if guard_satisfied:
+                        hits.append((agent_id, guard_list, reset_list))
+            if hits != []:
+                guard_hits.append((hits, state_dict, idx))
+                guard_hit_bool = True 
+            if hits == [] and guard_hit_bool:
+                break 
+            if any_contained:
+                break
+
+        reset_dict = {}
+        reset_idx_dict = {}
+        for hits, all_agent_state, hit_idx in guard_hits:
+            for agent_id, guard_list, reset_list in hits:
+                dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
+                if agent_id not in reset_dict:
+                    reset_dict[agent_id] = {}
+                    reset_idx_dict[agent_id] = {}
+                if not dest_list:
+                    warnings.warn(f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode")
+                    dest_list.append(None)
+                for dest in dest_list:
+                    if dest not in reset_dict[agent_id]:
+                        reset_dict[agent_id][dest] = []
+                        reset_idx_dict[agent_id][dest] = []
+                    reset_dict[agent_id][dest].append(reset_rect)
+                    reset_idx_dict[agent_id][dest].append(hit_idx)
+            
+        # Combine reset rects and construct transitions
+        for agent in reset_dict:
+            for dest in reset_dict[agent]:
+                combined_rect = None 
+                for rect in reset_dict[agent][dest]:
+                    rect = np.array(rect)
+                    if combined_rect is None:
+                        combined_rect = rect 
+                    else:
+                        combined_rect[0,:] = np.minimum(combined_rect[0,:], rect[0,:])
+                        combined_rect[1,:] = np.maximum(combined_rect[1,:], rect[1,:])
+                combined_rect = combined_rect.tolist()
+                min_idx = min(reset_idx_dict[agent][dest])
+                max_idx = max(reset_idx_dict[agent][dest])
+                transition = (agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx))
+                possible_transitions.append(transition)
+        # Return result
+        return possible_transitions
\ No newline at end of file
-- 
GitLab